Skip to content

Commit

Permalink
skip half test cases when cuSparse version < 12.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jan 8, 2025
1 parent cbc5d38 commit 3a5d4e0
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions cpp/test/neighbors/brute_force_prefiltered.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/random/rng_state.hpp>
#include <raft/util/popc.cuh>

#include <cusparse.h>
#include <gtest/gtest.h>

#include <cuda_fp16.h>
Expand Down Expand Up @@ -146,6 +147,22 @@ void set_bitmap(const index_t* src,
RAFT_CUDA_TRY(cudaGetLastError());
}

bool isCuSparseVersionGreaterThan_12_0_1()
{
int version;
cusparseHandle_t handle;
cusparseCreate(&handle);
cusparseGetVersion(handle, &version);

int major = version / 1000;
int minor = (version % 1000) / 100;
int patch = version % 100;

cusparseDestroy(handle);

return (major > 12) || (major == 12 && minor > 0) || (major == 12 && minor == 0 && patch >= 2);
}

template <typename value_t, typename dist_t, typename index_t, typename bitmap_t = uint32_t>
class PrefilteredBruteForceOnBitmapTest
: public ::testing::TestWithParam<PrefilteredBruteForceInputs<index_t>> {
Expand Down Expand Up @@ -352,6 +369,9 @@ class PrefilteredBruteForceOnBitmapTest

void SetUp() override
{
if (std::is_same_v<value_t, half> && !isCuSparseVersionGreaterThan_12_0_1()) {
GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it.";
}
index_t element =
raft::ceildiv(params.n_queries * params.n_dataset, index_t(sizeof(bitmap_t) * 8));
std::vector<bitmap_t> filter_h(element);
Expand Down Expand Up @@ -776,6 +796,9 @@ class PrefilteredBruteForceOnBitsetTest

void SetUp() override
{
if (std::is_same_v<value_t, half> && !isCuSparseVersionGreaterThan_12_0_1()) {
GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it.";
}
index_t element = raft::ceildiv(1 * params.n_dataset, index_t(sizeof(bitset_t) * 8));
std::vector<bitset_t> filter_h(element);
std::vector<bitset_t> filter_repeat_h(element * params.n_queries);
Expand Down

0 comments on commit 3a5d4e0

Please sign in to comment.