-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluator.py
37 lines (31 loc) · 1.55 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# -*- coding:utf-8 -*-
import math
"""知乎提供的评测方案"""
def score_eval(predict_label_and_marked_label_list):
"""
:param predict_label_and_marked_label_list: 一个元组列表。例如
[ ([1, 2, 3, 4, 5], [4, 5, 6, 7]),
([3, 2, 1, 4, 7], [5, 7, 3])
]
需要注意这里 predict_label 是去重复的,例如 [1,2,3,2,4,1,6],去重后变成[1,2,3,4,6]
marked_label_list 本身没有顺序性,但提交结果有,例如上例的命中情况分别为
[0,0,0,1,1] (4,5命中)
[1,0,0,0,1] (3,7命中)
"""
right_label_num = 0 #总命中标签数量
right_label_at_pos_num = [0, 0, 0, 0, 0] #在各个位置上总命中数量
sample_num = 0 #总问题数量
all_marked_label_num = 0 #总标签数量
for predict_labels, marked_labels in predict_label_and_marked_label_list:
sample_num += 1
marked_label_set = set(marked_labels)
all_marked_label_num += len(marked_label_set)
for pos, label in zip(range(0, min(len(predict_labels), 5)), predict_labels):
if label in marked_label_set: #命中
right_label_num += 1
right_label_at_pos_num[pos] += 1
precision = 0.0
for pos, right_num in zip(range(0, 5), right_label_at_pos_num):
precision += ((right_num / float(sample_num))) / math.log(2.0 + pos) # 下标0-4 映射到 pos1-5 + 1,所以最终+2
recall = float(right_label_num) / all_marked_label_num
return precision, recall, (precision * recall) / (precision + recall )