Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix usage of optuna distributions #268

Merged
merged 3 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/newsfragments/268.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix usage of optuna distributions as hyperparameters in older versions of scikit-learn by `Fede Raimondo`_.
23 changes: 23 additions & 0 deletions julearn/model_selection/_optuna_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ def register_optuna_searcher():
_recreate_reset_copy()


def is_optuna_valid_distribution(obj: Any) -> bool:
"""Check if an object is a valid Optuna distribution.

Parameters
----------
obj : any
The object to check.

Returns
-------
bool
Whether the object is a valid Optuna distribution.

"""
_valid_classes = [
"IntDistribution",
"FloatDistribution",
"CategoricalDistribution",
]

return obj.__class__.__name__ in _valid_classes


def _prepare_optuna_hyperparameters_distributions(
params_to_tune: Dict[str, Any],
) -> Dict[str, Any]:
Expand Down
33 changes: 33 additions & 0 deletions julearn/model_selection/tests/test_optuna_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from julearn.model_selection._optuna_searcher import (
_prepare_optuna_hyperparameters_distributions,
is_optuna_valid_distribution,
)


Expand Down Expand Up @@ -164,3 +165,35 @@ def test__prepare_optuna_hyperparameters_distributions(
)
else:
pytest.fail("Invalid distribution type")


@pytest.mark.parametrize(
"obj,expected",
[
(optd.IntDistribution(1, 20, log=False), True),
(optd.FloatDistribution(0.2, 0.7, log=False), True),
(optd.CategoricalDistribution([1, 2, 3]), True),
(optd.CategoricalDistribution(["a", "b", "c"]), True),
(optd.CategoricalDistribution(["a", "b", "c", "d"]), True),
("uniform", False),
("log-uniform", False),
("categorical", False),
(1, False),
(1.0, False),
([1, 2, 3], False),
(["a", "b", "c"], False),
],
)
def test_optuna_valid_distributions(obj: Any, expected: bool) -> None:
"""Test the optuna_valid_distributions function.

Parameters
----------
obj : Any
The object to check.
expected : bool
The expected result.

"""
out = is_optuna_valid_distribution(obj)
assert out == expected
4 changes: 4 additions & 0 deletions julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..base import ColumnTypes, ColumnTypesLike, JuTransformer, WrapModel
from ..model_selection._optuna_searcher import (
_prepare_optuna_hyperparameters_distributions,
is_optuna_valid_distribution,
)
from ..model_selection._skopt_searcher import (
_prepare_skopt_hyperparameters_distributions,
Expand Down Expand Up @@ -263,6 +264,9 @@ def add(
# If it is a distribution, we will tune it.
logger.info(f"Tuning hyperparameter {param} = {vals}")
params_to_tune[param] = vals
elif is_optuna_valid_distribution(vals):
logger.info(f"Tuning hyperparameter {param} = {vals}")
params_to_tune[param] = vals
else:
logger.info(f"Setting hyperparameter {param} = {vals}")
params_to_set[param] = vals
Expand Down
2 changes: 1 addition & 1 deletion julearn/pipeline/tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_merger_pipelines() -> None:
assert "scaler" == named_steps[1]
assert "rf" == named_steps[2]
assert len(merged.param_distributions) == 3 # type: ignore
assert merged.param_distributions[-1][
assert merged.param_distributions[-1][ # type: ignore
"rf__max_features"
] == [ # type: ignore
2,
Expand Down
2 changes: 1 addition & 1 deletion julearn/transformers/tests/test_jucolumntransformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_JuColumnTransformer_row_select():

mean_both = (
transformer_both.fit(X)
.column_transformer_.transformers_[0][1]
.column_transformer_.transformers_[0][1] # type: ignore
.mean_ # type: ignore
)

Expand Down
Loading