diff --git a/julearn/model_selection/_skopt_searcher.py b/julearn/model_selection/_skopt_searcher.py index e0904b10c..3bd246a3b 100644 --- a/julearn/model_selection/_skopt_searcher.py +++ b/julearn/model_selection/_skopt_searcher.py @@ -2,11 +2,14 @@ # Authors: Federico Raimondo # License: AGPL +from typing import Any, Dict +from ..utils import logger from .available_searchers import _recreate_reset_copy, register_searcher try: + import skopt.space as sksp from skopt import BayesSearchCV except ImportError: from sklearn.model_selection._search import BaseSearchCV @@ -30,3 +33,62 @@ def register_bayes_searcher(): # Update the "reset copy" of available searchers _recreate_reset_copy() + + + +def _prepare_skopt_hyperparameters_distributions( + params_to_tune: Dict[str, Any], +) -> Dict[str, Any]: + """Prepare hyperparameters distributions for RandomizedSearchCV. + + This method replaces tuples with distributions for RandomizedSearchCV + following the skopt convention. That is, if a parameter is a tuple + with 3 elements, the first two elements are the bounds of the + distribution and the third element is the type of distribution. In case + the last element is "categorical", the parameter is considered + categorical and all the previous elements are the choices. + + Parameters + ---------- + params_to_tune : dict + The parameters to tune. + + Returns + ------- + dict + The modified parameters to tune. + + """ + out = {} + for k, v in params_to_tune.items(): + if isinstance(v, tuple) and len(v) == 3: + prior = v[2] + if prior == "categorical": + logger.info(f"Hyperparameter {k} is categorical with 2 " + f"options: [{v[0]} and {v[1]}]") + out[k] = sksp.Categorical(v[:-1]) + elif isinstance(v[0], int) and isinstance(v[1], int): + logger.info( + f"Hyperparameter {k} is {prior} integer " + f"[{v[0]}, {v[1]}]" + ) + out[k] = sksp.Integer(v[0], v[1], prior=prior) + elif isinstance(v[0], float) and isinstance(v[1], float): + logger.info( + f"Hyperparameter {k} is {prior} float " + f"[{v[0]}, {v[1]}]" + ) + out[k] = sksp.Real(v[0], v[1], prior=prior) + else: + logger.info(f"Hyperparameter {k} as is {v}") + out[k] = v + elif ( + isinstance(v, tuple) + and isinstance(v[-1], str) + and v[-1] == "categorical" + ): + out[k] = sksp.Categorical(v[:-1]) + else: + logger.info(f"Hyperparameter {k} as is {v}") + out[k] = v + return out diff --git a/julearn/model_selection/tests/test_skopt_searcher.py b/julearn/model_selection/tests/test_skopt_searcher.py new file mode 100644 index 000000000..9e8c67fb0 --- /dev/null +++ b/julearn/model_selection/tests/test_skopt_searcher.py @@ -0,0 +1,133 @@ +"""Provides tests for the bayes searcher.""" + +# Authors: Federico Raimondo +# License: AGPL +from typing import Dict + +import pytest +import skopt.space as sksp + +from julearn.model_selection._skopt_searcher import ( + _prepare_skopt_hyperparameters_distributions, +) + + +@pytest.mark.parametrize( + "params_to_tune,expected_types, expected_dist", + [ + ( + { + "n_components": (0.2, 0.7, "uniform"), + "n_neighbors": (1.0, 10.0, "log-uniform"), + }, + ("float", "float"), + ("uniform", "log-uniform"), + ), + ( + { + "n_components": (1, 20, "uniform"), + "n_neighbors": (1, 10, "log-uniform"), + }, + ("int", "int", "int"), + ("uniform", "log-uniform"), + ), + ( + { + "options": (True, False, "categorical"), + "more_options": ("a", "b", "c", "d", "categorical"), + }, + (None, None), + ("categorical", "categorical"), + ), + ( + { + "n_components": sksp.Real(0.2, 0.7, prior="uniform"), + "n_neighbors": sksp.Real(1.0, 10.0, prior="log-uniform"), + }, + ("float", "float"), + ("uniform", "log-uniform"), + ), + ( + { + "n_components": sksp.Integer(1, 20, prior="uniform"), + "n_neighbors": sksp.Integer(1, 10, prior="log-uniform"), + }, + ("int", "int"), + ("uniform", "log-uniform"), + ), + ( + { + "options": sksp.Categorical([True, False]), + "more_options": sksp.Categorical( + ("a", "b", "c", "d"), + ), + }, + (None, None), + ("categorical", "categorical"), + ), + ], +) +def test__prepare_skopt_hyperparameters_distributions( + params_to_tune: Dict[str, Dict[str, tuple]], + expected_types: tuple, + expected_dist: tuple, +) -> None: + """Test the _prepare_skopt_hyperparameters_distributions function. + + Parameters + ---------- + params_to_tune : dict + The parameters to tune. + expected_types : tuple + The expected types of each parameter. + expected_dist : tuple + The expected distributions of each parameter. + + """ + new_params = _prepare_skopt_hyperparameters_distributions(params_to_tune) + for i, (k, v) in enumerate(new_params.items()): + if expected_types[i] == "int": + assert isinstance(v, sksp.Integer) + assert v.prior == expected_dist[i] + if isinstance(params_to_tune[k], tuple): + assert v.bounds[0] == params_to_tune[k][0] # type: ignore + assert v.bounds[1] == params_to_tune[k][1] # type: ignore + else: + assert isinstance(params_to_tune[k], sksp.Integer) + assert v.bounds[0] == params_to_tune[k].bounds[0] # type: ignore + assert v.bounds[1] == params_to_tune[k].bounds[1] # type: ignore + assert params_to_tune[k].prior == v.prior # type: ignore + elif expected_types[i] == "float": + assert isinstance(v, sksp.Real) + assert v.prior == expected_dist[i] + if isinstance(params_to_tune[k], tuple): + assert v.bounds[0] == params_to_tune[k][0] # type: ignore + assert v.bounds[1] == params_to_tune[k][1] # type: ignore + else: + assert isinstance(params_to_tune[k], sksp.Real) + assert v.bounds[0] == params_to_tune[k].bounds[0] # type: ignore + assert v.bounds[1] == params_to_tune[k].bounds[1] # type: ignore + assert params_to_tune[k].prior == v.prior # type: ignore + elif expected_dist[i] == "categorical": + assert isinstance(v, sksp.Categorical) + if isinstance(params_to_tune[k], tuple): + assert all( + x in v.categories + for x in params_to_tune[k][:-1] # type: ignore + ) + assert all( + x in params_to_tune[k][:-1] # type: ignore + for x in v.categories + ) + else: + assert isinstance(params_to_tune[k], sksp.Categorical) + assert all( + x in v.categories + for x in params_to_tune[k].categories # type: ignore + ) + assert all( + x in params_to_tune[k].categories # type: ignore + for x in v.categories + ) + else: + pytest.fail("Invalid distribution type")