Skip to content

Commit

Permalink
Modify logic for final model training
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Sep 26, 2024
1 parent eb7207f commit ebb0297
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 14 deletions.
38 changes: 25 additions & 13 deletions julearn/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import sklearn
from sklearn.base import BaseEstimator
from sklearn.model_selection import (
check_cv,
cross_validate,
)
from sklearn.model_selection._search import BaseSearchCV
from sklearn.pipeline import Pipeline

from .inspect import Inspector
from .model_selection.utils import check_cv
from .pipeline import PipelineCreator
from .pipeline.merger import merge_pipelines
from .prepare import check_consistency, prepare_input_data
Expand Down Expand Up @@ -541,16 +541,19 @@ def run_cross_validation(
seed=seed,
)

include_final_model = return_estimator in ["final", "all"]
cv_return_estimator = return_estimator in ["cv", "all", "final"]

# Prepare cross validation
cv_outer = check_cv(
cv, # type: ignore
classifier=problem_type == "classification",
include_final_model=include_final_model,
)
logger.info(f"Using outer CV scheme {cv_outer}")

check_consistency(df_y, cv, groups, problem_type) # type: ignore

cv_return_estimator = return_estimator in ["cv", "all"]
scoring = check_scoring(
pipeline, # type: ignore
scoring,
Expand Down Expand Up @@ -583,30 +586,39 @@ def run_cross_validation(
**_sklearn_deprec_fit_params,
)

n_repeats = getattr(cv_outer, "n_repeats", 1)
n_folds = len(scores["fit_time"]) // n_repeats

repeats = np.repeat(np.arange(n_repeats), n_folds)
folds = np.tile(np.arange(n_folds), n_repeats)

fold_sizes = np.array(
[
list(map(len, x))
for x in cv_outer.split(df_X, df_y, groups=df_groups)
]
)

if include_final_model:
# If we include the final model, we need to remove the last item in
# the scores as this is the final model
pipeline = scores["estimator"][-1]
if return_estimator == "final":
scores.pop("estimator")
scores = {k: v[:-1] for k, v in scores.items()}
fold_sizes = fold_sizes[:-1]

n_repeats = getattr(cv_outer, "n_repeats", 1)
n_folds = len(scores["fit_time"]) // n_repeats

repeats = np.repeat(np.arange(n_repeats), n_folds)
folds = np.tile(np.arange(n_folds), n_repeats)

scores["n_train"] = fold_sizes[:, 0]
scores["n_test"] = fold_sizes[:, 1]
scores["repeat"] = repeats
scores["fold"] = folds
scores["cv_mdsum"] = cv_mdsum

scores_df = pd.DataFrame(scores)

out = scores_df
if return_estimator in ["final", "all"]:
logger.info("Fitting final model")
pipeline.fit(df_X, df_y, **fit_params)
out = scores_df, pipeline
if include_final_model:
out = out, pipeline

if return_inspector:
inspector = Inspector(
Expand All @@ -615,7 +627,7 @@ def run_cross_validation(
X=df_X,
y=df_y,
groups=df_groups,
cv=cv_outer,
cv=cv_outer.cv if include_final_model else cv_outer,
)
if isinstance(out, tuple):
out = (*out, inspector)
Expand Down
93 changes: 93 additions & 0 deletions julearn/model_selection/final_model_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""CV Wrapper that includes a fold with all the data."""

# Authors: Federico Raimondo <[email protected]>
# License: AGPL
from typing import TYPE_CHECKING, Generator, Optional, Tuple

import numpy as np


if TYPE_CHECKING:
from sklearn.model_selection import BaseCrossValidator


class _JulearnFinalModelCV:
"""Final model cross-validation iterator.
Wraps any CV iterator to provide an extra iteration with the full dataset.
Parameters
----------
cv : BaseCrossValidator
The cross-validation iterator to wrap.
"""

def __init__(self, cv: "BaseCrossValidator") -> None:
self.cv = cv

def split(
self,
X: np.ndarray, # noqa: N803
y: np.ndarray,
groups: Optional[np.ndarray] = None,
) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
"""Generate indices to split data into training and test set.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data, where n_samples is the number of samples
and n_features is the number of features.
Note that providing ``y`` is sufficient to generate the splits and
hence ``np.zeros(n_samples)`` may be used as a placeholder for
``X`` instead of actual training data.
y : array-like of shape (n_samples,), default=None
The target variable for supervised learning problems.
groups : array-like of shape (n_samples,), default=None
Group labels for the samples used while splitting the dataset into
train/test set.
Yields
------
train : ndarray
The training set indices for that split.
test : ndarray
The testing set indices for that split.
Notes
-----
This CV Splitter will generate an extra fold where the full dataset is
used for training and testing. This is useful to train the final model
on the full dataset at the same time as the cross-validation,
profitting for joblib calls.
"""
yield from self.cv.split(X, y, groups)
all_inds = np.arange(len(X))
# For the last fold, train on all samples and return only 2 for testing
yield all_inds, all_inds[:2]

def get_n_splits(self) -> int:
"""Get the number of splits.
Returns
-------
int
The number of splits.
"""
return self.cv.get_n_splits() + 1

def __repr__(self) -> str:
"""Return the representation of the object.
Returns
-------
str
The representation of the object.
"""
return f"{self.cv} (incl. final model)"
51 changes: 51 additions & 0 deletions julearn/model_selection/tests/test_final_model_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Provides tests for the final model CV."""

# Authors: Federico Raimondo <[email protected]>
# License: AGPL
import numpy as np
from numpy.testing import assert_array_equal
from sklearn.model_selection import RepeatedStratifiedKFold

from julearn.model_selection.final_model_cv import _JulearnFinalModelCV
from julearn.utils import _compute_cvmdsum


def test_final_model_cv() -> None:
"""Test the final model CV."""
sklearn_cv = RepeatedStratifiedKFold(
n_repeats=2, n_splits=5, random_state=42
)

julearn_cv = _JulearnFinalModelCV(sklearn_cv)

assert julearn_cv.get_n_splits() == 11

n_features = 10
n_samples = 123
X = np.zeros((n_samples, n_features))
y = np.zeros(n_samples)

all_ju = list(julearn_cv.split(X, y))
all_sk = list(sklearn_cv.split(X, y))

assert len(all_ju) == len(all_sk) + 1
for i in range(10):
assert_array_equal(all_ju[i][0], all_sk[i][0])
assert_array_equal( all_ju[i][1], all_sk[i][1])

assert all_ju[-1][0].shape[0] == n_samples
assert all_ju[-1][1].shape[0] == 2
assert_array_equal(all_ju[-1][0], np.arange(n_samples))


def test_final_model_cv_mdsum() -> None:
"""Test the mdsum of the final model CV."""
sklearn_cv = RepeatedStratifiedKFold(
n_repeats=2, n_splits=5, random_state=42
)

julearn_cv = _JulearnFinalModelCV(sklearn_cv)

mdsum = _compute_cvmdsum(julearn_cv)
mdsum_sk = _compute_cvmdsum(sklearn_cv)
assert mdsum == mdsum_sk

Check failure on line 51 in julearn/model_selection/tests/test_final_model_cv.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W292)

julearn/model_selection/tests/test_final_model_cv.py:51:29: W292 No newline at end of file

Check failure on line 51 in julearn/model_selection/tests/test_final_model_cv.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W292)

