Skip to content

Commit

Permalink
[ENH] Add sklearn backend for forest (#36)
Browse files Browse the repository at this point in the history
* add sklearn backend

* add unit tests

* Update CHANGELOG.rst
  • Loading branch information
xuyxu authored Feb 16, 2021
1 parent fd25465 commit 37c33df
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Version 0.1.*
.. |Fix| replace:: :raw-html:`<span class="badge badge-danger">Fix</span>` :raw-latex:`{\small\sc [Fix]}`
.. |API| replace:: :raw-html:`<span class="badge badge-warning">API Change</span>` :raw-latex:`{\small\sc [API Change]}`

- |Feature| add scikit-learn backend (`#36 <https://github.com/LAMDA-NJU/Deep-Forest/pull/36>`__) @xuyxu
- |Feature| add official support for Mac-OS (`#34 <https://github.com/LAMDA-NJU/Deep-Forest/pull/34>`__) @T-Allen-sudo
- |Feature| support configurable criterion (`#28 <https://github.com/LAMDA-NJU/Deep-Forest/issues/28>`__) @tczhao
- |Feature| support regression prediction (`#25 <https://github.com/LAMDA-NJU/Deep-Forest/issues/25>`__) @tczhao
Expand Down
145 changes: 106 additions & 39 deletions deepforest/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@

__all__ = ["Estimator"]

import numpy as np
from .forest import (
RandomForestClassifier,
ExtraTreesClassifier,
RandomForestRegressor,
ExtraTreesRegressor,
)
from sklearn.ensemble import (
RandomForestClassifier as sklearn_RandomForestClassifier,
ExtraTreesClassifier as sklearn_ExtraTreesClassifier,
RandomForestRegressor as sklearn_RandomForestRegressor,
ExtraTreesRegressor as sklearn_ExtraTreesRegressor,
)


def make_classifier_estimator(
Expand All @@ -17,29 +24,54 @@ def make_classifier_estimator(
n_trees=100,
max_depth=None,
min_samples_leaf=1,
backend="custom",
n_jobs=None,
random_state=None,
):
# RandomForestClassifier
if name == "rf":
estimator = RandomForestClassifier(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
if backend == "custom":
estimator = RandomForestClassifier(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
elif backend == "sklearn":
estimator = sklearn_RandomForestClassifier(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
bootstrap=True,
oob_score=True,
n_jobs=n_jobs,
random_state=random_state,
)
# ExtraTreesClassifier
elif name == "erf":
estimator = ExtraTreesClassifier(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
if backend == "custom":
estimator = ExtraTreesClassifier(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
elif backend == "sklearn":
estimator = sklearn_ExtraTreesClassifier(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
bootstrap=True,
oob_score=True,
n_jobs=n_jobs,
random_state=random_state,
)
else:
msg = "Unknown type of estimator, which should be one of {{rf, erf}}."
raise NotImplementedError(msg)
Expand All @@ -53,29 +85,54 @@ def make_regressor_estimator(
n_trees=100,
max_depth=None,
min_samples_leaf=1,
backend="custom",
n_jobs=None,
random_state=None,
):
# RandomForestRegressor
if name == "rf":
estimator = RandomForestRegressor(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
if backend == "custom":
estimator = RandomForestRegressor(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
elif backend == "sklearn":
estimator = sklearn_RandomForestRegressor(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
bootstrap=True,
oob_score=True,
n_jobs=n_jobs,
random_state=random_state,
)
# ExtraTreesRegressor
elif name == "erf":
estimator = ExtraTreesRegressor(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
if backend == "custom":
estimator = ExtraTreesRegressor(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
n_jobs=n_jobs,
random_state=random_state,
)
elif backend == "sklearn":
estimator = sklearn_ExtraTreesRegressor(
criterion=criterion,
n_estimators=n_trees,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
bootstrap=True,
oob_score=True,
n_jobs=n_jobs,
random_state=random_state,
)
else:
msg = "Unknown type of estimator, which should be one of {{rf, erf}}."
raise NotImplementedError(msg)
Expand All @@ -91,11 +148,13 @@ def __init__(
n_trees=100,
max_depth=None,
min_samples_leaf=1,
backend="custom",
n_jobs=None,
random_state=None,
is_classifier=True,
):

self.backend = backend
self.is_classifier = is_classifier
if self.is_classifier:
self.estimator_ = make_classifier_estimator(
Expand All @@ -104,6 +163,7 @@ def __init__(
n_trees,
max_depth,
min_samples_leaf,
backend,
n_jobs,
random_state,
)
Expand All @@ -114,26 +174,33 @@ def __init__(
n_trees,
max_depth,
min_samples_leaf,
backend,
n_jobs,
random_state,
)

@property
def oob_decision_function_(self):
# Scikit-Learn uses `oob_prediction_` for ForestRegressor
if self.backend == "sklearn" and not self.is_classifier:
oob_prediction = self.estimator_.oob_prediction_
if len(oob_prediction.shape) == 1:
oob_prediction = np.expand_dims(oob_prediction, 1)
return oob_prediction
return self.estimator_.oob_decision_function_

def fit_transform(self, X, y, sample_weight=None):
self.estimator_.fit(X, y, sample_weight)
X_aug = self.estimator_.oob_decision_function_

return X_aug
return self.oob_decision_function_

def transform(self, X):
if self.is_classifier:
return self.estimator_.predict_proba(X)
return self.estimator_.predict(X)
"""Preserved for the naming consistency."""
return self.predict(X)

def predict(self, X):
if self.is_classifier:
return self.estimator_.predict_proba(X)
return self.estimator_.predict(X)
pred = self.estimator_.predict(X)
if len(pred.shape) == 1:
pred = np.expand_dims(pred, 1)
return pred
3 changes: 3 additions & 0 deletions deepforest/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
n_trees=100,
max_depth=None,
min_samples_leaf=1,
backend="custom",
partial_mode=False,
buffer=None,
n_jobs=None,
Expand All @@ -66,6 +67,7 @@ def __init__(
self.n_trees = n_trees
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.backend = backend
self.partial_mode = partial_mode
self.buffer = buffer
self.n_jobs = n_jobs
Expand Down Expand Up @@ -95,6 +97,7 @@ def _make_estimator(self, estimator_idx, estimator_name):
n_trees=self.n_trees,
max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
backend=self.backend,
n_jobs=self.n_jobs,
random_state=random_state,
is_classifier=self.is_classifier,
Expand Down
20 changes: 20 additions & 0 deletions deepforest/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ def _build_regressor_predictor(
Specifying this will extend/overwrite the original parameters inherit
from deep forest. If ``use_predictor`` is False, this parameter will
have no effect.
backend : :obj:`{"custom", "sklearn"}`, default="custom"
The backend of the forest estimator. Supported backends are ``custom``
for higher time and memory efficiency and ``sklearn`` for additional
functionality.
n_tolerant_rounds : :obj:`int`, default=2
Specify when to conduct early stopping. The training process
terminates when the validation performance on the training set does
Expand Down Expand Up @@ -345,6 +349,10 @@ def _build_regressor_predictor(
Specifying this will extend/overwrite the original parameters inherit
from deep forest.
If ``use_predictor`` is False, this parameter will have no effect.
backend : :obj:`{"custom", "sklearn"}`, default="custom"
The backend of the forest estimator. Supported backends are ``custom``
for higher time and memory efficiency and ``sklearn`` for additional
functionality.
n_tolerant_rounds : :obj:`int`, default=2
Specify when to conduct early stopping. The training process
terminates when the validation performance on the training set does
Expand Down Expand Up @@ -461,6 +469,7 @@ def __init__(
use_predictor=False,
predictor="forest",
predictor_kwargs={},
backend="custom",
n_tolerant_rounds=2,
delta=1e-5,
partial_mode=False,
Expand All @@ -478,6 +487,7 @@ def __init__(
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.predictor_kwargs = predictor_kwargs
self.backend = backend
self.n_tolerant_rounds = n_tolerant_rounds
self.delta = delta
self.partial_mode = partial_mode
Expand Down Expand Up @@ -607,6 +617,10 @@ def _validate_params(self):
msg = "max_layers = {} should be strictly positive."
raise ValueError(msg.format(self.max_layers))

if not self.backend in ("custom", "sklearn"):
msg = "backend = {} should be one of {{custom, sklearn}}."
raise ValueError(msg.format(self.backend))

if not self.n_tolerant_rounds > 0:
msg = "n_tolerant_rounds = {} should be strictly positive."
raise ValueError(msg.format(self.n_tolerant_rounds))
Expand Down Expand Up @@ -729,6 +743,7 @@ def fit(self, X, y, sample_weight=None):
self._set_n_trees(0),
self.max_depth,
self.min_samples_leaf,
self.backend,
self.partial_mode,
self.buffer_,
self.n_jobs,
Expand Down Expand Up @@ -805,6 +820,7 @@ def fit(self, X, y, sample_weight=None):
self._set_n_trees(layer_idx),
self.max_depth,
self.min_samples_leaf,
self.backend,
self.partial_mode,
self.buffer_,
self.n_jobs,
Expand Down Expand Up @@ -1134,6 +1150,7 @@ def __init__(
use_predictor=False,
predictor="forest",
predictor_kwargs={},
backend="custom",
n_tolerant_rounds=2,
delta=1e-5,
partial_mode=False,
Expand All @@ -1154,6 +1171,7 @@ def __init__(
use_predictor=use_predictor,
predictor=predictor,
predictor_kwargs=predictor_kwargs,
backend=backend,
n_tolerant_rounds=n_tolerant_rounds,
delta=delta,
partial_mode=partial_mode,
Expand Down Expand Up @@ -1331,6 +1349,7 @@ def __init__(
use_predictor=False,
predictor="forest",
predictor_kwargs={},
backend="custom",
n_tolerant_rounds=2,
delta=1e-5,
partial_mode=False,
Expand All @@ -1351,6 +1370,7 @@ def __init__(
use_predictor=use_predictor,
predictor=predictor,
predictor_kwargs=predictor_kwargs,
backend=backend,
n_tolerant_rounds=n_tolerant_rounds,
delta=delta,
partial_mode=partial_mode,
Expand Down
Loading

0 comments on commit 37c33df

Please sign in to comment.