Skip to content

Commit

Permalink
Update sklearn.py by addind catch to OptunaSearchCV
Browse files Browse the repository at this point in the history
To allow for catching exceptions during fit, an init param for "catch" is added, which is passed to self.study_.optimize
  • Loading branch information
muhlbach committed Sep 6, 2024
1 parent e401294 commit a0bc602
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optuna_integration/sklearn/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def __init__(
timeout: float | None = None,
verbose: int = 0,
callbacks: list[Callable[[study_module.Study, FrozenTrial], None]] | None = None,
catch: Iterable[type[Exception]] | type[Exception] = (),
) -> None:
_imports.check()

Expand Down Expand Up @@ -767,6 +768,7 @@ def __init__(
self.timeout = timeout
self.verbose = verbose
self.callbacks = callbacks
self.catch = catch

def _check_is_fitted(self) -> None:
attributes = ["n_splits_", "sample_indices_", "scorer_", "study_"]
Expand Down Expand Up @@ -925,6 +927,7 @@ def fit(
n_trials=self.n_trials,
timeout=self.timeout,
callbacks=self.callbacks,
catch=self.catch,
)

_logger.info("Finished hyperparameter search!")
Expand Down

0 comments on commit a0bc602

Please sign in to comment.