forked from barakmich/bbqvec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresult.go
102 lines (92 loc) · 1.71 KB
/
result.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package bbq
import (
"fmt"
"sync"
)
type Result struct {
Similarity float32
ID ID
}
func (r Result) String() string {
return fmt.Sprintf("(%d %0.4f)", r.ID, r.Similarity)
}
type ResultSet struct {
inner sync.Mutex
sims []float32
ids []ID
k int
valid int
}
func NewResultSet(topK int) *ResultSet {
return &ResultSet{
k: topK,
sims: make([]float32, topK),
ids: make([]ID, topK),
valid: 0,
}
}
func (rs *ResultSet) Len() int {
return len(rs.sims)
}
func (rs *ResultSet) ComputeRecall(baseline *ResultSet, at int) float64 {
found := 0
for _, v := range baseline.ids[:at] {
for _, w := range rs.ids[:at] {
if v == w {
found += 1
}
}
}
return float64(found) / float64(at)
}
func (rs *ResultSet) String() string {
return fmt.Sprint(rs.ToSlice())
}
func (rs *ResultSet) AddResult(id ID, sim float32) bool {
// Do a quick check...
if rs.valid == rs.k {
// Bail if the last one beats us
last := rs.sims[len(rs.sims)-1]
if last > sim {
return false
}
}
rs.inner.Lock()
defer rs.inner.Unlock()
insert := 0
found := false
for insert != rs.k {
// If we're building it out, then the new insertion point is at the end.
if rs.valid <= insert {
rs.valid += 1
found = true
break
}
if rs.ids[insert] == id {
return true
}
if rs.sims[insert] < sim {
found = true
break
}
insert++
}
if !found {
return false
}
copy(rs.sims[insert+1:], rs.sims[insert:])
rs.sims[insert] = sim
copy(rs.ids[insert+1:], rs.ids[insert:])
rs.ids[insert] = id
return true
}
func (rs *ResultSet) ToSlice() []*Result {
out := make([]*Result, rs.valid)
for i := range out {
out[i] = &Result{
Similarity: rs.sims[i],
ID: rs.ids[i],
}
}
return out
}