diff --git a/julearn/api.py b/julearn/api.py index aa5ec06bb..3f00a82d5 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -26,7 +26,7 @@ from .utils.typing import CVLike -def _validata_api_params( +def _validata_api_params( # noqa: C901 X: List[str], # noqa: N803 y: str, model: Union[str, PipelineCreator, BaseEstimator, List[PipelineCreator]], @@ -116,6 +116,29 @@ def _validata_api_params( is provided for at least one hyperparameter, a search will be performed. + search_params : dict | None + Additional parameters in case Hyperparameter Tuning is performed, with + the following keys: + + * 'kind': The kind of search algorithm to use, Valid options are: + + * ``"grid"`` : :class:`~sklearn.model_selection.GridSearchCV` + * ``"random"`` : + :class:`~sklearn.model_selection.RandomizedSearchCV` + * ``"bayes"`` : :class:`~skopt.BayesSearchCV` + * ``"optuna"`` : + :class:`~optuna_integration.OptunaSearchCV` + * user-registered searcher name : see + :func:`~julearn.model_selection.register_searcher` + * ``scikit-learn``-compatible searcher + + * 'cv': If a searcher is going to be used, the cross-validation + splitting strategy to use. Defaults to same CV as for the model + evaluation. + * 'scoring': If a searcher is going to be used, the scoring metric to + evaluate the performance. + + See :ref:`hp_tuning` for details. seed : int | None If not None, set the random seed before any operation. Useful for reproducibility. @@ -138,6 +161,7 @@ def _validata_api_params( Whether to wrap the score or not. problem_type : str The problem type. + """ if return_estimator not in [None, "final", "cv", "all"]: raise_error(