Skip to content

Commit

Permalink
Handle CPU/GPU interop
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Jan 14, 2025
1 parent 6bb9600 commit 52b0bfa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
25 changes: 12 additions & 13 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class UMAP(UniversalBase,
build_kwds: dict (optional, default=None)
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True,
'nnd_n_clusters': 1}
'nnd_n_clusters': 10}
Note that nnd_n_clusters > 1 will result in batch-building with NN Descent.
.. versionchanged:: 25.06
Expand Down Expand Up @@ -458,7 +458,7 @@ class UMAP(UniversalBase,
if self.min_dist > self.spread:
raise ValueError("min_dist should be <= spread")

def _build_umap_params(self, build_algo):
def _build_umap_params(self, build_algo, sparse):
IF GPUBUILD == 1:
cdef UMAPParams* umap_params = new UMAPParams()
umap_params.n_neighbors = <int> self.n_neighbors
Expand Down Expand Up @@ -499,15 +499,15 @@ class UMAP(UniversalBase,
"The default value of `nnd_n_clusters` will change from 1 to 10 in 25.06.",
FutureWarning
)
umap_params.nn_descent_params.n_clusters = <uint64_t> 10
umap_params.nn_descent_params.n_clusters = <uint64_t> 3
else:
n_clusters = self.build_kwds.get("nnd_n_clusters", "warn")
if n_clusters == "warn":
warnings.warn(
"The default value of `nnd_n_clusters` will change from 1 to 10 in 25.06.",
FutureWarning
)
n_clusters = 10
n_clusters = 3
umap_params.nn_descent_params.graph_degree = <uint64_t> self.build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> self.build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> self.build_kwds.get("nnd_max_iterations", 20)
Expand All @@ -524,7 +524,7 @@ class UMAP(UniversalBase,

try:
umap_params.metric = metric_parsing[self.metric.lower()]
if self.sparse_fit_:
if sparse:
if umap_params.metric not in SPARSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.")
elif umap_params.metric not in DENSE_SUPPORTED_METRICS:
Expand Down Expand Up @@ -716,7 +716,7 @@ class UMAP(UniversalBase,
<handle_t*> <size_t> self.handle.getHandle()
fss_graph = GraphHolder.new_graph(handle_.get_stream())
cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> self._build_umap_params(self.build_algo_)
<UMAPParams*> <size_t> self._build_umap_params(self.build_algo_, self.sparse_fit_)
if self.sparse_fit_:
fit_sparse(handle_[0],
<int*><uintptr_t> self._raw_data.indptr.ptr,
Expand Down Expand Up @@ -835,13 +835,13 @@ class UMAP(UniversalBase,
if len(X.shape) != 2:
raise ValueError("X should be two dimensional")

if is_sparse(X) and not self.sparse_fit_:
if is_sparse(X) and not self._sparse_data:
logger.warn("Model was trained on dense data but sparse "
"data was provided to transform(). Converting "
"to dense.")
X = X.todense()

elif not is_sparse(X) and self.sparse_fit_:
elif not is_sparse(X) and self._sparse_data:
logger.warn("Model was trained on sparse data but dense "
"data was provided to transform(). Converting "
"to sparse.")
Expand Down Expand Up @@ -877,19 +877,18 @@ class UMAP(UniversalBase,

cdef uintptr_t _embed_ptr = self.embedding_.ptr

build_algo = getattr(self, "build_algo_", "brute_force_knn")
# NN Descent doesn't support transform yet
if self.build_algo_ == "nn_descent" or self.build_algo_ == "auto":
if build_algo == "nn_descent" or build_algo == "auto":
build_algo = "brute_force_knn"
logger.info("Transform can only be run with brute force. Using brute force.")
else:
build_algo = self.build_algo_

IF GPUBUILD == 1:
cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> self._build_umap_params(build_algo)
<UMAPParams*> <size_t> self._build_umap_params(build_algo, self._sparse_data)
cdef handle_t * handle_ = \
<handle_t*> <size_t> self.handle.getHandle()
if self.sparse_fit_:
if self._sparse_data:
transform_sparse(handle_[0],
<int*><uintptr_t> X_m.indptr.ptr,
<int*><uintptr_t> X_m.indices.ptr,
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_new_data_on_host_default():
u.fit(data, data_on_host=False)
u.fit_transform(data, data_on_host=False)

# XXX crashes with CUDA memory error, why?
# XXX crashes with CUDA memory error, why? Too many rows?
"""
# No warning when the data is sparse
print("E")
Expand Down

0 comments on commit 52b0bfa

Please sign in to comment.