From 121f72c6604b8c60446916c3c6b0fb1356b44042 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 10 Jan 2025 06:50:09 -0800 Subject: [PATCH] Add support for refinement with index type --- cpp/include/cuvs/neighbors/refine.hpp | 45 +++++++++++++++++++ cpp/src/neighbors/ivf_flat_index.cpp | 1 + .../detail/refine_device_float_float.cu | 1 + cpp/src/neighbors/refine/refine_device.cuh | 13 +++--- 4 files changed, 54 insertions(+), 6 deletions(-) diff --git a/cpp/include/cuvs/neighbors/refine.hpp b/cpp/include/cuvs/neighbors/refine.hpp index 19fbd30bb..5e60ff537 100644 --- a/cpp/include/cuvs/neighbors/refine.hpp +++ b/cpp/include/cuvs/neighbors/refine.hpp @@ -76,6 +76,51 @@ void refine(raft::resources const& handle, raft::device_matrix_view distances, cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded); +/** + * @brief Refine nearest neighbor search. + * + * Refinement is an operation that follows an approximate NN search. The approximate search has + * already selected n_candidates neighbor candidates for each query. We narrow it down to k + * neighbors. For each query, we calculate the exact distance between the query and its + * n_candidates neighbor candidate, and select the k nearest ones. + * + * The k nearest neighbors and distances are returned. + * + * Example usage + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search m = 4 * k nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, neighbor_candidates, + * out_dists_tmp); + * // refine it to the k nearest one + * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, + * index.metric()); + * @endcode + * + * + * @param[in] handle the raft handle + * @param[in] dataset device matrix that stores the dataset [n_rows, dims] + * @param[in] queries device matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where + * n_candidates >= k + * @param[out] indices device matrix that stores the refined indices [n_queries, k] + * @param[out] distances device matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +void refine(raft::resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded); + /** * @brief Refine nearest neighbor search. * diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp index 6f7d11e50..c16dc47aa 100644 --- a/cpp/src/neighbors/ivf_flat_index.cpp +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -226,6 +226,7 @@ void index::check_consistency() "inconsistent number of lists (clusters)"); } +template struct index; // Used for refine function template struct index; template struct index; template struct index; diff --git a/cpp/src/neighbors/refine/detail/refine_device_float_float.cu b/cpp/src/neighbors/refine/detail/refine_device_float_float.cu index 25bad201b..76b792d1c 100644 --- a/cpp/src/neighbors/refine/detail/refine_device_float_float.cu +++ b/cpp/src/neighbors/refine/detail/refine_device_float_float.cu @@ -43,5 +43,6 @@ } instantiate_cuvs_neighbors_refine_d(int64_t, float, float, int64_t); +instantiate_cuvs_neighbors_refine_d(uint32_t, float, float, int64_t); #undef instantiate_cuvs_neighbors_refine_d diff --git a/cpp/src/neighbors/refine/refine_device.cuh b/cpp/src/neighbors/refine/refine_device.cuh index 6184e540b..1ce115449 100644 --- a/cpp/src/neighbors/refine/refine_device.cuh +++ b/cpp/src/neighbors/refine/refine_device.cuh @@ -84,12 +84,13 @@ void refine_device( cuvs::neighbors::ivf_flat::index refinement_index( handle, cuvs::distance::DistanceType(metric), n_queries, false, true, dim); - cuvs::neighbors::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); + cuvs::neighbors::ivf_flat::detail::fill_refinement_index( + handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + (idx_t)n_queries, + (uint32_t)n_candidates); uint32_t grid_dim_x = 1; // the neighbor ids will be computed in uint32_t as offset