Skip to content

Commit

Permalink
Add support for refinement with index type
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Jan 10, 2025
1 parent 2a10353 commit 121f72c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
45 changes: 45 additions & 0 deletions cpp/include/cuvs/neighbors/refine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,51 @@ void refine(raft::resources const& handle,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief Refine nearest neighbor search.
*
* Refinement is an operation that follows an approximate NN search. The approximate search has
* already selected n_candidates neighbor candidates for each query. We narrow it down to k
* neighbors. For each query, we calculate the exact distance between the query and its
* n_candidates neighbor candidate, and select the k nearest ones.
*
* The k nearest neighbors and distances are returned.
*
* Example usage
* @code{.cpp}
* using namespace cuvs::neighbors;
* // use default index parameters
* ivf_pq::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = ivf_pq::build(handle, index_params, dataset);
* // use default search parameters
* ivf_pq::search_params search_params;
* // search m = 4 * k nearest neighbours for each of the N queries
* ivf_pq::search(handle, search_params, index, queries, neighbor_candidates,
* out_dists_tmp);
* // refine it to the k nearest one
* refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists,
* index.metric());
* @endcode
*
*
* @param[in] handle the raft handle
* @param[in] dataset device matrix that stores the dataset [n_rows, dims]
* @param[in] queries device matrix of the queries [n_queris, dims]
* @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where
* n_candidates >= k
* @param[out] indices device matrix that stores the refined indices [n_queries, k]
* @param[out] distances device matrix that stores the refined distances [n_queries, k]
* @param[in] metric distance metric to use. Euclidean (L2) is used by default
*/
void refine(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<const uint32_t, int64_t, raft::row_major> neighbor_candidates,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> indices,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief Refine nearest neighbor search.
*
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/ivf_flat_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ void index<T, IdxT>::check_consistency()
"inconsistent number of lists (clusters)");
}

template struct index<float, uint32_t>; // Used for refine function
template struct index<float, int64_t>;
template struct index<half, int64_t>;
template struct index<int8_t, int64_t>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@
}

instantiate_cuvs_neighbors_refine_d(int64_t, float, float, int64_t);
instantiate_cuvs_neighbors_refine_d(uint32_t, float, float, int64_t);

#undef instantiate_cuvs_neighbors_refine_d
13 changes: 7 additions & 6 deletions cpp/src/neighbors/refine/refine_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ void refine_device(
cuvs::neighbors::ivf_flat::index<data_t, idx_t> refinement_index(
handle, cuvs::distance::DistanceType(metric), n_queries, false, true, dim);

cuvs::neighbors::ivf_flat::detail::fill_refinement_index(handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
n_queries,
n_candidates);
cuvs::neighbors::ivf_flat::detail::fill_refinement_index<data_t, idx_t>(
handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
(idx_t)n_queries,
(uint32_t)n_candidates);
uint32_t grid_dim_x = 1;

// the neighbor ids will be computed in uint32_t as offset
Expand Down

0 comments on commit 121f72c

Please sign in to comment.