Skip to content

Commit

Permalink
Merge pull request #268 from juaml/fix/optuna_distributions
Browse files Browse the repository at this point in the history
Fix usage of optuna distributions
  • Loading branch information
fraimondo authored Jun 19, 2024
2 parents c93ef00 + d501297 commit be6a63b
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/268.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix usage of optuna distributions as hyperparameters in older versions of scikit-learn by `Fede Raimondo`_.
23 changes: 23 additions & 0 deletions julearn/model_selection/_optuna_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ def register_optuna_searcher():
_recreate_reset_copy()


def is_optuna_valid_distribution(obj: Any) -> bool:
"""Check if an object is a valid Optuna distribution.
Parameters
----------
obj : any
The object to check.
Returns
-------
bool
Whether the object is a valid Optuna distribution.
"""
_valid_classes = [
"IntDistribution",
"FloatDistribution",
"CategoricalDistribution",
]

return obj.__class__.__name__ in _valid_classes


def _prepare_optuna_hyperparameters_distributions(
params_to_tune: Dict[str, Any],
) -> Dict[str, Any]:
Expand Down
33 changes: 33 additions & 0 deletions julearn/model_selection/tests/test_optuna_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from julearn.model_selection._optuna_searcher import (
_prepare_optuna_hyperparameters_distributions,
is_optuna_valid_distribution,
)


Expand Down Expand Up @@ -164,3 +165,35 @@ def test__prepare_optuna_hyperparameters_distributions(
)
else:
pytest.fail("Invalid distribution type")


@pytest.mark.parametrize(
"obj,expected",
[
(optd.IntDistribution(1, 20, log=False), True),
(optd.FloatDistribution(0.2, 0.7, log=False), True),
(optd.CategoricalDistribution([1, 2, 3]), True),
(optd.CategoricalDistribution(["a", "b", "c"]), True),
(optd.CategoricalDistribution(["a", "b", "c", "d"]), True),
("uniform", False),
("log-uniform", False),
("categorical", False),
(1, False),
(1.0, False),
([1, 2, 3], False),
(["a", "b", "c"], False),
],
)
def test_optuna_valid_distributions(obj: Any, expected: bool) -> None:
"""Test the optuna_valid_distributions function.
Parameters
----------
obj : Any
The object to check.
expected : bool
The expected result.
"""
out = is_optuna_valid_distribution(obj)
assert out == expected
4 changes: 4 additions & 0 deletions julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..base import ColumnTypes, ColumnTypesLike, JuTransformer, WrapModel
from ..model_selection._optuna_searcher import (
_prepare_optuna_hyperparameters_distributions,
is_optuna_valid_distribution,
)
from ..model_selection._skopt_searcher import (
_prepare_skopt_hyperparameters_distributions,
Expand Down Expand Up @@ -263,6 +264,9 @@ def add(
# If it is a distribution, we will tune it.
logger.info(f"Tuning hyperparameter {param} = {vals}")
params_to_tune[param] = vals
elif is_optuna_valid_distribution(vals):
logger.info(f"Tuning hyperparameter {param} = {vals}")
params_to_tune[param] = vals
else:
logger.info(f"Setting hyperparameter {param} = {vals}")
params_to_set[param] = vals
Expand Down
2 changes: 1 addition & 1 deletion julearn/pipeline/tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_merger_pipelines() -> None:
assert "scaler" == named_steps[1]
assert "rf" == named_steps[2]
assert len(merged.param_distributions) == 3 # type: ignore
assert merged.param_distributions[-1][
assert merged.param_distributions[-1][ # type: ignore
"rf__max_features"
] == [ # type: ignore
2,
Expand Down
2 changes: 1 addition & 1 deletion julearn/transformers/tests/test_jucolumntransformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_JuColumnTransformer_row_select():

mean_both = (
transformer_both.fit(X)
.column_transformer_.transformers_[0][1]
.column_transformer_.transformers_[0][1] # type: ignore
.mean_ # type: ignore
)

Expand Down

0 comments on commit be6a63b

Please sign in to comment.