Skip to content

Commit

Permalink
FIX missing imports and other small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd committed Dec 11, 2024
1 parent 3a1a31e commit 08ce1d7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
6 changes: 5 additions & 1 deletion python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ except ImportError:

import cuml
import cuml.common
from cuml.common.sparse_utils import is_sparse
import cuml.internals.logger as logger
import cuml.internals
import cuml.internals.input_utils
Expand All @@ -47,6 +48,7 @@ from cuml.internals.input_utils import (
determine_array_type,
input_to_cuml_array,
input_to_host_array,
input_to_host_array_with_sparse_support,
is_array_like
)
from cuml.internals.memory_utils import determine_array_memtype
Expand Down Expand Up @@ -676,7 +678,9 @@ class UniversalBase(Base):

def args_to_cpu(self, *args, **kwargs):
# put all the args on host
new_args = tuple(input_to_host_array_with_sparse_support(arg) for arg in args)
new_args = tuple(
input_to_host_array_with_sparse_support(arg) for arg in args
)

# put all the kwargs on host
new_kwargs = dict()
Expand Down
5 changes: 5 additions & 0 deletions python/cuml/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ def input_to_host_array(


def input_to_host_array_with_sparse_support(X):
try:
if scipy_sparse.isspmatrix(X):
return X
except UnavailableError:
pass
_array_type, is_sparse = determine_array_type_full(X)
if is_sparse:
if _array_type == "cupy":
Expand Down
18 changes: 12 additions & 6 deletions python/cuml/cuml/tests/experimental/accel/test_sparse_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,17 @@
from sklearn.decomposition import TruncatedSVD
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import (
LinearRegression, LogisticRegression, ElasticNet,
Ridge, Lasso
LinearRegression,
LogisticRegression,
ElasticNet,
Ridge,
Lasso,
)
from sklearn.neighbors import (
NearestNeighbors,
KNeighborsClassifier,
KNeighborsRegressor,
)
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier, KNeighborsRegressor
from sklearn.base import is_classifier, is_regressor


Expand All @@ -44,15 +51,14 @@

@pytest.mark.parametrize("estimator_name", list(estimators.keys()))
def test_sparse_support(estimator_name):
X_sparse = csr_matrix([[0, 1],
[1, 0]])
X_sparse = csr_matrix([[0, 1], [1, 0]])
print(X_sparse.shape[0])
y_class = np.array([0, 1])
y_reg = np.array([0.0, 1.0])
estimator = estimators[estimator_name]()
# Fit or fit_transform depending on the estimator type
if isinstance(estimator, (KMeans, DBSCAN, TruncatedSVD, NearestNeighbors)):
if hasattr(estimator, 'fit_transform'):
if hasattr(estimator, "fit_transform"):
estimator.fit_transform(X_sparse)
else:
estimator.fit(X_sparse)
Expand Down

0 comments on commit 08ce1d7

Please sign in to comment.