-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
executable file
·36 lines (25 loc) · 1.07 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from utils.utils import extract_dataset
from common.eval import *
model.eval()
if P.mode == 'test_clean_acc':
from evals import test_classifier
test_classifier(P, model, test_loader, 0, logger=None)
elif P.mode == 'test_adv_acc':
from evals import test_classifier_adv
test_classifier_adv(P, model, test_loader, 0,
adversary=adversary_t, logger=None, ret='adv')
elif P.mode == 'test_auto_attack':
from autoattack import AutoAttack
auto_adversary = AutoAttack(model, norm=P.distance, eps=P.epsilon, version='standard')
x_test, y_test = extract_dataset(test_loader)
x_adv = auto_adversary.run_standard_evaluation(x_test, y_test)
elif P.mode == 'test_mce':
from evals import test_classifier
mean_corruption_error = 0.
for name in corruption_list:
error = test_classifier(P, model, corruption_loader[name], 0, logger=None)
mean_corruption_error += error
print (f'Error of {name}: {error}%\n')
print (f'MCE: {mean_corruption_error/len(corruption_list)} %')
else:
raise NotImplementedError()