Skip to content

Commit

Permalink
PQ with pytorch
Browse files Browse the repository at this point in the history
Summary: This diff implements Product Quantization using Pytorch only.

Differential Revision: D67766798
  • Loading branch information
mlomeli1 authored and facebook-github-bot committed Jan 2, 2025
1 parent 0cbc2a8 commit 7ce6f63
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
1 change: 1 addition & 0 deletions contrib/torch/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans


class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
Expand Down
61 changes: 53 additions & 8 deletions contrib/torch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,47 @@
This contrib module contains Pytorch code for quantization.
"""

import numpy as np
import torch
import faiss

from faiss.contrib import torch_utils
import math
from faiss.contrib import clustering
# the kmeans can produce both torch and numpy centroids


class Quantizer:

def __init__(self, d, code_size):
"""
d: dimension of vectors
code_size: nb of bytes of the code (per vector)
"""
self.d = d
self.code_size = code_size

def train(self, x):
"""
takes a n-by-d array and peforms training
"""
pass

def encode(self, x):
"""
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
"""
pass

def decode(self, x):
def decode(self, codes):
"""
takes a n-by-code_size uint8 array, returns a n-by-d array
"""
pass


class VectorQuantizer(Quantizer):

def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))

code_size = int(math.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k

Expand All @@ -42,12 +56,43 @@ def train(self, x):


class ProductQuantizer(Quantizer):

def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
""" M: number of subvectors, d%M == 0
nbits: number of bits that each vector is encoded into
"""
assert d % M == 0
assert nbits == 8 # todo: implement other nbits values
code_size = int(math.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
self.M = M
self.nbits = nbits

def train(self, x):
pass
self.codebook = torch.zeros((self.M, 2 ** self.nbits, self.d // self.M), device=x.device, dtype=x.dtype)
for m in range(self.M):
data = clustering.DatasetAssign(x[:, m * self.d // self.M: (m + 1) * self.d // self.M].contiguous())
self.codebook[m] = clustering.kmeans(2 ** self.nbits, self.code_size, data)

def encode(self, x):
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
for m in range(self.M):
_, I = faiss.knn(
x[:, m * self.d // self.M:(m + 1) * self.d // self.M].contiguous(),
self.codebook[m],
1,
)
codes[:, m] = I.ravel()
return codes

def decode(self, codes):
n = codes.shape[0]
x_rec = torch.zeros(n, self.d)
for i in range(n):
helper = torch.concat(
(self.codebook[0, codes[i, 0].item(), :], self.codebook[1, codes[i, 1].item(), :])
)

for m in range(2, self.M):
helper = torch.concat((helper, self.codebook[m, codes[i, m].item(), :]))
x_rec[i, :] = helper
return x_rec
33 changes: 29 additions & 4 deletions tests/torch_test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# LICENSE file in the root directory of this source tree.

import torch # usort: skip
import unittest # usort: skip
import numpy as np # usort: skip
import unittest # usort: skip
import numpy as np # usort: skip

import faiss # usort: skip
import faiss # usort: skip
import faiss.contrib.torch_utils # usort: skip
from faiss.contrib import datasets
from faiss.contrib.torch import clustering
from faiss.contrib.torch import clustering, quantization




Expand Down Expand Up @@ -400,3 +401,27 @@ def test_python_kmeans(self):
# 33498.332 33380.477
# print(err, err2) 1/0
self.assertLess(err2, err * 1.1)


class TestQuantization(unittest.TestCase):
def test_python_product_quantization(self):
""" Test the python implementation of product quantization """
d = 64
n = 2000
cs = 4
nbits = 8
M = 4
x = np.random.random(size=(n, d)).astype('float32')
pq = faiss.ProductQuantizer(d, cs, nbits)
pq.train(x)
codes = pq.compute_codes(x)
x2 = pq.decode(codes)
diff = ((x - x2)**2).sum()
# vs pure pytorch impl
xt = torch.from_numpy(x)
my_pq = quantization.ProductQuantizer(d, M, nbits)
my_pq.train(xt)
my_codes = my_pq.encode(xt)
xt2 = my_pq.decode(my_codes)
my_diff = ((xt - xt2)**2).sum()
self.assertLess(abs(diff - my_diff), 100)

0 comments on commit 7ce6f63

Please sign in to comment.