julearn/model_selection/tests/test_final_model_cv.py:51:29: W292 No newline at end of file
47 changes: 47 additions & 0 deletions julearn/model_selection/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Utility functions for model selection in julearn."""

# Authors: Federico Raimondo <[email protected]>
# License: AGPL

from sklearn.model_selection import check_cv as sk_check_cv

from .final_model_cv import _JulearnFinalModelCV


def check_cv(cv=5, classifier=False, include_final_model=False):
"""Check the CV instance and return the proper CV for julearn.
Parameters
----------
cv : int, str or cross-validation generator | None
Cross-validation splitting strategy to use for model evaluation.
Options are:
* None: defaults to 5-fold
* int: the number of folds in a `(Stratified)KFold`
* CV Splitter (see scikit-learn documentation on CV)
* An iterable yielding (train, test) splits as arrays of indices.
classifier : bool, default=False
Whether the task is a classification task, in which case
stratified KFold will be used.
include_final_model : bool, default=False
Whether to include the final model in the cross-validation. If true,
one more fold will be added to the cross-validation, where the full
dataset is used for training and testing
Returns
-------
checked_cv : a cross-validator instance.
The return value is a cross-validator which generates the train/test
splits via the ``split`` method.
"""

cv = sk_check_cv(cv, classifier=classifier)
if include_final_model:
cv = _JulearnFinalModelCV(cv)

return cv
5 changes: 5 additions & 0 deletions julearn/utils/_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ContinuousStratifiedGroupKFold,
RepeatedContinuousStratifiedGroupKFold,
)
from ..model_selection.final_model_cv import _JulearnFinalModelCV


def _recurse_to_list(a):
Expand All @@ -40,6 +41,9 @@ def _recurse_to_list(a):

def _compute_cvmdsum(cv):
"""Compute the sum of the CV generator."""
if isinstance(cv, _JulearnFinalModelCV):
return _compute_cvmdsum(cv.cv)

params = dict(vars(cv).items())
params["class"] = cv.__class__.__name__

Expand All @@ -59,6 +63,7 @@ def _compute_cvmdsum(cv):
params["test_fold"] = params["test_fold"].tolist()
params["unique_folds"] = params["unique_folds"].tolist()


if "cv" in params:
if inspect.isclass(params["cv"]):
params["cv"] = params["cv"].__class__.__name__
Expand Down
9 changes: 8 additions & 1 deletion julearn/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit
from sklearn.model_selection._split import _RepeatedSplits

from ..model_selection.final_model_cv import _JulearnFinalModelCV


try: # sklearn >= 1.4.0
from sklearn.metrics._scorer import _Scorer # type: ignore
Expand Down Expand Up @@ -387,5 +389,10 @@ def get_apply_to(self) -> ColumnTypes:


CVLike = Union[
int, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit, Iterable
int,
BaseCrossValidator,
_RepeatedSplits,
BaseShuffleSplit,
Iterable,
_JulearnFinalModelCV,
]

0 comments on commit ebb0297

Please sign in to comment.