diff --git a/docs/changes/newsfragments/268.bugfix b/docs/changes/newsfragments/268.bugfix new file mode 100644 index 000000000..338a56f84 --- /dev/null +++ b/docs/changes/newsfragments/268.bugfix @@ -0,0 +1 @@ +Fix usage of optuna distributions as hyperparameters in older versions of scikit-learn by `Fede Raimondo`_. \ No newline at end of file diff --git a/julearn/model_selection/_optuna_searcher.py b/julearn/model_selection/_optuna_searcher.py index ec6952b2f..921389afc 100644 --- a/julearn/model_selection/_optuna_searcher.py +++ b/julearn/model_selection/_optuna_searcher.py @@ -38,6 +38,29 @@ def register_optuna_searcher(): _recreate_reset_copy() +def is_optuna_valid_distribution(obj: Any) -> bool: + """Check if an object is a valid Optuna distribution. + + Parameters + ---------- + obj : any + The object to check. + + Returns + ------- + bool + Whether the object is a valid Optuna distribution. + + """ + _valid_classes = [ + "IntDistribution", + "FloatDistribution", + "CategoricalDistribution", + ] + + return obj.__class__.__name__ in _valid_classes + + def _prepare_optuna_hyperparameters_distributions( params_to_tune: Dict[str, Any], ) -> Dict[str, Any]: diff --git a/julearn/model_selection/tests/test_optuna_searcher.py b/julearn/model_selection/tests/test_optuna_searcher.py index 231f5becd..0dea2c9de 100644 --- a/julearn/model_selection/tests/test_optuna_searcher.py +++ b/julearn/model_selection/tests/test_optuna_searcher.py @@ -9,6 +9,7 @@ from julearn.model_selection._optuna_searcher import ( _prepare_optuna_hyperparameters_distributions, + is_optuna_valid_distribution, ) @@ -164,3 +165,35 @@ def test__prepare_optuna_hyperparameters_distributions( ) else: pytest.fail("Invalid distribution type") + + +@pytest.mark.parametrize( + "obj,expected", + [ + (optd.IntDistribution(1, 20, log=False), True), + (optd.FloatDistribution(0.2, 0.7, log=False), True), + (optd.CategoricalDistribution([1, 2, 3]), True), + (optd.CategoricalDistribution(["a", "b", "c"]), True), + (optd.CategoricalDistribution(["a", "b", "c", "d"]), True), + ("uniform", False), + ("log-uniform", False), + ("categorical", False), + (1, False), + (1.0, False), + ([1, 2, 3], False), + (["a", "b", "c"], False), + ], +) +def test_optuna_valid_distributions(obj: Any, expected: bool) -> None: + """Test the optuna_valid_distributions function. + + Parameters + ---------- + obj : Any + The object to check. + expected : bool + The expected result. + + """ + out = is_optuna_valid_distribution(obj) + assert out == expected diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index 979bf828a..60d0be052 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -16,6 +16,7 @@ from ..base import ColumnTypes, ColumnTypesLike, JuTransformer, WrapModel from ..model_selection._optuna_searcher import ( _prepare_optuna_hyperparameters_distributions, + is_optuna_valid_distribution, ) from ..model_selection._skopt_searcher import ( _prepare_skopt_hyperparameters_distributions, @@ -263,6 +264,9 @@ def add( # If it is a distribution, we will tune it. logger.info(f"Tuning hyperparameter {param} = {vals}") params_to_tune[param] = vals + elif is_optuna_valid_distribution(vals): + logger.info(f"Tuning hyperparameter {param} = {vals}") + params_to_tune[param] = vals else: logger.info(f"Setting hyperparameter {param} = {vals}") params_to_set[param] = vals diff --git a/julearn/pipeline/tests/test_merger.py b/julearn/pipeline/tests/test_merger.py index 1a93e1786..49cf9afe0 100644 --- a/julearn/pipeline/tests/test_merger.py +++ b/julearn/pipeline/tests/test_merger.py @@ -51,7 +51,7 @@ def test_merger_pipelines() -> None: assert "scaler" == named_steps[1] assert "rf" == named_steps[2] assert len(merged.param_distributions) == 3 # type: ignore - assert merged.param_distributions[-1][ + assert merged.param_distributions[-1][ # type: ignore "rf__max_features" ] == [ # type: ignore 2, diff --git a/julearn/transformers/tests/test_jucolumntransformers.py b/julearn/transformers/tests/test_jucolumntransformers.py index 13cd3d9be..3d418f1dd 100644 --- a/julearn/transformers/tests/test_jucolumntransformers.py +++ b/julearn/transformers/tests/test_jucolumntransformers.py @@ -157,7 +157,7 @@ def test_JuColumnTransformer_row_select(): mean_both = ( transformer_both.fit(X) - .column_transformer_.transformers_[0][1] + .column_transformer_.transformers_[0][1] # type: ignore .mean_ # type: ignore )