Skip to content

Commit

Permalink
Remove final model fit requirement for inspector
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Aug 30, 2024
1 parent be6a63b commit a9e09e3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
11 changes: 7 additions & 4 deletions julearn/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def run_cross_validation( # noqa: C901
)
if return_inspector:
if return_estimator is None:
logger.info("Inspector requested: setting return_estimator='all'")
return_estimator = "all"
if return_estimator != "all":
if return_estimator not in ["all", "cv"]:
raise_error(
"return_inspector=True requires return_estimator to be `all`."
"return_inspector=True requires return_estimator to be `all` "
"or `cv`"
)

X_types = {} if X_types is None else X_types
Expand Down Expand Up @@ -441,6 +441,9 @@ def run_cross_validation( # noqa: C901
groups=df_groups,
cv=cv_outer,
)
out = scores_df, pipeline, inspector
if isinstance(out, tuple):
out = (*out, inspector)
else:
out = out, inspector

return out
23 changes: 22 additions & 1 deletion julearn/inspect/tests/test_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None:
"""
X = list(df_iris.iloc[:, :-1].columns)
scores, pipe, inspect = run_cross_validation(

# All estimators
out = run_cross_validation(
X=X,
y="species",
data=df_iris,
Expand All @@ -63,13 +65,32 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None:
return_inspector=True,
problem_type="classification",
)
scores, pipe, inspect = out
assert pipe == inspect.model._model # type: ignore
for (_, score), inspect_fold in zip(
scores.iterrows(), # type: ignore
inspect.folds, # type: ignore
):
assert score["estimator"] == inspect_fold.model._model

del pipe
# only CV estimators
out = run_cross_validation(
X=X,
y="species",
data=df_iris,
model="svm",
return_estimator="cv",
return_inspector=True,
problem_type="classification",
)
scores, inspect = out
for (_, score), inspect_fold in zip(
scores.iterrows(), # type: ignore
inspect.folds, # type: ignore
):
assert score["estimator"] == inspect_fold.model._model


def test_normal_usage_with_search(df_iris: "pd.DataFrame") -> None:
"""Test inspector with search.
Expand Down

0 comments on commit a9e09e3

Please sign in to comment.