Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into rhdong/bf-bitset
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong authored Jan 10, 2025
2 parents 3a5d4e0 + 1e548f8 commit 8a45192
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 59 deletions.
28 changes: 28 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,20 @@ auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a host pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
const cuvs::neighbors::brute_force::index_params& index_params,
raft::host_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
Expand Down Expand Up @@ -231,6 +245,20 @@ auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a host pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
const cuvs::neighbors::brute_force::index_params& index_params,
raft::host_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ void index<T, DistT>::update_dataset(
{ \
return detail::build<T, DistT>(res, dataset, index_params.metric, index_params.metric_arg); \
} \
auto build(raft::resources const& res, \
const cuvs::neighbors::brute_force::index_params& index_params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::brute_force::index<T, DistT> \
{ \
return detail::build<T, DistT>(res, dataset, index_params.metric, index_params.metric_arg); \
} \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
Expand Down
22 changes: 18 additions & 4 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "./knn_utils.cuh"

#include <raft/core/bitmap.cuh>
#include <raft/core/copy.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
Expand Down Expand Up @@ -775,10 +776,10 @@ void search(raft::resources const& res,
}
}

template <typename T, typename DistT, typename LayoutT = raft::row_major>
template <typename T, typename DistT, typename AccessorT, typename LayoutT = raft::row_major>
cuvs::neighbors::brute_force::index<T, DistT> build(
raft::resources const& res,
raft::device_matrix_view<const T, int64_t, LayoutT> dataset,
mdspan<const T, matrix_extent<int64_t>, LayoutT, AccessorT> dataset,
cuvs::distance::DistanceType metric,
DistT metric_arg)
{
Expand All @@ -789,18 +790,31 @@ cuvs::neighbors::brute_force::index<T, DistT> build(
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
auto dataset_storage = std::optional<device_matrix<T, int64_t, LayoutT>>{};
auto dataset_view = [&res, &dataset_storage, dataset]() {
if constexpr (std::is_same_v<decltype(dataset),
raft::device_matrix_view<const T, int64_t, row_major>>) {
return dataset;
} else {
dataset_storage =
make_device_matrix<T, int64_t, LayoutT>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_storage->view(), dataset);
return raft::make_const_mdspan(dataset_storage->view());
}
}();

norms = raft::make_device_vector<DistT, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
Expand Down
142 changes: 87 additions & 55 deletions cpp/test/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cuvs/selection/select_k.hpp>

#include <cuvs/neighbors/brute_force.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
Expand Down Expand Up @@ -210,14 +211,15 @@ struct RandomKNNInputs {
int k;
cuvs::distance::DistanceType metric;
bool row_major;
bool host_dataset;
};

std::ostream& operator<<(std::ostream& os, const RandomKNNInputs& input)
{
return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs
<< " dim:" << input.dim << " k:" << input.k
<< " metric:" << cuvs::neighbors::print_metric{input.metric}
<< " row_major:" << input.row_major;
<< " row_major:" << input.row_major << " host_dataset:" << input.host_dataset;
}

template <typename T, typename DistT = T>
Expand Down Expand Up @@ -399,12 +401,15 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>

cuvs::neighbors::brute_force::search_params search_params;

if (params_.row_major) {
auto idx =
cuvs::neighbors::brute_force::build(handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t>(
database.data(), params_.num_db_vecs, params_.dim));
if (params_.host_dataset) {
// test building from a dataset in host memory
auto host_database =
raft::make_host_matrix<T, int64_t, raft::row_major>(params_.num_db_vecs, params_.dim);
raft::copy(
host_database.data_handle(), database.data(), params_.num_db_vecs * params_.dim, stream_);

auto idx = cuvs::neighbors::brute_force::build(
handle_, index_params, raft::make_const_mdspan(host_database.view()));

cuvs::neighbors::brute_force::search(
handle_,
Expand All @@ -416,21 +421,39 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
distances,
cuvs::neighbors::filtering::none_sample_filter{});
} else {
auto idx = cuvs::neighbors::brute_force::build(
handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
database.data(), params_.num_db_vecs, params_.dim));
if (params_.row_major) {
auto idx =
cuvs::neighbors::brute_force::build(handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t>(
database.data(), params_.num_db_vecs, params_.dim));

cuvs::neighbors::brute_force::search(
handle_,
search_params,
idx,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
search_queries.data(), params_.num_queries, params_.dim),
indices,
distances,
cuvs::neighbors::filtering::none_sample_filter{});
cuvs::neighbors::brute_force::search(
handle_,
search_params,
idx,
raft::make_device_matrix_view<const T, int64_t>(
search_queries.data(), params_.num_queries, params_.dim),
indices,
distances,
cuvs::neighbors::filtering::none_sample_filter{});
} else {
auto idx = cuvs::neighbors::brute_force::build(
handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
database.data(), params_.num_db_vecs, params_.dim));

cuvs::neighbors::brute_force::search(
handle_,
search_params,
idx,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
search_queries.data(), params_.num_queries, params_.dim),
indices,
distances,
cuvs::neighbors::filtering::none_sample_filter{});
}
}

ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(ref_indices_.data(),
Expand Down Expand Up @@ -480,42 +503,51 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>

const std::vector<RandomKNNInputs> random_inputs = {
// test each distance metric on a small-ish input, with row-major inputs
{100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true},
{100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true, false},
// test each distance metric with col-major inputs
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false},
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false, false},
// larger tests on different sized data / k values
{10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false},
{345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true},
{789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false}};
{10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false, false},
{345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true, false},
{789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true, false},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false, false},
// test with datasets on host memory
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, true, true},
{256, 512, 32, 16, cuvs::distance::DistanceType::L2Unexpanded, true, true},
{256, 512, 8, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, true},
{256, 128, 32, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true, true},
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, true, true}};

typedef RandomBruteForceKNNTest<float, float> RandomBruteForceKNNTestF;
TEST_P(RandomBruteForceKNNTestF, BruteForce) { this->testBruteForce(); }
Expand Down

0 comments on commit 8a45192

Please sign in to comment.