forked from barakmich/bbqvec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackend_quantized_memory.go
107 lines (93 loc) · 2.27 KB
/
backend_quantized_memory.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package bbq
import (
"errors"
"math/rand"
"time"
)
type QuantizedMemoryBackend[V any, Q Quantization[V]] struct {
vecs []*V
rng *rand.Rand
dim int
quantization Q
}
var _ scannableBackend = &QuantizedMemoryBackend[Vector, NoQuantization]{}
var _ VectorGetter[Vector] = &QuantizedMemoryBackend[Vector, NoQuantization]{}
func NewQuantizedMemoryBackend[V any, Q Quantization[V]](dimensions int, quantization Q) *QuantizedMemoryBackend[V, Q] {
return &QuantizedMemoryBackend[V, Q]{
rng: rand.New(rand.NewSource(time.Now().UnixMicro())),
dim: dimensions,
quantization: quantization,
}
}
func (q *QuantizedMemoryBackend[V, Q]) Close() error {
return nil
}
func (q *QuantizedMemoryBackend[V, Q]) PutVector(id ID, vector Vector) error {
if len(vector) != q.dim {
return errors.New("QuantizedMemoryBackend: vector dimension doesn't match")
}
v, err := q.quantization.Lower(vector)
if err != nil {
return err
}
if int(id) < len(q.vecs) {
q.vecs[int(id)] = &v
} else if int(id) == len(q.vecs) {
q.vecs = append(q.vecs, &v)
} else {
q.grow(int(id))
q.vecs[int(id)] = &v
}
return nil
}
func (q *QuantizedMemoryBackend[V, Q]) grow(to int) {
diff := (to - len(q.vecs)) + 1
q.vecs = append(q.vecs, make([]*V, diff)...)
}
func (q *QuantizedMemoryBackend[V, Q]) ComputeSimilarity(vector Vector, targetID ID) (float32, error) {
v, err := q.quantization.Lower(vector)
if err != nil {
return 0, err
}
target, err := q.GetVector(targetID)
if err != nil {
return 0, err
}
return q.quantization.Similarity(target, v), nil
}
func (q *QuantizedMemoryBackend[V, Q]) Info() BackendInfo {
return BackendInfo{
HasIndexData: false,
Dimensions: q.dim,
}
}
func (q *QuantizedMemoryBackend[V, Q]) Exists(id ID) bool {
i := int(id)
if len(q.vecs) <= i {
return false
}
return q.vecs[i] != nil
}
func (q *QuantizedMemoryBackend[V, Q]) GetVector(id ID) (v V, err error) {
if int(id) > len(q.vecs)-1 {
err = ErrIDNotFound
return
}
if q.vecs[int(id)] == nil {
err = ErrIDNotFound
return
}
return *q.vecs[int(id)], nil
}
func (q *QuantizedMemoryBackend[V, Q]) ForEachVector(cb func(ID) error) error {
for i, v := range q.vecs {
if v == nil {
continue
}
err := cb(ID(i))
if err != nil {
return err
}
}
return nil
}