From 2c501649636ed2920c0876ee503d9e852829b2b0 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 22 May 2024 10:57:44 -0700 Subject: [PATCH] Add tutorial for FastScan with refinement for cpp (#3474) Summary: This commit focus on the cpp version of PQfastscan tutorial with index refinement by defining the k factor. Differential Revision: D57680905 --- tutorial/cpp/8-PQFastScanRefine.cpp | 84 +++++++++++++++++++++++++++++ tutorial/cpp/CMakeLists.txt | 3 ++ 2 files changed, 87 insertions(+) create mode 100644 tutorial/cpp/8-PQFastScanRefine.cpp diff --git a/tutorial/cpp/8-PQFastScanRefine.cpp b/tutorial/cpp/8-PQFastScanRefine.cpp new file mode 100644 index 0000000000..2435d94d2c --- /dev/null +++ b/tutorial/cpp/8-PQFastScanRefine.cpp @@ -0,0 +1,84 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include +#include + +using idx_t = faiss::idx_t; + +int main() { + int d = 64; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + + std::mt19937 rng; + std::uniform_real_distribution<> distrib; + + float* xb = new float[(int)(d * nb)]; + float* xq = new float[(int)(d * nq)]; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) { + xb[d * i + j] = distrib(rng); + } + xb[d * i] += i / 1000.; + } + + for (int i = 0; i < nq; i++) { + for (int j = 0; j < d; j++) { + xq[d * i + j] = distrib(rng); + } + xq[d * i] += i / 1000.; + } + + int m = 8; + int n_bit = 4; + + faiss::IndexPQFastScan index(d, m, n_bit); + faiss::IndexRefineFlat index_refine(&index); + // refine index after PQFastScan + + printf("Index is trained? %s\n", + index_refine.is_trained ? "true" : "false"); + index_refine.train(nb, xb); + printf("Index is trained? %s\n", + index_refine.is_trained ? "true" : "false"); + index_refine.add(nb, xb); + + int k = 4; + { // search xq + idx_t* I = new idx_t[(int)(k * nq)]; + float* D = new float[(int)(k * nq)]; + float k_factor = 3; + faiss::IndexRefineSearchParameters* params = + new faiss::IndexRefineSearchParameters(); + params->k_factor = k_factor; + index_refine.search(nq, xq, k, D, I, params); + + printf("I=\n"); + for (int i = nq - 5; i < nq; i++) { + for (int j = 0; j < k; j++) { + printf("%5zd ", I[i * k + j]); + } + printf("\n"); + } + + delete[] I; + delete[] D; + delete params; + } + + delete[] xb; + delete[] xq; + + return 0; +} diff --git a/tutorial/cpp/CMakeLists.txt b/tutorial/cpp/CMakeLists.txt index 894fb4168e..ad152c499d 100644 --- a/tutorial/cpp/CMakeLists.txt +++ b/tutorial/cpp/CMakeLists.txt @@ -24,3 +24,6 @@ target_link_libraries(6-HNSW PRIVATE faiss) add_executable(7-PQFastScan EXCLUDE_FROM_ALL 7-PQFastScan.cpp) target_link_libraries(7-PQFastScan PRIVATE faiss) + +add_executable(8-PQFastScanRefine EXCLUDE_FROM_ALL 8-PQFastScanRefine.cpp) +target_link_libraries(8-PQFastScanRefine PRIVATE faiss)