Skip to content

Commit

Permalink
adaboost: adapt to scikit-learn's 1.4 deprecation of base_estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
markotoplak committed Nov 17, 2023
1 parent 0957652 commit 1231f47
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
48 changes: 34 additions & 14 deletions Orange/ensembles/ada_boost.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import sklearn.ensemble as skl_ensemble

from Orange.base import SklLearner
Expand All @@ -7,6 +9,8 @@
from Orange.regression.base_regression import (
SklLearnerRegression, SklModelRegression
)
from Orange.util import OrangeDeprecationWarning


__all__ = ['SklAdaBoostClassificationLearner', 'SklAdaBoostRegressionLearner']

Expand All @@ -15,21 +19,32 @@ class SklAdaBoostClassifier(SklModelClassification):
pass


def base_estimator_deprecation():
warnings.warn(
"`base_estimator` is deprecated: use `estimator` instead.",
OrangeDeprecationWarning, stacklevel=3)


class SklAdaBoostClassificationLearner(SklLearnerClassification):
__wraps__ = skl_ensemble.AdaBoostClassifier
__returns__ = SklAdaBoostClassifier
supports_weights = True

def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1.,
algorithm='SAMME.R', random_state=None, preprocessors=None):
def __init__(self, estimator=None, n_estimators=50, learning_rate=1.,
algorithm='SAMME.R', random_state=None, preprocessors=None,
base_estimator="deprecated"):
if base_estimator != "deprecated":
base_estimator_deprecation()
estimator = base_estimator
del base_estimator
from Orange.modelling import Fitter
# If fitter, get the appropriate Learner instance
if isinstance(base_estimator, Fitter):
base_estimator = base_estimator.get_learner(
base_estimator.CLASSIFICATION)
if isinstance(estimator, Fitter):
estimator = estimator.get_learner(
estimator.CLASSIFICATION)
# If sklearn learner, get the underlying sklearn representation
if isinstance(base_estimator, SklLearner):
base_estimator = base_estimator.__wraps__(**base_estimator.params)
if isinstance(estimator, SklLearner):
estimator = estimator.__wraps__(**estimator.params)
super().__init__(preprocessors=preprocessors)
self.params = vars()

Expand All @@ -43,15 +58,20 @@ class SklAdaBoostRegressionLearner(SklLearnerRegression):
__returns__ = SklAdaBoostRegressor
supports_weights = True

def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1.,
loss='linear', random_state=None, preprocessors=None):
def __init__(self, estimator=None, n_estimators=50, learning_rate=1.,
loss='linear', random_state=None, preprocessors=None,
base_estimator="deprecated"):
if base_estimator != "deprecated":
base_estimator_deprecation()
estimator = base_estimator
del base_estimator
from Orange.modelling import Fitter
# If fitter, get the appropriate Learner instance
if isinstance(base_estimator, Fitter):
base_estimator = base_estimator.get_learner(
base_estimator.REGRESSION)
if isinstance(estimator, Fitter):
estimator = estimator.get_learner(
estimator.REGRESSION)
# If sklearn learner, get the underlying sklearn representation
if isinstance(base_estimator, SklLearner):
base_estimator = base_estimator.__wraps__(**base_estimator.params)
if isinstance(estimator, SklLearner):
estimator = estimator.__wraps__(**estimator.params)
super().__init__(preprocessors=preprocessors)
self.params = vars()
23 changes: 17 additions & 6 deletions Orange/tests/test_ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
# pylint: disable=missing-docstring

import unittest
from distutils.version import LooseVersion

import numpy as np

import Orange
from Orange.data import Table
from Orange.classification import SklTreeLearner
from Orange.regression import SklTreeRegressionLearner
Expand All @@ -27,14 +31,14 @@ def test_adaboost(self):
self.assertGreater(ca, 0.9)
self.assertLess(ca, 0.99)

def test_adaboost_base_estimator(self):
def test_adaboost_estimator(self):
np.random.seed(0)
stump_estimator = SklTreeLearner(max_depth=1)
tree_estimator = SklTreeLearner()
stump = SklAdaBoostClassificationLearner(
base_estimator=stump_estimator, n_estimators=5)
estimator=stump_estimator, n_estimators=5)
tree = SklAdaBoostClassificationLearner(
base_estimator=tree_estimator, n_estimators=5)
estimator=tree_estimator, n_estimators=5)
cv = CrossValidation(k=4)
results = cv(self.iris, [stump, tree])
ca = CA(results)
Expand Down Expand Up @@ -68,12 +72,12 @@ def test_adaboost_reg(self):
results = cv(self.housing, [learn])
_ = RMSE(results)

def test_adaboost_reg_base_estimator(self):
def test_adaboost_reg_estimator(self):
np.random.seed(0)
stump_estimator = SklTreeRegressionLearner(max_depth=1)
tree_estimator = SklTreeRegressionLearner()
stump = SklAdaBoostRegressionLearner(base_estimator=stump_estimator)
tree = SklAdaBoostRegressionLearner(base_estimator=tree_estimator)
stump = SklAdaBoostRegressionLearner(estimator=stump_estimator)
tree = SklAdaBoostRegressionLearner(estimator=tree_estimator)
cv = CrossValidation(k=3)
results = cv(self.housing, [stump, tree])
rmse = RMSE(results)
Expand Down Expand Up @@ -103,3 +107,10 @@ def test_predict_numpy_reg(self):
def test_adaboost_adequacy_reg(self):
learner = SklAdaBoostRegressionLearner()
self.assertRaises(ValueError, learner, self.iris)

def test_remove_deprecation(self):
if LooseVersion(Orange.__version__) >= LooseVersion("3.39"):
self.fail(
"`base_estimator` was deprecated in "
"version 3.37. Please remove everything related to it."
)
2 changes: 1 addition & 1 deletion Orange/widgets/model/owadaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def create_learner(self):
if self.base_estimator is None:
return None
return self.LEARNER(
base_estimator=self.base_estimator,
estimator=self.base_estimator,
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
random_state=self.random_seed,
Expand Down

0 comments on commit 1231f47

Please sign in to comment.