Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster-Kfold. Now with correct spelling! #636

Merged
merged 9 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/_scripts/cross-validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,28 @@ def print_folds(cv, X, y, groups):
grid.best_estimator_.get_params()["reg__alpha"]
# 0.8
# --8<-- [end:grid-search]



######################################## ClusterKfold ####################################
##########################################################################################

# --8<-- [start:cluster-fold-start]
from sklego.model_selection import ClusterFoldValidation
from sklearn.cluster import KMeans

clusterer = KMeans(n_clusters=5, random_state=42)
folder = ClusterFoldValidation(clusterer)
# --8<-- [end:cluster-fold-start]


# --8<-- [start:cluster-fold-plot]
import matplotlib.pylab as plt
import numpy as np

X_orig = np.random.uniform(0, 1, (1000, 2))
for i, split in enumerate(folder.split(X_orig)):
x_train, x_valid = split
plt.scatter(X_orig[x_valid, 0], X_orig[x_valid, 1], label=f"split {i}")
plt.legend();
# --8<-- [end:cluster-fold-plot]
Binary file added docs/_static/cross-validation/kfold.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions docs/user-guide/cross-validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,40 @@ To use `GroupTimeSeriesSplit` with sklearn's [GridSearchCV](https://scikit-learn
--8<-- "docs/_scripts/cross-validation.py:grid-search"
```

## Cluster-Kfold

The [ClusterFoldValidation](clusterfold-api) object is a cross-validator that splits the data into `n_splits` folds, where each fold is determined by a clustering algorithm. This is not a common pattern, probably more like an anti-pattern really, but it might be useful when you want to make sure that the train and test sets are very distinct. This can be seen as a way to make it harder for the algorithm perform well, because the training sets are sampled differently than the test sets.

### Example

Here's how you could set up a cross validator that uses KMeans.

```py title="Using Kmeans to generate folds"
--8<-- "docs/_scripts/cross-validation.py:cluster-fold-start"
```

You can also use other cross validation methods, but the nice thing about Kmeans is that it demos well. Here's how it would generate folds on a uniform dataset.

```py title="Using Kmeans to generate folds"
--8<-- "docs/_scripts/cross-validation.py:cluster-fold-plot"
```

![example-1](../_static/cross-validation/kfold.png)

As you can see, each split will focus on a cluster of the data. Hopefully this also makes it clear that this method will ensure that each validation set will be rather distinct from the train set. These sets are not only exclusive, but they are also from a different region of the data by design.

Note that this image is mostly for illustrative purposes because you typically won't directly generate these folds yourself. Instead you'd use a helper function like `cross_val_score` or `GridSearchCV` to do this for you.

```py title="More realistic example"
from sklearn.model_selection import cross_val_score

# Given an existing pipeline and X,y dataset, you probably would do something like this:
fold_method = KlusterFoldValidation(
KMeans(n_cluster=5, random_state=42)
)
cross_val_score(pipeline, X, y, cv=fold_method)
```

[time-gap-split-api]: ../../api/model-selection#sklego.model_selection.TimeGapSplit
[group-ts-split-api]: ../../api/model-selection#sklego.model_selection.GroupTimeSeriesSplit
[clusterfold-api]: ../../api/model-selection#sklego.model_selection.ClusterFoldValidation
13 changes: 11 additions & 2 deletions sklego/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,17 @@ def get_split_info(X, indices, j, part, summary):
return pd.DataFrame(summary)


class KlusterFoldValidation:
"""KlusterFold cross validator. Create folds based on provided cluster method
def KlusterFoldValidation(**kwargs):
warn(
"Please use `ClusterFoldValidation` instead of `KlusterFoldValidation`."
"We will use correct spelling going forward and `KlusterFoldValidation` will be deprecated.",
DeprecationWarning,
)
return ClusterFoldValidation(**kwargs)


class ClusterFoldValidation:
"""Cross validator that create folds based on provided cluster method.
koaning marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from sklego.model_selection import KlusterFoldValidation
from sklego.model_selection import ClusterFoldValidation
from tests.conftest import id_func

k_means_pipeline = make_pipeline(StandardScaler(), KMeans())
Expand Down Expand Up @@ -34,7 +34,7 @@ def fit_predict(self, X):
def test_splits_not_fitted(cluster_method, random_xy_dataset_regr):
cluster_method = clone(cluster_method)
X, y = random_xy_dataset_regr
kf = KlusterFoldValidation(cluster_method=cluster_method)
kf = ClusterFoldValidation(cluster_method=cluster_method)
for train_index, test_index in kf.split(X):
assert len(train_index) > 0
assert len(test_index) > 0
Expand All @@ -49,7 +49,7 @@ def test_splits_fitted(cluster_method, random_xy_dataset_regr):
cluster_method = clone(cluster_method)
X, y = random_xy_dataset_regr
cluster_method = cluster_method.fit(X)
kf = KlusterFoldValidation(cluster_method=cluster_method)
kf = ClusterFoldValidation(cluster_method=cluster_method)
for train_index, test_index in kf.split(X):
assert len(train_index) > 0
assert len(test_index) > 0
Expand All @@ -59,7 +59,7 @@ def test_no_split(random_xy_dataset_regr):
X, y = random_xy_dataset_regr
# With only one split, the method should raise a ValueError
cluster_method = DummyCluster(n_splits=1)
kf = KlusterFoldValidation(cluster_method=cluster_method)
kf = ClusterFoldValidation(cluster_method=cluster_method)
with pytest.raises(ValueError):
for train_index, test_index in kf.split(X):
assert len(train_index) > 0
Expand Down