Skip to content

Commit

Permalink
Added test of catch functionality to OptunaSearchCV
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlbach authored Sep 9, 2024
1 parent 9fdb0ad commit 29e5f5d
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/sklearn/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from optuna_integration.sklearn.sklearn import _is_arraylike
from optuna_integration.sklearn.sklearn import _make_indexable
from optuna_integration.sklearn.sklearn import _num_samples
from sklearn.base import BaseEstimator
from sklearn.datasets import make_blobs
from sklearn.datasets import make_regression
from sklearn.decomposition import PCA
Expand All @@ -25,6 +26,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import PredefinedSplit
from sklearn.metrics import r2_score, make_scorer
from sklearn.neighbors import KernelDensity
from sklearn.tree import DecisionTreeRegressor

Expand Down Expand Up @@ -458,6 +460,42 @@ def test_callbacks() -> None:
assert callback.call_count == n_trials


@pytest.mark.parametrize("catch", [(FloatingPointError,), ()])
def test_catch(catch):
X, y = make_blobs(n_samples=10)

class MockEstimator(BaseEstimator):
def __init__(self, param: int = 1):
self.param = param

def fit(self, X, y):
raise FloatingPointError

est = MockEstimator()
param_dist = {"param": distributions.IntDistribution(1, 10)}
n_trials = 3

optuna_search = OptunaSearchCV(
est,
param_dist,
cv=3,
max_iter=5,
n_trials=n_trials,
error_score=0,
refit=False,
scoring=make_scorer(r2_score),
catch=catch,
)

if catch:
optuna_search.fit(X, y)
assert optuna_search.n_trials_ == n_trials
else:
with pytest.raises(FloatingPointError):
optuna_search.fit(X, y)



@pytest.mark.filterwarnings("ignore::UserWarning")
@patch("optuna_integration.sklearn.sklearn.cross_validate")
def test_terminator_cv_score_reporting(mock: MagicMock) -> None:
Expand Down

0 comments on commit 29e5f5d

Please sign in to comment.