Skip to content

Commit

Permalink
Support RF Classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Jan 9, 2025
1 parent 641efea commit cc9e1e2
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions python/cuml/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class BaseRandomForestModel(UniversalBase):
self.treelite_serialized_model = treelite_serialize(self._temp.handle)
self._obtain_treelite_handle()
self.dtype = np.float64
self.update_labels = False
super().cpu_to_gpu()

def gpu_to_cpu(self):
Expand Down
10 changes: 10 additions & 0 deletions python/cuml/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#

# distutils: language = c++
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
Expand Down Expand Up @@ -247,6 +249,9 @@ class RandomForestClassifier(BaseRandomForestModel,
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_.
"""

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestClassifier'

@device_interop_preparation
def __init__(self, *, split_criterion=0, handle=None, verbose=False,
output_type=None,
**kwargs):
Expand Down Expand Up @@ -337,6 +342,9 @@ class RandomForestClassifier(BaseRandomForestModel,
self.treelite_serialized_model = None
self.n_cols = None

def get_attr_names(self):
return []

def convert_to_treelite_model(self):
"""
Converts the cuML RF model to a Treelite model
Expand Down Expand Up @@ -417,6 +425,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@cuml.internals.api_base_return_any(set_output_type=False,
set_output_dtype=True,
set_n_features_in=False)
@enable_device_interop
def fit(self, X, y, convert_dtype=True):
"""
Perform Random Forest Classification on the input data
Expand Down Expand Up @@ -555,6 +564,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
@cuml.internals.api_base_return_array(get_output_dtype=True)
@enable_device_interop
def predict(self, X, predict_model="GPU", threshold=0.5,
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
# distutils: language = c++

from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
Expand Down
16 changes: 13 additions & 3 deletions python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from cuml.decomposition import PCA, TruncatedSVD
from cuml.cluster import KMeans
from cuml.cluster import DBSCAN
from cuml.ensemble import RandomForestRegressor
from cuml.ensemble import RandomForestClassifier, RandomForestRegressor
from cuml.common.device_selection import DeviceType, using_device_type
from cuml.testing.utils import assert_dbscan_equal
from hdbscan import HDBSCAN as refHDBSCAN
Expand Down Expand Up @@ -1016,9 +1016,19 @@ def test_dbscan_methods(train_device, infer_device):

@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_random_forest_methods(train_device, infer_device):
def test_random_forest_regressor(train_device, infer_device):
model = RandomForestRegressor()
with using_device_type(train_device):
model.fit(X_train_reg, y_train_reg)
with using_device_type(infer_device):
output = model.predict(X_test_reg)
_ = model.predict(X_test_reg)


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_random_forest_classifier(train_device, infer_device):
model = RandomForestClassifier()
with using_device_type(train_device):
model.fit(X_train_blob, y_train_blob)
with using_device_type(infer_device):
_ = model.predict(X_test_blob)

0 comments on commit cc9e1e2

Please sign in to comment.