From 6bb9600c7e823eaff5e9a4ee4fc4bab2d294fcf5 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 13 Jan 2025 04:38:01 -0800 Subject: [PATCH] Switch to new default for UMAP Leave data on the host by default and enable batching when using NN descrent. --- python/cuml/cuml/manifold/umap.pyx | 236 +++++++++++++++++----------- python/cuml/cuml/tests/test_umap.py | 33 ++++ 2 files changed, 178 insertions(+), 91 deletions(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 079b270d0a..84f62862e4 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -305,6 +305,9 @@ class UMAP(UniversalBase, 'nnd_n_clusters': 1} Note that nnd_n_clusters > 1 will result in batch-building with NN Descent. + .. versionchanged:: 25.06 + The default value for `nnd_n_clusters` will change from 1 to 10 in 25.06. + Notes ----- This module is heavily based on Leland McInnes' reference UMAP package. @@ -428,7 +431,6 @@ class UMAP(UniversalBase, self.validate_hyperparams() - self.sparse_fit = False self._input_hash = None self._small_data = False @@ -448,18 +450,7 @@ class UMAP(UniversalBase, # perform the necessary arithmetic. logger.set_level(logger.level_enum(6 - self._verbose)) - if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": - if self.deterministic and build_algo == "auto": - # TODO: for now, users should be able to see the same results as previous version - # (i.e. running brute force knn) when they explicitly pass random_state - # https://github.com/rapidsai/cuml/issues/5985 - logger.info("build_algo set to brute_force_knn because random_state is given") - self.build_algo ="brute_force_knn" - else: - self.build_algo = build_algo - else: - raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo) - + self.build_algo = build_algo self.build_kwds = build_kwds def validate_hyperparams(self): @@ -467,77 +458,89 @@ class UMAP(UniversalBase, if self.min_dist > self.spread: raise ValueError("min_dist should be <= spread") - @staticmethod - def _build_umap_params(cls, sparse): + def _build_umap_params(self, build_algo): IF GPUBUILD == 1: cdef UMAPParams* umap_params = new UMAPParams() - umap_params.n_neighbors = cls.n_neighbors - umap_params.n_components = cls.n_components - umap_params.n_epochs = cls.n_epochs if cls.n_epochs else 0 - umap_params.learning_rate = cls.learning_rate - umap_params.min_dist = cls.min_dist - umap_params.spread = cls.spread - umap_params.set_op_mix_ratio = cls.set_op_mix_ratio - umap_params.local_connectivity = cls.local_connectivity - umap_params.repulsion_strength = cls.repulsion_strength - umap_params.negative_sample_rate = cls.negative_sample_rate - umap_params.transform_queue_size = cls.transform_queue_size - umap_params.verbosity = cls.verbose - umap_params.a = cls.a - umap_params.b = cls.b - if cls.init == "spectral": + umap_params.n_neighbors = self.n_neighbors + umap_params.n_components = self.n_components + umap_params.n_epochs = self.n_epochs if self.n_epochs else 0 + umap_params.learning_rate = self.learning_rate + umap_params.min_dist = self.min_dist + umap_params.spread = self.spread + umap_params.set_op_mix_ratio = self.set_op_mix_ratio + umap_params.local_connectivity = self.local_connectivity + umap_params.repulsion_strength = self.repulsion_strength + umap_params.negative_sample_rate = self.negative_sample_rate + umap_params.transform_queue_size = self.transform_queue_size + umap_params.verbosity = self.verbose + umap_params.a = self.a + umap_params.b = self.b + if self.init == "spectral": umap_params.init = 1 else: # self.init == "random" umap_params.init = 0 - umap_params.target_n_neighbors = cls.target_n_neighbors - if cls.target_metric == "euclidean": + umap_params.target_n_neighbors = self.target_n_neighbors + if self.target_metric == "euclidean": umap_params.target_metric = MetricType.EUCLIDEAN else: # self.target_metric == "categorical" umap_params.target_metric = MetricType.CATEGORICAL - if cls.build_algo == "brute_force_knn": + if build_algo == "brute_force_knn": umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN else: # self.init == "nn_descent" umap_params.build_algo = graph_build_algo.NN_DESCENT - if cls.build_kwds is None: + + if self.build_kwds is None: umap_params.nn_descent_params.graph_degree = 64 umap_params.nn_descent_params.intermediate_graph_degree = 128 umap_params.nn_descent_params.max_iterations = 20 umap_params.nn_descent_params.termination_threshold = 0.0001 umap_params.nn_descent_params.return_distances = True - umap_params.nn_descent_params.n_clusters = 1 + warnings.warn( + "The default value of `nnd_n_clusters` will change from 1 to 10 in 25.06.", + FutureWarning + ) + umap_params.nn_descent_params.n_clusters = 10 else: - umap_params.nn_descent_params.graph_degree = cls.build_kwds.get("nnd_graph_degree", 64) - umap_params.nn_descent_params.intermediate_graph_degree = cls.build_kwds.get("nnd_intermediate_graph_degree", 128) - umap_params.nn_descent_params.max_iterations = cls.build_kwds.get("nnd_max_iterations", 20) - umap_params.nn_descent_params.termination_threshold = cls.build_kwds.get("nnd_termination_threshold", 0.0001) - umap_params.nn_descent_params.return_distances = cls.build_kwds.get("nnd_return_distances", True) - if cls.build_kwds.get("nnd_n_clusters", 1) < 1: - logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1") - umap_params.nn_descent_params.n_clusters = cls.build_kwds.get("nnd_n_clusters", 1) - - umap_params.target_weight = cls.target_weight - umap_params.random_state = cls.random_state - umap_params.deterministic = cls.deterministic + n_clusters = self.build_kwds.get("nnd_n_clusters", "warn") + if n_clusters == "warn": + warnings.warn( + "The default value of `nnd_n_clusters` will change from 1 to 10 in 25.06.", + FutureWarning + ) + n_clusters = 10 + umap_params.nn_descent_params.graph_degree = self.build_kwds.get("nnd_graph_degree", 64) + umap_params.nn_descent_params.intermediate_graph_degree = self.build_kwds.get("nnd_intermediate_graph_degree", 128) + umap_params.nn_descent_params.max_iterations = self.build_kwds.get("nnd_max_iterations", 20) + umap_params.nn_descent_params.termination_threshold = self.build_kwds.get("nnd_termination_threshold", 0.0001) + umap_params.nn_descent_params.return_distances = self.build_kwds.get("nnd_return_distances", True) + if self.build_kwds.get("nnd_n_clusters", 1) < 1: + # XXX this is broken/doesn't do what it says?? + logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 10") + umap_params.nn_descent_params.n_clusters = n_clusters + + umap_params.target_weight = self.target_weight + umap_params.random_state = self.random_state + umap_params.deterministic = self.deterministic try: - umap_params.metric = metric_parsing[cls.metric.lower()] - if sparse: + umap_params.metric = metric_parsing[self.metric.lower()] + if self.sparse_fit_: if umap_params.metric not in SPARSE_SUPPORTED_METRICS: - raise NotImplementedError(f"Metric '{cls.metric}' not supported for sparse inputs.") + raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.") elif umap_params.metric not in DENSE_SUPPORTED_METRICS: - raise NotImplementedError(f"Metric '{cls.metric}' not supported for dense inputs.") + raise NotImplementedError(f"Metric '{self.metric}' not supported for dense inputs.") except KeyError: - raise ValueError(f"Invalid value for metric: {cls.metric}") + raise ValueError(f"Invalid value for metric: {self.metric}") - if cls.metric_kwds is None: + if self.metric_kwds is None: umap_params.p = 2.0 else: - umap_params.p = cls.metric_kwds.get('p') + umap_params.p = self.metric_kwds.get('p') cdef uintptr_t callback_ptr = 0 - if cls.callback: - callback_ptr = cls.callback.get_native_callback() + if self.callback: + callback_ptr = self.callback.get_native_callback() umap_params.callback = callback_ptr return umap_params @@ -558,21 +561,31 @@ class UMAP(UniversalBase, skip_parameters_heading=True) @enable_device_interop def fit(self, X, y=None, convert_dtype=True, - knn_graph=None, data_on_host=False) -> "UMAP": + knn_graph=None, data_on_host="warn") -> "UMAP": """ Fit X into an embedded space. Parameters ---------- knn_graph : array / sparse array / tuple, optional (device or host) - Either one of a tuple (indices, distances) of - arrays of shape (n_samples, n_neighbors), a pairwise distances - dense array of shape (n_samples, n_samples) or a KNN graph - sparse array (preferably CSR/COO). This feature allows - the precomputation of the KNN outside of UMAP - and also allows the use of a custom distance function. This function - should match the metric used to train the UMAP embeedings. - Takes precedence over the precomputed_knn parameter. + Either one of a tuple (indices, distances) of + arrays of shape (n_samples, n_neighbors), a pairwise distances + dense array of shape (n_samples, n_samples) or a KNN graph + sparse array (preferably CSR/COO). This feature allows + the precomputation of the KNN outside of UMAP + and also allows the use of a custom distance function. This function + should match the metric used to train the UMAP embeedings. + Takes precedence over the precomputed_knn parameter. + + data_on_host : bool, default=True + Whether to move input data to the host or not. + + With the default value of "auto" the data will be left on the host + when the nn descent algorithm is being used and the data has more + than 50000 rows. + + .. versionchanged:: 25.06 + The default value for `data_on_host` will change from False to "auto" in 25.06. """ if len(X.shape) != 2: raise ValueError("data should be two dimensional") @@ -584,22 +597,38 @@ class UMAP(UniversalBase, # Handle sparse inputs if is_sparse(X): - self._raw_data = SparseCumlArray(X, convert_to_dtype=cupy.float32, convert_format=False) self.n_rows, self.n_dims = self._raw_data.shape - self.sparse_fit = True + self.sparse_fit_ = True self._sparse_data = True + convert_to_mem_type = None if self.build_algo == "nn_descent": raise ValueError("NN Descent does not support sparse inputs") # Handle dense inputs else: + self.sparse_fit_ = False self._sparse_data = False - if data_on_host: - convert_to_mem_type = MemoryType.host + + if data_on_host == "warn": + warnings.warn( + 'The default value of `data_on_host` will change from False to "auto" in 25.06.', + FutureWarning + ) + data_on_host = "auto" + + if data_on_host == "auto": + if self.build_algo != "brute_force_knn" and X.shape[0] > 50000: + convert_to_mem_type = MemoryType.host + else: + convert_to_mem_type = MemoryType.device else: - convert_to_mem_type = MemoryType.device + # `data_on_host` isn't a string so we can use simple boolean conditions + if data_on_host: + convert_to_mem_type = MemoryType.host + else: + convert_to_mem_type = MemoryType.device self._raw_data, self.n_rows, self.n_dims, _ = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, @@ -608,22 +637,37 @@ class UMAP(UniversalBase, else None), convert_to_mem_type=convert_to_mem_type) - if self.build_algo == "auto": - if self.n_rows <= 50000 or self.sparse_fit: - # brute force is faster for small datasets - logger.info("Building knn graph using brute force") - self.build_algo = "brute_force_knn" + if self.build_algo in ("auto", "brute_force_knn", "nn_descent"): + if self.build_algo == "auto": + if self.deterministic: + # TODO: for now, users should be able to see the same results as previous version + # (i.e. running brute force knn) when they explicitly pass random_state + # https://github.com/rapidsai/cuml/issues/5985 + logger.info("build_algo set to brute_force_knn because random_state is given") + self.build_algo_ = "brute_force_knn" + + elif self.n_rows <= 50000 or self.sparse_fit_: + # brute force is faster for small datasets + logger.info("Building knn graph using brute force") + self.build_algo_ = "brute_force_knn" + + else: + logger.info("Building knn graph using nn descent") + self.build_algo_ = "nn_descent" + else: - logger.info("Building knn graph using nn descent") - self.build_algo = "nn_descent" + self.build_algo_ = self.build_algo + + else: + raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % self.build_algo) - if self.build_algo == "brute_force_knn" and data_on_host: + if self.build_algo_ == "brute_force_knn" and convert_to_mem_type == MemoryType.host: raise ValueError("Data cannot be on host for building with brute force knn") if self.n_rows <= 1: raise ValueError("There needs to be more than 1 sample to " "build nearest the neighbors graph") - if self.build_algo == "nn_descent" and self.n_rows < 150: + if self.build_algo_ == "nn_descent" and self.n_rows < 150: # https://github.com/rapidsai/cuvs/issues/184 warnings.warn("using nn_descent as build_algo on a small dataset (< 150 samples) is unstable") @@ -636,7 +680,7 @@ class UMAP(UniversalBase, elif self.precomputed_knn is not None: knn_indices, knn_dists = self.precomputed_knn - if self.sparse_fit: + if self.sparse_fit_: knn_indices, _, _, _ = \ input_to_cuml_array(knn_indices, convert_to_dtype=np.int32) @@ -672,9 +716,8 @@ class UMAP(UniversalBase, self.handle.getHandle() fss_graph = GraphHolder.new_graph(handle_.get_stream()) cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self, - self.sparse_fit) - if self.sparse_fit: + self._build_umap_params(self.build_algo_) + if self.sparse_fit_: fit_sparse(handle_[0], self._raw_data.indptr.ptr, self._raw_data.indices.ptr, @@ -720,7 +763,7 @@ class UMAP(UniversalBase, @cuml.internals.api_base_fit_transform() @enable_device_interop def fit_transform(self, X, y=None, convert_dtype=True, - knn_graph=None, data_on_host=False) -> CumlArray: + knn_graph=None, data_on_host="warn") -> CumlArray: """ Fit X into an embedded space and return that transformed output. @@ -752,6 +795,16 @@ class UMAP(UniversalBase, Acceptable formats: sparse SciPy ndarray, CuPy device ndarray, CSR/COO preferred other formats will go through conversion to CSR + data_on_host : bool, default="auto" + Whether to move input data to the host or not. + + With the default value of "auto" the data will be left on the host + when the nn descent algorithm is being used and the data has more + than 50000 rows. + + .. versionchanged:: 25.06 + The default value for `data_on_host` will change from False to "auto" in 25.06. + """ self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph, data_on_host=data_on_host) @@ -782,13 +835,13 @@ class UMAP(UniversalBase, if len(X.shape) != 2: raise ValueError("X should be two dimensional") - if is_sparse(X) and not self.sparse_fit: + if is_sparse(X) and not self.sparse_fit_: logger.warn("Model was trained on dense data but sparse " "data was provided to transform(). Converting " "to dense.") X = X.todense() - elif not is_sparse(X) and self.sparse_fit: + elif not is_sparse(X) and self.sparse_fit_: logger.warn("Model was trained on sparse data but dense " "data was provided to transform(). Converting " "to sparse.") @@ -825,17 +878,18 @@ class UMAP(UniversalBase, cdef uintptr_t _embed_ptr = self.embedding_.ptr # NN Descent doesn't support transform yet - if self.build_algo == "nn_descent" or self.build_algo == "auto": - self.build_algo = "brute_force_knn" + if self.build_algo_ == "nn_descent" or self.build_algo_ == "auto": + build_algo = "brute_force_knn" logger.info("Transform can only be run with brute force. Using brute force.") + else: + build_algo = self.build_algo_ IF GPUBUILD == 1: cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self, - self.sparse_fit) + self._build_umap_params(build_algo) cdef handle_t * handle_ = \ self.handle.getHandle() - if self.sparse_fit: + if self.sparse_fit_: transform_sparse(handle_[0], X_m.indptr.ptr, X_m.indices.ptr, diff --git a/python/cuml/cuml/tests/test_umap.py b/python/cuml/cuml/tests/test_umap.py index 6d91012177..4ef3512263 100644 --- a/python/cuml/cuml/tests/test_umap.py +++ b/python/cuml/cuml/tests/test_umap.py @@ -54,6 +54,39 @@ dataset_names = ["iris", "digits", "wine", "blobs"] +def test_new_data_on_host_default(): + data, labels = make_blobs( + # Make the data big enough so that we can have it on the host + n_samples=50_000 + 1, + n_features=10, + centers=5, + random_state=0, + ) + u = cuUMAP() + + with pytest.warns( + FutureWarning, + match='The default value of `data_on_host` will change from False to "auto" in 25.06.', + ): + u.fit(data) + u.fit_transform(data) + + # No warnings when value is explicitly set + u.fit(data, data_on_host=True) + u.fit_transform(data, data_on_host=True) + u.fit(data, data_on_host=False) + u.fit_transform(data, data_on_host=False) + + # XXX crashes with CUDA memory error, why? + """ + # No warning when the data is sparse + print("E") + data = scipy_sparse.csr_matrix(data) + u = cuUMAP() + u.fit_transform(data) + """ + + @pytest.mark.parametrize( "nrows", [unit_param(500), quality_param(5000), stress_param(500000)] )