From a9e09e3f8cd36703d29a1900c58a5b62fba821b4 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Fri, 30 Aug 2024 14:55:41 +0200 Subject: [PATCH] Remove final model fit requirement for inspector --- julearn/api.py | 11 +++++++---- julearn/inspect/tests/test_inspector.py | 23 ++++++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/julearn/api.py b/julearn/api.py index 130b57e0a..17fe483a1 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -194,11 +194,11 @@ def run_cross_validation( # noqa: C901 ) if return_inspector: if return_estimator is None: - logger.info("Inspector requested: setting return_estimator='all'") return_estimator = "all" - if return_estimator != "all": + if return_estimator not in ["all", "cv"]: raise_error( - "return_inspector=True requires return_estimator to be `all`." + "return_inspector=True requires return_estimator to be `all` " + "or `cv`" ) X_types = {} if X_types is None else X_types @@ -441,6 +441,9 @@ def run_cross_validation( # noqa: C901 groups=df_groups, cv=cv_outer, ) - out = scores_df, pipeline, inspector + if isinstance(out, tuple): + out = (*out, inspector) + else: + out = out, inspector return out diff --git a/julearn/inspect/tests/test_inspector.py b/julearn/inspect/tests/test_inspector.py index 8643cee1d..6f069324b 100644 --- a/julearn/inspect/tests/test_inspector.py +++ b/julearn/inspect/tests/test_inspector.py @@ -54,7 +54,9 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: """ X = list(df_iris.iloc[:, :-1].columns) - scores, pipe, inspect = run_cross_validation( + + # All estimators + out = run_cross_validation( X=X, y="species", data=df_iris, @@ -63,6 +65,7 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: return_inspector=True, problem_type="classification", ) + scores, pipe, inspect = out assert pipe == inspect.model._model # type: ignore for (_, score), inspect_fold in zip( scores.iterrows(), # type: ignore @@ -70,6 +73,24 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: ): assert score["estimator"] == inspect_fold.model._model + del pipe + # only CV estimators + out = run_cross_validation( + X=X, + y="species", + data=df_iris, + model="svm", + return_estimator="cv", + return_inspector=True, + problem_type="classification", + ) + scores, inspect = out + for (_, score), inspect_fold in zip( + scores.iterrows(), # type: ignore + inspect.folds, # type: ignore + ): + assert score["estimator"] == inspect_fold.model._model + def test_normal_usage_with_search(df_iris: "pd.DataFrame") -> None: """Test inspector with search.