-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_images.py
28 lines (25 loc) · 1.04 KB
/
test_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import torch
def test_images(test_X, test_Y, num, cnn, basewidth, loss_func, device):
""" Testing test data... """
""" Return accuracy, prediction class, total loss"""
test_accuracy = 0
test_prediction = np.zeros(num)
tot_loss = 0
for step, b_x in enumerate(test_X): # gives batch data, normalize x when iterate train_loader
b_x = b_x.reshape(1, 1, basewidth, basewidth)
b_x = b_x.to(device)
output = cnn(b_x)[0] # cnn output
del b_x
b_y = test_Y[step]
b_y = b_y.reshape(1)
loss = loss_func(output, b_y) # cross entropy loss
tot_loss = tot_loss + loss
pred_y = torch.max(output.cpu(), 1)[1].data.numpy()
test_prediction[step] = pred_y
if pred_y == b_y.cpu().data.numpy():
test_accuracy = test_accuracy + 1
if step == int(num-1):
test_accuracy = test_accuracy / num
test_accuracy = test_accuracy * 100
return test_accuracy, test_prediction, tot_loss