-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Modify logic for final model training
- Loading branch information
Showing
6 changed files
with
229 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
"""CV Wrapper that includes a fold with all the data.""" | ||
|
||
# Authors: Federico Raimondo <[email protected]> | ||
# License: AGPL | ||
from typing import TYPE_CHECKING, Generator, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
if TYPE_CHECKING: | ||
from sklearn.model_selection import BaseCrossValidator | ||
|
||
|
||
class _JulearnFinalModelCV: | ||
"""Final model cross-validation iterator. | ||
Wraps any CV iterator to provide an extra iteration with the full dataset. | ||
Parameters | ||
---------- | ||
cv : BaseCrossValidator | ||
The cross-validation iterator to wrap. | ||
""" | ||
|
||
def __init__(self, cv: "BaseCrossValidator") -> None: | ||
self.cv = cv | ||
|
||
def split( | ||
self, | ||
X: np.ndarray, # noqa: N803 | ||
y: np.ndarray, | ||
groups: Optional[np.ndarray] = None, | ||
) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: | ||
"""Generate indices to split data into training and test set. | ||
Parameters | ||
---------- | ||
X : array-like of shape (n_samples, n_features) | ||
Training data, where n_samples is the number of samples | ||
and n_features is the number of features. | ||
Note that providing ``y`` is sufficient to generate the splits and | ||
hence ``np.zeros(n_samples)`` may be used as a placeholder for | ||
``X`` instead of actual training data. | ||
y : array-like of shape (n_samples,), default=None | ||
The target variable for supervised learning problems. | ||
groups : array-like of shape (n_samples,), default=None | ||
Group labels for the samples used while splitting the dataset into | ||
train/test set. | ||
Yields | ||
------ | ||
train : ndarray | ||
The training set indices for that split. | ||
test : ndarray | ||
The testing set indices for that split. | ||
Notes | ||
----- | ||
This CV Splitter will generate an extra fold where the full dataset is | ||
used for training and testing. This is useful to train the final model | ||
on the full dataset at the same time as the cross-validation, | ||
profitting for joblib calls. | ||
""" | ||
yield from self.cv.split(X, y, groups) | ||
all_inds = np.arange(len(X)) | ||
# For the last fold, train on all samples and return only 2 for testing | ||
yield all_inds, all_inds[:2] | ||
|
||
def get_n_splits(self) -> int: | ||
"""Get the number of splits. | ||
Returns | ||
------- | ||
int | ||
The number of splits. | ||
""" | ||
return self.cv.get_n_splits() + 1 | ||
|
||
def __repr__(self) -> str: | ||
"""Return the representation of the object. | ||
Returns | ||
------- | ||
str | ||
The representation of the object. | ||
""" | ||
return f"{self.cv} (incl. final model)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Provides tests for the final model CV.""" | ||
|
||
# Authors: Federico Raimondo <[email protected]> | ||
# License: AGPL | ||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
from sklearn.model_selection import RepeatedStratifiedKFold | ||
|
||
from julearn.model_selection.final_model_cv import _JulearnFinalModelCV | ||
from julearn.utils import _compute_cvmdsum | ||
|
||
|
||
def test_final_model_cv() -> None: | ||
"""Test the final model CV.""" | ||
sklearn_cv = RepeatedStratifiedKFold( | ||
n_repeats=2, n_splits=5, random_state=42 | ||
) | ||
|
||
julearn_cv = _JulearnFinalModelCV(sklearn_cv) | ||
|
||
assert julearn_cv.get_n_splits() == 11 | ||
|
||
n_features = 10 | ||
n_samples = 123 | ||
X = np.zeros((n_samples, n_features)) | ||
y = np.zeros(n_samples) | ||
|
||
all_ju = list(julearn_cv.split(X, y)) | ||
all_sk = list(sklearn_cv.split(X, y)) | ||
|
||
assert len(all_ju) == len(all_sk) + 1 | ||
for i in range(10): | ||
assert_array_equal(all_ju[i][0], all_sk[i][0]) | ||
assert_array_equal( all_ju[i][1], all_sk[i][1]) | ||
|
||
assert all_ju[-1][0].shape[0] == n_samples | ||
assert all_ju[-1][1].shape[0] == 2 | ||
assert_array_equal(all_ju[-1][0], np.arange(n_samples)) | ||
|
||
|
||
def test_final_model_cv_mdsum() -> None: | ||
"""Test the mdsum of the final model CV.""" | ||
sklearn_cv = RepeatedStratifiedKFold( | ||
n_repeats=2, n_splits=5, random_state=42 | ||
) | ||
|
||
julearn_cv = _JulearnFinalModelCV(sklearn_cv) | ||
|
||
mdsum = _compute_cvmdsum(julearn_cv) | ||
mdsum_sk = _compute_cvmdsum(sklearn_cv) | ||
assert mdsum == mdsum_sk | ||
Check failure on line 51 in julearn/model_selection/tests/test_final_model_cv.py GitHub Actions / lintRuff (W292)
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Utility functions for model selection in julearn.""" | ||
|
||
# Authors: Federico Raimondo <[email protected]> | ||
# License: AGPL | ||
|
||
from sklearn.model_selection import check_cv as sk_check_cv | ||
|
||
from .final_model_cv import _JulearnFinalModelCV | ||
|
||
|
||
def check_cv(cv=5, classifier=False, include_final_model=False): | ||
"""Check the CV instance and return the proper CV for julearn. | ||
Parameters | ||
---------- | ||
cv : int, str or cross-validation generator | None | ||
Cross-validation splitting strategy to use for model evaluation. | ||
Options are: | ||
* None: defaults to 5-fold | ||
* int: the number of folds in a `(Stratified)KFold` | ||
* CV Splitter (see scikit-learn documentation on CV) | ||
* An iterable yielding (train, test) splits as arrays of indices. | ||
classifier : bool, default=False | ||
Whether the task is a classification task, in which case | ||
stratified KFold will be used. | ||
include_final_model : bool, default=False | ||
Whether to include the final model in the cross-validation. If true, | ||
one more fold will be added to the cross-validation, where the full | ||
dataset is used for training and testing | ||
Returns | ||
------- | ||
checked_cv : a cross-validator instance. | ||
The return value is a cross-validator which generates the train/test | ||
splits via the ``split`` method. | ||
""" | ||
|
||
cv = sk_check_cv(cv, classifier=classifier) | ||
if include_final_model: | ||
cv = _JulearnFinalModelCV(cv) | ||
|
||
return cv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters