From 660a2caa64f864e38e0e7bd19df86556d25aa7db Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:29:55 -0800 Subject: [PATCH 1/7] Additional Distances for CAGRA C and Python API (#546) Add InnerProduct metric to CAGRA C and Python API + updates to CAGRA pytests. Closes https://github.com/rapidsai/cuvs/issues/545 Authors: - Tarang Jain (https://github.com/tarang-jain) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuvs/pull/546 --- cpp/include/cuvs/neighbors/cagra.h | 3 +++ cpp/src/neighbors/cagra_c.cpp | 6 ++++-- python/cuvs/cuvs/neighbors/cagra/cagra.pxd | 2 ++ python/cuvs/cuvs/neighbors/cagra/cagra.pyx | 20 ++++++++++++-------- python/cuvs/cuvs/test/test_cagra.py | 8 +++++--- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index 14331ebbc..f7f58a19c 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -87,6 +88,8 @@ typedef struct cuvsCagraCompressionParams* cuvsCagraCompressionParams_t; * */ struct cuvsCagraIndexParams { + /** Distance type. */ + cuvsDistanceType metric; /** Degree of input graph for pruning. */ size_t intermediate_graph_degree; /** Degree of output graph. */ diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index 326a89665..02b7a566e 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -41,7 +41,8 @@ void* _build(cuvsResources_t res, cuvsCagraIndexParams params, DLManagedTensor* auto res_ptr = reinterpret_cast(res); auto index = new cuvs::neighbors::cagra::index(*res_ptr); - auto index_params = cuvs::neighbors::cagra::index_params(); + auto index_params = cuvs::neighbors::cagra::index_params(); + index_params.metric = static_cast((int)params.metric), index_params.intermediate_graph_degree = params.intermediate_graph_degree; index_params.graph_degree = params.graph_degree; @@ -252,7 +253,8 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params) { return cuvs::core::translate_exceptions([=] { - *params = new cuvsCagraIndexParams{.intermediate_graph_degree = 128, + *params = new cuvsCagraIndexParams{.metric = L2Expanded, + .intermediate_graph_degree = 128, .graph_degree = 64, .build_algo = IVF_PQ, .nn_descent_niter = 20}; diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd index bba5a91a8..a0f811480 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd @@ -28,6 +28,7 @@ from libcpp cimport bool from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor +from cuvs.distance_type cimport cuvsDistanceType cdef extern from "cuvs/neighbors/cagra.h" nogil: @@ -47,6 +48,7 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: ctypedef cuvsCagraCompressionParams* cuvsCagraCompressionParams_t ctypedef struct cuvsCagraIndexParams: + cuvsDistanceType metric size_t intermediate_graph_degree size_t graph_degree cuvsCagraGraphBuildAlgo build_algo diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index 752aef741..fd55905cf 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -28,11 +28,13 @@ from libcpp cimport bool, cast from libcpp.string cimport string from cuvs.common cimport cydlpack +from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible +from cuvs.distance import DISTANCE_TYPES from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( @@ -131,9 +133,11 @@ cdef class IndexParams: Parameters ---------- metric : string denoting the metric type, default="sqeuclidean" - Valid values for metric: ["sqeuclidean"], where + Valid values for metric: ["sqeuclidean", "inner_product"], where - sqeuclidean is the euclidean distance without the square root operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2 + - inner_product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. intermediate_graph_degree : int, default = 128 graph_degree : int, default = 64 @@ -151,6 +155,7 @@ cdef class IndexParams: """ cdef cuvsCagraIndexParams* params + cdef object _metric # hold on to a reference to the compression, to keep from being GC'ed cdef public object compression @@ -170,10 +175,8 @@ cdef class IndexParams: nn_descent_niter=20, compression=None): - # todo (dgd): enable once other metrics are present - # and exposed in cuVS C API - # self.params.metric = _get_metric(metric) - # self.params.metric_arg = 0 + self._metric = metric + self.params.metric = DISTANCE_TYPES[metric] self.params.intermediate_graph_degree = intermediate_graph_degree self.params.graph_degree = graph_degree if build_algo == "ivf_pq": @@ -186,9 +189,9 @@ cdef class IndexParams: self.params.compression = \ compression.get_handle() - # @property - # def metric(self): - # return self.params.metric + @property + def metric(self): + return self._metric @property def intermediate_graph_degree(self): @@ -247,6 +250,7 @@ def build(IndexParams index_params, dataset, resources=None): The following distance metrics are supported: - L2 + - InnerProduct Parameters ---------- diff --git a/python/cuvs/cuvs/test/test_cagra.py b/python/cuvs/cuvs/test/test_cagra.py index 56e132c23..d3b03a5d0 100644 --- a/python/cuvs/cuvs/test/test_cagra.py +++ b/python/cuvs/cuvs/test/test_cagra.py @@ -29,7 +29,7 @@ def run_cagra_build_search_test( n_queries=100, k=10, dtype=np.float32, - metric="euclidean", + metric="sqeuclidean", intermediate_graph_degree=128, graph_degree=64, build_algo="ivf_pq", @@ -42,6 +42,8 @@ def run_cagra_build_search_test( ): dataset = generate_data((n_rows, n_cols), dtype) if metric == "inner_product": + if dtype in [np.int8, np.uint8]: + pytest.skip("skip normalization for int8/uint8 data") dataset = normalize(dataset, norm="l2", axis=1) dataset_device = device_ndarray(dataset) @@ -122,7 +124,7 @@ def run_cagra_build_search_test( @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) @pytest.mark.parametrize("array_type", ["device", "host"]) @pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) -@pytest.mark.parametrize("metric", ["euclidean"]) +@pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product"]) def test_cagra_dataset_dtype_host_device( dtype, array_type, inplace, build_algo, metric ): @@ -145,7 +147,7 @@ def test_cagra_dataset_dtype_host_device( "graph_degree": 32, "add_data_on_build": True, "k": 1, - "metric": "euclidean", + "metric": "sqeuclidean", "build_algo": "ivf_pq", }, { From 89ebf15150223f4bce4a08bf3a6a4089380a1d0a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 19 Dec 2024 19:34:09 -0800 Subject: [PATCH 2/7] Use nvidia-sphinx-theme for docs (#528) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/cuvs/pull/528 --- conda/environments/all_cuda-118_arch-aarch64.yaml | 6 ++++-- conda/environments/all_cuda-118_arch-x86_64.yaml | 6 ++++-- conda/environments/all_cuda-125_arch-aarch64.yaml | 6 ++++-- conda/environments/all_cuda-125_arch-x86_64.yaml | 6 ++++-- dependencies.yaml | 8 +++++--- docs/source/conf.py | 2 +- 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 50aa3fe7e..a6d98ea3b 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -7,7 +7,6 @@ channels: - conda-forge - nvidia dependencies: -- breathe - c-compiler - clang - clang-tools=16.0.6 @@ -44,7 +43,6 @@ dependencies: - nvcc_linux-aarch64=11.8 - openblas - pre-commit -- pydata-sphinx-theme - pylibraft==25.2.*,>=0.0.0a0 - pytest-cov - pytest==7.* @@ -55,5 +53,9 @@ dependencies: - scikit-learn - sphinx-copybutton - sphinx-markdown-tables +- sphinx>=8.0.0 - sysroot_linux-aarch64==2.17 +- pip: + - breathe>=4.35.0 + - nvidia-sphinx-theme name: all_cuda-118_arch-aarch64 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 8f15b6164..1063e4d6c 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -7,7 +7,6 @@ channels: - conda-forge - nvidia dependencies: -- breathe - c-compiler - clang - clang-tools=16.0.6 @@ -44,7 +43,6 @@ dependencies: - nvcc_linux-64=11.8 - openblas - pre-commit -- pydata-sphinx-theme - pylibraft==25.2.*,>=0.0.0a0 - pytest-cov - pytest==7.* @@ -55,5 +53,9 @@ dependencies: - scikit-learn - sphinx-copybutton - sphinx-markdown-tables +- sphinx>=8.0.0 - sysroot_linux-64==2.17 +- pip: + - breathe>=4.35.0 + - nvidia-sphinx-theme name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index f194c01a3..ee7b37695 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -7,7 +7,6 @@ channels: - conda-forge - nvidia dependencies: -- breathe - c-compiler - clang - clang-tools=16.0.6 @@ -40,7 +39,6 @@ dependencies: - numpydoc - openblas - pre-commit -- pydata-sphinx-theme - pylibraft==25.2.*,>=0.0.0a0 - pytest-cov - pytest==7.* @@ -51,5 +49,9 @@ dependencies: - scikit-learn - sphinx-copybutton - sphinx-markdown-tables +- sphinx>=8.0.0 - sysroot_linux-aarch64==2.17 +- pip: + - breathe>=4.35.0 + - nvidia-sphinx-theme name: all_cuda-125_arch-aarch64 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 912d1629b..7c8e1fd99 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -7,7 +7,6 @@ channels: - conda-forge - nvidia dependencies: -- breathe - c-compiler - clang - clang-tools=16.0.6 @@ -40,7 +39,6 @@ dependencies: - numpydoc - openblas - pre-commit -- pydata-sphinx-theme - pylibraft==25.2.*,>=0.0.0a0 - pytest-cov - pytest==7.* @@ -51,5 +49,9 @@ dependencies: - scikit-learn - sphinx-copybutton - sphinx-markdown-tables +- sphinx>=8.0.0 - sysroot_linux-64==2.17 +- pip: + - breathe>=4.35.0 + - nvidia-sphinx-theme name: all_cuda-125_arch-x86_64 diff --git a/dependencies.yaml b/dependencies.yaml index eca97d2f5..a73fe7b8f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -394,22 +394,24 @@ dependencies: common: - output_types: [conda] packages: - - breathe - doxygen>=1.8.20 - graphviz - ipython - numpydoc - - pydata-sphinx-theme - recommonmark + - sphinx>=8.0.0 - sphinx-copybutton - sphinx-markdown-tables + - pip: + - nvidia-sphinx-theme + - breathe>=4.35.0 rust: common: - output_types: [conda] packages: - make - rust - # clang/liblclang only needed for bindgen support + # clang/libclang only needed for bindgen support - clang - libclang build_wheels: diff --git a/docs/source/conf.py b/docs/source/conf.py index 0d667833a..c14919568 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -99,7 +99,7 @@ # a list of builtin themes. # -html_theme = "pydata_sphinx_theme" +html_theme = "nvidia_sphinx_theme" # Theme options are theme-specific and customize the look and feel of a theme From f48e9aab593232b72f74fd79ad256ed51b997b43 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 19 Dec 2024 19:39:29 -0800 Subject: [PATCH 3/7] Add support for float16 to the python pairwise distance api (#547) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/547 --- cpp/src/distance/pairwise_distance_c.cpp | 13 +++++++++---- python/cuvs/cuvs/distance/distance.pyx | 7 +++++-- python/cuvs/cuvs/test/test_distance.py | 13 ++++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cpp/src/distance/pairwise_distance_c.cpp b/cpp/src/distance/pairwise_distance_c.cpp index d457198a2..061adaa2c 100644 --- a/cpp/src/distance/pairwise_distance_c.cpp +++ b/cpp/src/distance/pairwise_distance_c.cpp @@ -29,7 +29,7 @@ namespace { -template +template void _pairwise_distance(cuvsResources_t res, DLManagedTensor* x_tensor, DLManagedTensor* y_tensor, @@ -40,7 +40,7 @@ void _pairwise_distance(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); using mdspan_type = raft::device_matrix_view; - using distances_mdspan_type = raft::device_matrix_view; + using distances_mdspan_type = raft::device_matrix_view; auto x_mds = cuvs::core::from_dlpack(x_tensor); auto y_mds = cuvs::core::from_dlpack(y_tensor); @@ -71,9 +71,14 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res, } if (x_dt.bits == 32) { - _pairwise_distance(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); + _pairwise_distance( + res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); + } else if (x_dt.bits == 16) { + _pairwise_distance( + res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); } else if (x_dt.bits == 64) { - _pairwise_distance(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); + _pairwise_distance( + res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); } else { RAFT_FAIL("Unsupported DLtensor dtype: %d and bits: %d", x_dt.code, x_dt.bits); } diff --git a/python/cuvs/cuvs/distance/distance.pyx b/python/cuvs/cuvs/distance/distance.pyx index eb34366e4..187532bfe 100644 --- a/python/cuvs/cuvs/distance/distance.pyx +++ b/python/cuvs/cuvs/distance/distance.pyx @@ -100,7 +100,10 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0, n = y_cai.shape[0] if out is None: - out = device_ndarray.empty((m, n), dtype=y_cai.dtype) + output_dtype = y_cai.dtype + if np.issubdtype(y_cai.dtype, np.float16): + output_dtype = np.float32 + out = device_ndarray.empty((m, n), dtype=output_dtype) out_cai = wrap_array(out) x_k = x_cai.shape[1] @@ -119,7 +122,7 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0, y_dt = y_cai.dtype d_dt = out_cai.dtype - if x_dt != y_dt or x_dt != d_dt: + if x_dt != y_dt: raise ValueError("Inputs must have the same dtypes") cdef cydlpack.DLManagedTensor* x_dlpack = \ diff --git a/python/cuvs/cuvs/test/test_distance.py b/python/cuvs/cuvs/test/test_distance.py index 681217fc8..f466c2743 100644 --- a/python/cuvs/cuvs/test/test_distance.py +++ b/python/cuvs/cuvs/test/test_distance.py @@ -40,7 +40,7 @@ ], ) @pytest.mark.parametrize("inplace", [True, False]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16]) def test_distance(n_rows, n_cols, inplace, metric, dtype): input1 = np.random.random_sample((n_rows, n_cols)) input1 = np.asarray(input1).astype(dtype) @@ -55,7 +55,10 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype): norm = np.sum(input1, axis=1) input1 = (input1.T / norm).T - output = np.zeros((n_rows, n_rows), dtype=dtype) + output_dtype = dtype + if np.issubdtype(dtype, np.float16): + output_dtype = np.float32 + output = np.zeros((n_rows, n_rows), dtype=output_dtype) if metric == "inner_product": expected = np.matmul(input1, input1.T) @@ -76,4 +79,8 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype): actual = output_device.copy_to_host() - assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3) + tol = 1e-3 + if np.issubdtype(dtype, np.float16): + tol = 1e-1 + + assert np.allclose(expected, actual, atol=tol, rtol=tol) From ac49c414254cb448efce02d7a7b08190e43584e8 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Mon, 30 Dec 2024 11:44:25 -0800 Subject: [PATCH 4/7] Check if nightlies have succeeded recently enough (#548) Contributes to https://github.com/rapidsai/build-planning/issues/127 This PR cannot be merged unless nightly CI has passed within the past 7 days, so if it remains unmerged that will itself be an indication that nightly CI needs fixing. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/cuvs/pull/548 --- .github/workflows/pr.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 4c3b4d06a..91f51bd90 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -12,6 +12,7 @@ concurrency: jobs: pr-builder: needs: + - check-nightly-ci - changed-files - checks - conda-cpp-build @@ -29,6 +30,18 @@ jobs: if: always() with: needs: ${{ toJSON(needs) }} + check-nightly-ci: + # Switch to ubuntu-latest once it defaults to a version of Ubuntu that + # provides at least Python 3.11 (see + # https://docs.python.org/3/library/datetime.html#datetime.date.fromisoformat) + runs-on: ubuntu-24.04 + env: + RAPIDS_GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Check if nightly CI is passing + uses: rapidsai/shared-actions/check_nightly_success/dispatch@main + with: + repo: cuvs changed-files: secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/changed-files.yaml@branch-25.02 From a57227310a54b42481e20aaece72d0879f4c5b96 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Mon, 30 Dec 2024 16:09:03 -0800 Subject: [PATCH 5/7] Update for raft logger changes (#540) This PR updates cuvs to use raft's updated logger implementation using [rapids-logger](https://github.com/rapidsai/rapids-logger). It is a breaking change because it changes the kmeans `base_params` verbosity type from an int to a `raft::level_enum`. This PR requires https://github.com/rapidsai/raft/pull/2530. Contributes to https://github.com/rapidsai/build-planning/issues/104 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Bradley Dice (https://github.com/bdice) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/540 --- cpp/CMakeLists.txt | 11 ++++--- cpp/bench/ann/CMakeLists.txt | 21 +++++++++---- cpp/bench/ann/src/common/benchmark.hpp | 31 +++++++------------ cpp/include/cuvs/cluster/kmeans.hpp | 3 +- cpp/src/cluster/detail/kmeans.cuh | 12 +++---- cpp/src/cluster/detail/kmeans_auto_find_k.cuh | 4 +-- cpp/src/cluster/detail/kmeans_balanced.cuh | 4 +-- cpp/src/cluster/detail/kmeans_common.cuh | 2 +- .../detail/sparse/coo_spmv_kernel.cuh | 2 ++ cpp/src/neighbors/detail/ann_utils.cuh | 2 +- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 2 -- .../neighbors/detail/cagra/cagra_build.cuh | 4 +-- .../detail/cagra/cagra_serialize.cuh | 4 +-- .../detail/cagra/compute_distance.hpp | 2 +- .../detail/cagra/search_multi_cta.cuh | 2 +- .../cagra/search_multi_cta_kernel-inl.cuh | 2 +- .../detail/cagra/search_multi_kernel.cuh | 2 +- .../detail/cagra/search_single_cta.cuh | 2 +- .../cagra/search_single_cta_kernel-inl.cuh | 3 +- .../neighbors/detail/dataset_serialize.hpp | 2 +- cpp/src/neighbors/detail/dynamic_batching.cuh | 2 -- .../neighbors/detail/vamana/vamana_build.cuh | 4 +-- .../detail/vamana/vamana_serialize.cuh | 2 +- .../detail/vamana/vamana_structs.cuh | 2 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 2 +- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 3 +- .../ivf_flat/ivf_flat_interleaved_scan.cuh | 2 +- .../neighbors/ivf_flat/ivf_flat_search.cuh | 3 +- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 3 +- cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh | 2 +- cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh | 2 +- cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh | 2 +- cpp/src/neighbors/mg/omp_checks.cpp | 1 - cpp/test/CMakeLists.txt | 4 +-- cpp/test/neighbors/ann_ivf_pq.cuh | 2 -- cpp/test/neighbors/ann_utils.cuh | 2 -- cpp/test/neighbors/brute_force.cu | 2 -- 37 files changed, 71 insertions(+), 86 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 79e50c1c1..26c0b82d3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -486,13 +486,14 @@ if(BUILD_SHARED_LIBS) "$<$:${CUVS_CUDA_FLAGS}>" ) target_link_libraries( - cuvs_objs PUBLIC raft::raft rmm::rmm rmm::rmm_logger ${CUVS_CTK_MATH_DEPENDENCIES} - $ - PRIVATE rmm::rmm_logger_impl + cuvs_objs + PUBLIC raft::raft rmm::rmm rmm::rmm_logger ${CUVS_CTK_MATH_DEPENDENCIES} + $ + PRIVATE rmm::rmm_logger_impl raft::raft_logger_impl ) add_library(cuvs SHARED $,EXCLUDE,rmm.*logger>) - add_library(cuvs_static STATIC $,EXCLUDE,rmm.*logger>) + add_library(cuvs_static STATIC $,EXCLUDE,rmm.*logger>) target_compile_options( cuvs INTERFACE $<$:--expt-extended-lambda @@ -704,7 +705,7 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENAB target_link_libraries( cuvs_c PUBLIC cuvs::cuvs ${CUVS_CTK_MATH_DEPENDENCIES} - PRIVATE raft::raft rmm::rmm_logger_impl + PRIVATE raft::raft rmm::rmm_logger_impl raft::raft_logger_impl ) # ensure CUDA symbols aren't relocated to the middle of the debug build binaries diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 144cd3048..200b52ab3 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -126,10 +126,11 @@ function(ConfigureAnnBench) PRIVATE ${ConfigureAnnBench_LINKS} nlohmann_json::nlohmann_json Threads::Threads + $ $<$:CUDA::cudart_static> $ $ - $ + $ ) set_target_properties( @@ -175,9 +176,11 @@ function(ConfigureAnnBench) add_dependencies(CUVS_ANN_BENCH_ALL ${BENCH_NAME}) endfunction() -if(CUVS_FAISS_ENABLE_GPU) - add_library(cuvs_bench_rmm_logger OBJECT) - target_link_libraries(cuvs_bench_rmm_logger PRIVATE rmm::rmm_logger_impl) +if(CUVS_FAISS_ENABLE_GPU OR CUVS_ANN_BENCH_SINGLE_EXE) + add_library(cuvs_bench_logger OBJECT) + target_link_libraries( + cuvs_bench_logger PRIVATE rmm::rmm_logger_impl $ + ) endif() # ################################################################################################## @@ -303,8 +306,14 @@ if(CUVS_ANN_BENCH_SINGLE_EXE) target_link_libraries( ANN_BENCH - PRIVATE raft::raft nlohmann_json::nlohmann_json benchmark::benchmark dl fmt::fmt-header-only - spdlog::spdlog_header_only $<$:CUDA::nvtx3> rmm::rmm_logger_impl + PRIVATE raft::raft + nlohmann_json::nlohmann_json + benchmark::benchmark + dl + fmt::fmt-header-only + spdlog::spdlog_header_only + $<$:CUDA::nvtx3> + cuvs_bench_logger ) set_target_properties( ANN_BENCH diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 06e1e27af..49be78673 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -597,18 +597,16 @@ inline auto parse_string_flag(const char* arg, const char* pat, std::string& res inline auto run_main(int argc, char** argv) -> int { - bool force_overwrite = false; - bool build_mode = false; - bool search_mode = false; - bool no_lap_sync = false; - std::string data_prefix = "data"; - std::string index_prefix = "index"; - std::string new_override_kv = ""; - std::string mode = "latency"; - std::string threads_arg_txt = ""; - std::vector threads = {1, -1}; // min_thread, max_thread - std::string log_level_str = ""; - [[maybe_unused]] int raft_log_level = 0; // raft::logger::get(RAFT_NAME).get_level(); + bool force_overwrite = false; + bool build_mode = false; + bool search_mode = false; + bool no_lap_sync = false; + std::string data_prefix = "data"; + std::string index_prefix = "index"; + std::string new_override_kv = ""; + std::string mode = "latency"; + std::string threads_arg_txt = ""; + std::vector threads = {1, -1}; // min_thread, max_thread kv_series override_kv{}; char arg0_default[] = "benchmark"; // NOLINT @@ -639,12 +637,7 @@ inline auto run_main(int argc, char** argv) -> int parse_string_flag(argv[i], "--index_prefix", index_prefix) || parse_string_flag(argv[i], "--mode", mode) || parse_string_flag(argv[i], "--override_kv", new_override_kv) || - parse_string_flag(argv[i], "--threads", threads_arg_txt) || - parse_string_flag(argv[i], "--raft_log_level", log_level_str)) { - if (!log_level_str.empty()) { - raft_log_level = std::stoi(log_level_str); - log_level_str = ""; - } + parse_string_flag(argv[i], "--threads", threads_arg_txt)) { if (!threads_arg_txt.empty()) { auto threads_arg = split(threads_arg_txt, ':'); threads[0] = std::stoi(threads_arg[0]); @@ -673,8 +666,6 @@ inline auto run_main(int argc, char** argv) -> int } } - // raft::logger::get(RAFT_NAME).set_level(raft_log_level); - Mode metric_objective = Mode::kLatency; if (mode == "throughput") { metric_objective = Mode::kThroughput; } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 89b3acc24..cb8d36b10 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -85,7 +86,7 @@ struct params : base_params { /** * verbosity level. */ - int verbosity = RAFT_LEVEL_INFO; + raft::level_enum verbosity = raft::level_enum::info; /** * Seed to the random number generator. diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 3d054f0fd..e943b8afc 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include @@ -56,8 +56,6 @@ namespace cuvs::cluster::kmeans::detail { -// TODO(cjnolet): RAFT_NAME needs to be removed and the raft::logger fixed to not require it -static const std::string RAFT_NAME = "raft"; static const std::string CUVS_NAME = "cuvs"; // ========================================================= @@ -373,7 +371,7 @@ void kmeans_fit_main(raft::resources const& handle, rmm::device_uvector& workspace) { raft::common::nvtx::range fun_scope("kmeans_fit_main"); - raft::logger::get(RAFT_NAME).set_level(params.verbosity); + raft::default_logger().set_level(params.verbosity); cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); @@ -879,7 +877,7 @@ void kmeans_fit(raft::resources const& handle, pams.n_clusters); } - raft::logger::get(RAFT_NAME).set_level(pams.verbosity); + raft::default_logger().set_level(pams.verbosity); // Allocate memory rmm::device_uvector workspace(0, stream); @@ -1025,7 +1023,7 @@ void kmeans_predict(raft::resources const& handle, RAFT_EXPECTS(centroids.extent(1) == n_features, "invalid parameter (centroids.extent(1) != n_features)"); - raft::logger::get(RAFT_NAME).set_level(pams.verbosity); + raft::default_logger().set_level(pams.verbosity); auto metric = pams.metric; // Allocate memory @@ -1218,7 +1216,7 @@ void kmeans_transform(raft::resources const& handle, raft::device_matrix_view X_new) { raft::common::nvtx::range fun_scope("kmeans_transform"); - raft::logger::get(RAFT_NAME).set_level(pams.verbosity); + raft::default_logger().set_level(pams.verbosity); cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); diff --git a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh index 6441f7ad5..797b33bca 100644 --- a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh +++ b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -230,4 +230,4 @@ void find_k(raft::resources const& handle, n_iter); } } -} // namespace cuvs::cluster::kmeans::detail \ No newline at end of file +} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 3f1ad2334..ba4cabbde 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -25,7 +25,8 @@ #include #include -#include +#include +#include #include #include #include @@ -59,7 +60,6 @@ namespace cuvs::cluster::kmeans::detail { -static const std::string RAFT_NAME = "raft"; constexpr static inline float kAdjustCentersWeight = 7.0f; /** diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index eec71b5d2..03db08bd1 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh b/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh index 1f4b19af4..e44edc68a 100644 --- a/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh +++ b/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 529356351..149eea3f1 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 952039130..358b7643e 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -31,8 +31,6 @@ namespace cuvs::neighbors::cagra { -static const std::string RAFT_NAME = "raft"; - template void add_node_core( raft::resources const& handle, diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index b7fec724b..340986448 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -46,8 +46,6 @@ namespace cuvs::neighbors::cagra::detail { -static const std::string RAFT_NAME = "raft"; - template void write_to_graph(raft::host_matrix_view knn_graph, raft::host_matrix_view neighbors_host_view, diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 0f6cf852f..c83da7bb1 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include #include @@ -34,8 +34,6 @@ namespace cuvs::neighbors::cagra::detail { -static const std::string RAFT_NAME = "raft"; - constexpr int serialization_version = 4; /** diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 7eb798459..2227e4f9e 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include // TODO: This shouldn't be invoking spatial/knn diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index ecfd856f1..9cb432bcb 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 9fa9d5894..7535ff217 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -26,7 +26,7 @@ #include "utils.hpp" #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index c6fe21642..469c80a08 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -23,7 +23,7 @@ #include "utils.hpp" #include -#include +#include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index fa71dbaf9..161aa8c4a 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -26,7 +26,7 @@ #include "utils.hpp" #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 678ed0cb4..188862fbb 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -28,7 +28,7 @@ #include #include -#include +#include #include #include #include @@ -64,7 +64,6 @@ namespace cuvs::neighbors::cagra::detail { namespace single_cta_search { -using raft::RAFT_NAME; // TODO: this is required for RAFT_LOG_XXX messages. // #define _CLK_BREAKDOWN diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 0ecc2cf5d..ba3090b59 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -21,7 +21,7 @@ #include #include -#include +#include #include diff --git a/cpp/src/neighbors/detail/dynamic_batching.cuh b/cpp/src/neighbors/detail/dynamic_batching.cuh index 5c6b1654e..cb8e08ef5 100644 --- a/cpp/src/neighbors/detail/dynamic_batching.cuh +++ b/cpp/src/neighbors/detail/dynamic_batching.cuh @@ -50,8 +50,6 @@ namespace cuvs::neighbors::dynamic_batching::detail { -using raft::RAFT_NAME; // TODO: a workaround for RAFT_LOG_XXX macros - /** * A helper to make the requester threads more cooperative when busy-spinning. * It is used in the wait loops across this file to reduce the CPU usage. diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index da24decb3..ec75c99c1 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include #include @@ -52,8 +52,6 @@ namespace cuvs::neighbors::experimental::vamana::detail { * @{ */ -static const std::string RAFT_NAME = "raft"; - static const int blockD = 32; static const int maxBlocks = 10000; diff --git a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh index a554464f6..c360ae19a 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 86cb4e1f8..f6f0279f7 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index d85bad920..0d7882b4b 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index d6ffc1218..f594343c7 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -27,7 +27,8 @@ #include "../../cluster/kmeans_balanced.cuh" #include "../detail/ann_utils.cuh" #include -#include +#include +#include #include #include #include diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index f5a4267cd..79b4f1a18 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -23,7 +23,7 @@ #include "../detail/ann_utils.cuh" #include -#include // RAFT_LOG_TRACE +#include #include #include #include // RAFT_CUDA_TRY diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 032b6a8ff..2df6f4f0e 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -27,7 +27,8 @@ #include // is_min_close, DistanceType #include // cuvs::selection::select_k #include -#include // RAFT_LOG_TRACE +#include +#include #include #include // raft::resources #include // raft::linalg::gemm diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 1d4acea1e..44a1b11fa 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -30,7 +30,7 @@ #include "../../cluster/kmeans_balanced.cuh" #include -#include +#include #include #include #include @@ -68,7 +68,6 @@ #include namespace cuvs::neighbors::ivf_pq::detail { -using raft::RAFT_NAME; // TODO: this is required for RAFT_LOG_XXX messages. using namespace cuvs::spatial::knn::detail; // NOLINT using internal_extents_t = int64_t; // The default mdspan extent type used internally. diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh index 5b41e5f3d..1b098ac5c 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index db8f9fbd3..05bb99353 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh index 5eaebe69d..4af9dbb8e 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/src/neighbors/mg/omp_checks.cpp b/cpp/src/neighbors/mg/omp_checks.cpp index e09182dfe..c8cc27414 100644 --- a/cpp/src/neighbors/mg/omp_checks.cpp +++ b/cpp/src/neighbors/mg/omp_checks.cpp @@ -18,7 +18,6 @@ #include namespace cuvs::neighbors::mg { -using raft::RAFT_NAME; void check_omp_threads(const int requirements) { diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4d13daaed..cca061455 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -89,7 +89,7 @@ function(ConfigureTest) endfunction() add_library(test_rmm_logger OBJECT) -target_link_libraries(test_rmm_logger PRIVATE rmm::rmm_logger_impl) +target_link_libraries(test_rmm_logger PRIVATE rmm::rmm_logger_impl raft::raft_logger_impl) # ################################################################################################## # test sources ################################################################################## @@ -236,7 +236,7 @@ if(BUILD_TESTS) NAME SPARSE_TEST PATH sparse/cluster/cluster_solvers.cu sparse/cluster/eigen_solvers.cu sparse/cluster/spectral.cu GPUS 1 PERCENT 100 ) - + ConfigureTest( NAME PREPROCESSING_TEST PATH preprocessing/scalar_quantization.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 3a92b5e3d..01efd804e 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -31,8 +31,6 @@ namespace cuvs::neighbors::ivf_pq { -using raft::RAFT_NAME; // For logging - struct test_ivf_sample_filter { static constexpr unsigned offset = 300; }; diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 94bccade2..ded8cb5af 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -38,8 +38,6 @@ namespace cuvs::neighbors { -using raft::RAFT_NAME; // For logging - struct print_dtype { cudaDataType_t value; }; diff --git a/cpp/test/neighbors/brute_force.cu b/cpp/test/neighbors/brute_force.cu index 8c354baa9..2cefb1098 100644 --- a/cpp/test/neighbors/brute_force.cu +++ b/cpp/test/neighbors/brute_force.cu @@ -76,11 +76,9 @@ class KNNTest : public ::testing::TestWithParam> { protected: void testBruteForce() { - // #if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout); std::cout << "K: " << k_ << std::endl; raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); - // #endif auto index = raft::make_device_matrix_view( (const T*)(input_.data()), rows_, cols_); From 55c5a7f0f9c3e103a33264a913dbd17b059eff78 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Mon, 30 Dec 2024 18:48:13 -0800 Subject: [PATCH 6/7] Get Breathe from conda again (#554) As part of https://github.com/rapidsai/cuvs/pull/528 cuvs's doc builds were modified to pull Breathe from pip. That was necessary because the nvidia-sphinx-theme requires Sphinx 8 but [the conda-forge Breathe package was not compatible with that Sphinx version](https://github.com/conda-forge/breathe-feedstock/issues/63). I fixed that in https://github.com/conda-forge/breathe-feedstock/pull/64, so now we can go back to using Breathe from conda to avoid mixing pip and conda for dependency management in the same environment. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cuvs/pull/554 --- conda/environments/all_cuda-118_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-125_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- dependencies.yaml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index a6d98ea3b..01853da84 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -7,6 +7,7 @@ channels: - conda-forge - nvidia dependencies: +- breathe>=4.35.0 - c-compiler - clang - clang-tools=16.0.6 @@ -56,6 +57,5 @@ dependencies: - sphinx>=8.0.0 - sysroot_linux-aarch64==2.17 - pip: - - breathe>=4.35.0 - nvidia-sphinx-theme name: all_cuda-118_arch-aarch64 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 1063e4d6c..a1ad68d7f 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -7,6 +7,7 @@ channels: - conda-forge - nvidia dependencies: +- breathe>=4.35.0 - c-compiler - clang - clang-tools=16.0.6 @@ -56,6 +57,5 @@ dependencies: - sphinx>=8.0.0 - sysroot_linux-64==2.17 - pip: - - breathe>=4.35.0 - nvidia-sphinx-theme name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index ee7b37695..ee0213fff 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -7,6 +7,7 @@ channels: - conda-forge - nvidia dependencies: +- breathe>=4.35.0 - c-compiler - clang - clang-tools=16.0.6 @@ -52,6 +53,5 @@ dependencies: - sphinx>=8.0.0 - sysroot_linux-aarch64==2.17 - pip: - - breathe>=4.35.0 - nvidia-sphinx-theme name: all_cuda-125_arch-aarch64 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 7c8e1fd99..d93dcaf7a 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -7,6 +7,7 @@ channels: - conda-forge - nvidia dependencies: +- breathe>=4.35.0 - c-compiler - clang - clang-tools=16.0.6 @@ -52,6 +53,5 @@ dependencies: - sphinx>=8.0.0 - sysroot_linux-64==2.17 - pip: - - breathe>=4.35.0 - nvidia-sphinx-theme name: all_cuda-125_arch-x86_64 diff --git a/dependencies.yaml b/dependencies.yaml index a73fe7b8f..a11e59e31 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -394,6 +394,7 @@ dependencies: common: - output_types: [conda] packages: + - breathe>=4.35.0 - doxygen>=1.8.20 - graphviz - ipython @@ -404,7 +405,6 @@ dependencies: - sphinx-markdown-tables - pip: - nvidia-sphinx-theme - - breathe>=4.35.0 rust: common: - output_types: [conda] From 0e735ea025f8e1e24e8e9b3d3f2ac502711f5387 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 7 Jan 2025 12:32:02 -0600 Subject: [PATCH 7/7] remove setup.cfg files, other packaging cleanup (#544) Similar to https://github.com/rapidsai/raft/pull/2532, this proposes some small packaging cleanup. * removes `setup.cfg` files - *these are currently being ignored by tools, in favor of identical configuration in `pyproject.toml` and `.flake8` files* - e.g. https://github.com/rapidsai/cuvs/blob/b3ce774d39e149d4e34c401068f24136eac44e13/.pre-commit-config.yaml#L31-L35 * alphabetizes dependency lists in `dependencies.yaml` * changes `cupy:` group in `dependencies.yaml` to `depends_on_cupy:` (for consistency with other dependencies) Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/cuvs/pull/544 --- .pre-commit-config.yaml | 3 +-- dependencies.yaml | 40 +++++++++++++++--------------- pyproject.toml | 4 +-- python/cuvs/setup.cfg | 39 ----------------------------- setup.cfg | 55 ----------------------------------------- 5 files changed, 23 insertions(+), 118 deletions(-) delete mode 100644 python/cuvs/setup.cfg delete mode 100644 setup.cfg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e53abd92..fcfc7e1fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -108,8 +108,7 @@ repos: [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx|rs)$| CMakeLists[.]txt$| CMakeLists_standalone[.]txt$| - meta[.]yaml$| - setup[.]cfg$ + meta[.]yaml$ exclude: | (?x) docs/source/sphinxext/github_link\.py| diff --git a/dependencies.yaml b/dependencies.yaml index a11e59e31..fbd1d8372 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -7,39 +7,39 @@ files: arch: [x86_64, aarch64] includes: - build - - rapids_build - build_py_cuvs + - build_wheels + - checks - cuda - cuda_version - - depends_on_pylibraft + - depends_on_cupy - depends_on_librmm + - depends_on_pylibraft - develop - - checks - - build_wheels - - test_libcuvs - docs + - rapids_build - run_py_cuvs + - rust + - test_libcuvs - test_python_common - test_py_cuvs - - cupy - - rust bench_ann: output: conda matrix: cuda: ["11.8", "12.5"] arch: [x86_64, aarch64] includes: - - rapids_build + - bench + - bench_python - build_py_cuvs - cuda - cuda_version + - depends_on_cupy - depends_on_pylibraft - depends_on_librmm - develop - - bench - - bench_python + - rapids_build - rapids_build_setuptools - - cupy test_cpp: output: none includes: @@ -49,10 +49,10 @@ files: output: none includes: - cuda_version + - depends_on_cupy - py_version - test_python_common - test_py_cuvs - - cupy checks: output: none includes: @@ -61,19 +61,19 @@ files: docs: output: none includes: + - cuda - cuda_version - - cupy + - depends_on_cupy - docs - py_version - - rust - rapids_build - - cuda + - rust rust: output: none includes: + - cuda - cuda_version - rapids_build - - cuda - rust py_build_cuvs: output: pyproject @@ -89,8 +89,8 @@ files: table: tool.rapids-build-backend key: requires includes: - - rapids_build - build_py_cuvs + - rapids_build py_run_cuvs: output: pyproject pyproject_dir: python/cuvs @@ -98,8 +98,8 @@ files: table: project includes: - cuda_wheels - - run_py_cuvs - depends_on_pylibraft + - run_py_cuvs py_test_cuvs: output: pyproject pyproject_dir: python/cuvs @@ -107,9 +107,9 @@ files: table: project.optional-dependencies key: test includes: + - depends_on_cupy - test_python_common - test_py_cuvs - - cupy py_build_cuvs_bench: output: pyproject pyproject_dir: python/cuvs_bench @@ -368,7 +368,7 @@ dependencies: - nvidia-cusolver - nvidia-cusparse - cupy: + depends_on_cupy: common: - output_types: conda packages: diff --git a/pyproject.toml b/pyproject.toml index fbf4cf41f..417514466 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ force-exclude = ''' # unlike the match option above this match-dir will have no effect when # pydocstyle is invoked from pre-commit. Therefore this exclusion list must # also be maintained in the pre-commit config file. -match-dir = "^(?!(ci|cpp|conda|docs)).*$" +match-dir = "^(?!(ci|cpp|conda|docs|notebooks)).*$" select = "D201, D204, D206, D207, D208, D209, D210, D211, D214, D215, D300, D301, D302, D403, D405, D406, D407, D408, D409, D410, D411, D412, D414, D418" # Would like to enable the following rules in the future: # D200, D202, D205, D400 @@ -42,6 +42,6 @@ follow_imports = "skip" skip = "./.git,./.github,./cpp/build,.*egg-info.*,./.mypy_cache,.*_skbuild" # ignore short words, and typename parameters like OffsetT ignore-regex = "\\b(.{1,4}|[A-Z]\\w*T)\\b" -ignore-words-list = "inout,numer" +ignore-words-list = "inout,unparseable,numer" builtin = "clear" quiet-level = 3 diff --git a/python/cuvs/setup.cfg b/python/cuvs/setup.cfg deleted file mode 100644 index 57b4954bc..000000000 --- a/python/cuvs/setup.cfg +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. - -[isort] -line_length=79 -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -combine_as_imports=True -order_by_type=True -known_dask= - dask - distributed - dask_cuda -known_rapids= - cuvs - nvtext - cudf - cuml - raft - cugraph - dask_cudf - rmm -known_first_party= - cuvs -default_section=THIRDPARTY -sections=FUTURE,STDLIB,THIRDPARTY,DASK,RAPIDS,FIRSTPARTY,LOCALFOLDER -skip= - thirdparty - .eggs - .git - .hg - .mypy_cache - .tox - .venv - _build - buck-out - build - dist - __init__.py diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index e64641d05..000000000 --- a/setup.cfg +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. - -[flake8] -filename = *.py, *.pyx, *.pxd, *.pxi -exclude = __init__.py, *.egg, build, docs, .git -force-check = True -ignore = - # line break before binary operator - W503, - # whitespace before : - E203 -per-file-ignores = - # Rules ignored only in Cython: - # E211: whitespace before '(' (used in multi-line imports) - # E225: Missing whitespace around operators (breaks cython casting syntax like ) - # E226: Missing whitespace around arithmetic operators (breaks cython pointer syntax like int*) - # E227: Missing whitespace around bitwise or shift operator (Can also break casting syntax) - # E275: Missing whitespace after keyword (Doesn't work with Cython except?) - # E402: invalid syntax (works for Python, not Cython) - # E999: invalid syntax (works for Python, not Cython) - # W504: line break after binary operator (breaks lines that end with a pointer) - *.pyx: E211, E225, E226, E227, E275, E402, E999, W504 - *.pxd: E211, E225, E226, E227, E275, E402, E999, W504 - *.pxi: E211, E225, E226, E227, E275, E402, E999, W504 - -[pydocstyle] -# Due to https://github.com/PyCQA/pydocstyle/issues/363, we must exclude rather -# than include using match-dir. Note that as discussed in -# https://stackoverflow.com/questions/65478393/how-to-filter-directories-using-the-match-dir-flag-for-pydocstyle, -# unlike the match option above this match-dir will have no effect when -# pydocstyle is invoked from pre-commit. Therefore this exclusion list must -# also be maintained in the pre-commit config file. -match-dir = ^(?!(ci|cpp|conda|docs|java|notebooks)).*$ -# Allow missing docstrings for docutils -ignore-decorators = .*(docutils|doc_apply|copy_docstring).* -select = - D201, D204, D206, D207, D208, D209, D210, D211, D214, D215, D300, D301, D302, D403, D405, D406, D407, D408, D409, D410, D411, D412, D414, D418 - # Would like to enable the following rules in the future: - # D200, D202, D205, D400 - -[mypy] -ignore_missing_imports = True -# If we don't specify this, then mypy will check excluded files if -# they are imported by a checked file. -follow_imports = skip - -[codespell] -# note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - -# this is only to allow you to run codespell interactively -skip = ./.git,./.github,./cpp/build,.*egg-info.*,./.mypy_cache,.*_skbuild -# ignore short words, and typename parameters like OffsetT -ignore-regex = \b(.{1,4}|[A-Z]\w*T)\b -ignore-words-list = inout,unparseable,numer -builtin = clear -quiet-level = 3