-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_baseline.py
133 lines (109 loc) · 4.19 KB
/
evaluate_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
# Author: Armit
# Create Time: 2024/04/27
# 0. 训练一个模型,保存好权重
# 1. 加载预训练好的权重,恢复现场,布置推理环境
# - 预训练权重
# - 测试数据 (测训练时用的所有数据集)
# 2. 过一遍dataloader,计算每个batch的性能指标,记录下来
# 3. 统计整个数据集上的该性能指标
import yaml
from tqdm import tqdm
from train import *
from train import is_clf as is_clf_fn
from metrics import *
from train_baseline import fix_last_layer, LitModel
@torch.inference_mode()
def run(args):
''' Model & Ckpt '''
fp = Path(args.load).parent.parent / 'hparams.yaml'
with open(fp, 'r', encoding='utf-8') as fh:
hp = yaml.unsafe_load(fh)
dataset_cls = get_dataset_cls(hp['dataset'])
model = globals()[args.model](pretrained=args.load is None)
fix_last_layer(model, HEAD_DIMS[dataset_cls.head.value])
model = LitModel.load_from_checkpoint(args.load, model=model).model.to(device).eval()
''' Data '''
dataloader_kwargs = {
'num_workers': 0,
'persistent_workers': False,
'pin_memory': False,
}
dataloader = DataLoader(dataset_cls(args.split), args.batch_size, shuffle=False, drop_last=False, **dataloader_kwargs)
''' Bookkeep '''
is_ldl = dataset_cls.is_ldl
is_clf = is_clf_fn(dataset_cls.head.value)
''' Evaluate '''
Y_clf_list: List[Tensor] = []
Y_ldl_list: List[Tensor] = []
Y_rgr_list: List[Tensor] = []
pred_list: List[Tensor] = []
prob_list: List[Tensor] = []
output_list: List[Tensor] = []
for X, Y in tqdm(dataloader):
X, Y = X.to(device), Y.to(device)
if is_clf:
Y_clf = torch.argmax(Y, dim=-1) if is_ldl else Y
Y_clf_list.append(Y_clf)
if is_ldl:
Y_ldl_list.append(Y)
else:
Y_rgr_list.append(Y)
out = model(X)
if is_clf:
pred_list.append(torch.argmax(out, dim=-1))
if is_ldl:
prob_list.append(F.softmax(out, dim=-1))
else:
output_list.append(out)
print(f'>> [{hp["dataset"]}]')
if is_clf:
preds = torch.cat(pred_list ).cpu().numpy()
Y_clfs = torch.cat(Y_clf_list).cpu().numpy()
acc_v = acc (preds, Y_clfs)
prec_v = prec (preds, Y_clfs)
recall_v = recall(preds, Y_clfs)
f1_v = f1 (preds, Y_clfs)
print(f'>> acc: {acc_v:.3%}')
print(f'>> prec: {prec_v:.3%}')
print(f'>> recall: {recall_v:.3%}')
print(f'>> f1: {f1_v:.3%}')
if is_ldl:
probs = torch.cat(prob_list ).cpu().numpy()
Y_ldls = torch.cat(Y_ldl_list).cpu().numpy()
chebyshev_v = chebyshev_dist (probs, Y_ldls)
clark_v = clark_dist (probs, Y_ldls)
canberra_v = canberra_dist (probs, Y_ldls)
kullback_leibler_v = kullback_leibler_dist(probs, Y_ldls)
cosine_sim_v = cosine_sim (probs, Y_ldls)
intersection_sim_v = intersection_sim (probs, Y_ldls)
print(f'>> chebyshev: {chebyshev_v:.5f}')
print(f'>> clark: {clark_v:.5f}')
print(f'>> canberra: {canberra_v:.5f}')
print(f'>> kl_div: {kullback_leibler_v:.5f}')
print(f'>> cos_sim: {cosine_sim_v:.5f}')
print(f'>> intersect_sim: {intersection_sim_v:.5f}')
else:
outputs = torch.cat(output_list).cpu().numpy()
Y_rgrs = torch.cat(Y_rgr_list ).cpu().numpy()
# FIXME: tmp fix for mean shift, DO NOT know why
outputs -= outputs.mean()
outputs += Y_rgrs.mean()
mae_v = mae (outputs, Y_rgrs)
mse_v = mse (outputs, Y_rgrs)
#msle_v = msle(outputs, Y_rgrs)
rmse_v = rmse(outputs, Y_rgrs)
r2_v = r2 (outputs, Y_rgrs)
print(f'>> mae: {mae_v:.5f}')
print(f'>> mse: {mse_v:.5f}')
#print(f'>> msle: {msle_v:.5f}')
print(f'>> rmse: {rmse_v:.5f}')
print(f'>> r2: {r2_v:.5f}')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-L', '--load', type=Path, required=True, help='load pretrained weights')
parser.add_argument('-M', '--model', default='resnet50', choices=list(BACKBONE_CLASSES.keys()), help='backbone net')
parser.add_argument('--split', default='valid', choices=['train', 'valid'], help='dataset split')
parser.add_argument('-B', '--batch_size', type=int, default=128)
args = parser.parse_args()
run(args)