diff --git a/tests/sklearn/test_sklearn.py b/tests/sklearn/test_sklearn.py index 822760c1..b2d81e29 100644 --- a/tests/sklearn/test_sklearn.py +++ b/tests/sklearn/test_sklearn.py @@ -17,6 +17,7 @@ from optuna_integration.sklearn.sklearn import _is_arraylike from optuna_integration.sklearn.sklearn import _make_indexable from optuna_integration.sklearn.sklearn import _num_samples +from sklearn.base import BaseEstimator from sklearn.datasets import make_blobs from sklearn.datasets import make_regression from sklearn.decomposition import PCA @@ -25,6 +26,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.linear_model import SGDClassifier from sklearn.model_selection import PredefinedSplit +from sklearn.metrics import r2_score, make_scorer from sklearn.neighbors import KernelDensity from sklearn.tree import DecisionTreeRegressor @@ -458,6 +460,42 @@ def test_callbacks() -> None: assert callback.call_count == n_trials +@pytest.mark.parametrize("catch", [(FloatingPointError,), ()]) +def test_catch(catch): + X, y = make_blobs(n_samples=10) + + class MockEstimator(BaseEstimator): + def __init__(self, param: int = 1): + self.param = param + + def fit(self, X, y): + raise FloatingPointError + + est = MockEstimator() + param_dist = {"param": distributions.IntDistribution(1, 10)} + n_trials = 3 + + optuna_search = OptunaSearchCV( + est, + param_dist, + cv=3, + max_iter=5, + n_trials=n_trials, + error_score=0, + refit=False, + scoring=make_scorer(r2_score), + catch=catch, + ) + + if catch: + optuna_search.fit(X, y) + assert optuna_search.n_trials_ == n_trials + else: + with pytest.raises(FloatingPointError): + optuna_search.fit(X, y) + + + @pytest.mark.filterwarnings("ignore::UserWarning") @patch("optuna_integration.sklearn.sklearn.cross_validate") def test_terminator_cv_score_reporting(mock: MagicMock) -> None: