Skip to content

Commit

Permalink
Cluster-Kfold. Now with correct spelling! (#636)
Browse files Browse the repository at this point in the history
* name

* progress

* update

* rename

* getting close

* Update sklego/model_selection.py

Co-authored-by: Francesco Bruzzesi <[email protected]>

---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
  • Loading branch information
koaning and FBruzzesi authored Mar 24, 2024
1 parent d7ae46c commit d321198
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 6 deletions.
25 changes: 25 additions & 0 deletions docs/_scripts/cross-validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,28 @@ def print_folds(cv, X, y, groups):
grid.best_estimator_.get_params()["reg__alpha"]
# 0.8
# --8<-- [end:grid-search]



######################################## ClusterKfold ####################################
##########################################################################################

# --8<-- [start:cluster-fold-start]
from sklego.model_selection import ClusterFoldValidation
from sklearn.cluster import KMeans

clusterer = KMeans(n_clusters=5, random_state=42)
folder = ClusterFoldValidation(clusterer)
# --8<-- [end:cluster-fold-start]


# --8<-- [start:cluster-fold-plot]
import matplotlib.pylab as plt
import numpy as np

X_orig = np.random.uniform(0, 1, (1000, 2))
for i, split in enumerate(folder.split(X_orig)):
x_train, x_valid = split
plt.scatter(X_orig[x_valid, 0], X_orig[x_valid, 1], label=f"split {i}")
plt.legend();
# --8<-- [end:cluster-fold-plot]
Binary file added docs/_static/cross-validation/kfold.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions docs/user-guide/cross-validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,40 @@ To use `GroupTimeSeriesSplit` with sklearn's [GridSearchCV](https://scikit-learn
--8<-- "docs/_scripts/cross-validation.py:grid-search"
```

## Cluster-Kfold

The [ClusterFoldValidation](clusterfold-api) object is a cross-validator that splits the data into `n_splits` folds, where each fold is determined by a clustering algorithm. This is not a common pattern, probably more like an anti-pattern really, but it might be useful when you want to make sure that the train and test sets are very distinct. This can be seen as a way to make it harder for the algorithm perform well, because the training sets are sampled differently than the test sets.

### Example

Here's how you could set up a cross validator that uses KMeans.

```py title="Using Kmeans to generate folds"
--8<-- "docs/_scripts/cross-validation.py:cluster-fold-start"
```

You can also use other cross validation methods, but the nice thing about Kmeans is that it demos well. Here's how it would generate folds on a uniform dataset.

```py title="Using Kmeans to generate folds"
--8<-- "docs/_scripts/cross-validation.py:cluster-fold-plot"
```

![example-1](../_static/cross-validation/kfold.png)

As you can see, each split will focus on a cluster of the data. Hopefully this also makes it clear that this method will ensure that each validation set will be rather distinct from the train set. These sets are not only exclusive, but they are also from a different region of the data by design.

Note that this image is mostly for illustrative purposes because you typically won't directly generate these folds yourself. Instead you'd use a helper function like `cross_val_score` or `GridSearchCV` to do this for you.

```py title="More realistic example"
from sklearn.model_selection import cross_val_score

# Given an existing pipeline and X,y dataset, you probably would do something like this:
fold_method = KlusterFoldValidation(
KMeans(n_cluster=5, random_state=42)
)
cross_val_score(pipeline, X, y, cv=fold_method)
```

[time-gap-split-api]: ../../api/model-selection#sklego.model_selection.TimeGapSplit
[group-ts-split-api]: ../../api/model-selection#sklego.model_selection.GroupTimeSeriesSplit
[clusterfold-api]: ../../api/model-selection#sklego.model_selection.ClusterFoldValidation
16 changes: 14 additions & 2 deletions sklego/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,20 @@ def get_split_info(X, indices, j, part, summary):
return pd.DataFrame(summary)


class KlusterFoldValidation:
"""KlusterFold cross validator. Create folds based on provided cluster method
def KlusterFoldValidation(**kwargs):
warn(
"Please use `ClusterFoldValidation` instead of `KlusterFoldValidation`."
"We will use correct spelling going forward and `KlusterFoldValidation` will be deprecated.",
DeprecationWarning,
)
return ClusterFoldValidation(**kwargs)


class ClusterFoldValidation:
"""Cross validator that creates folds based on provided cluster method.
This ensures that data points in the same cluster are not split across different folds.
!!! info "New in version 0.9.0"
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from sklego.model_selection import KlusterFoldValidation
from sklego.model_selection import ClusterFoldValidation
from tests.conftest import id_func

k_means_pipeline = make_pipeline(StandardScaler(), KMeans())
Expand Down Expand Up @@ -34,7 +34,7 @@ def fit_predict(self, X):
def test_splits_not_fitted(cluster_method, random_xy_dataset_regr):
cluster_method = clone(cluster_method)
X, y = random_xy_dataset_regr
kf = KlusterFoldValidation(cluster_method=cluster_method)
kf = ClusterFoldValidation(cluster_method=cluster_method)
for train_index, test_index in kf.split(X):
assert len(train_index) > 0
assert len(test_index) > 0
Expand All @@ -49,7 +49,7 @@ def test_splits_fitted(cluster_method, random_xy_dataset_regr):
cluster_method = clone(cluster_method)
X, y = random_xy_dataset_regr
cluster_method = cluster_method.fit(X)
kf = KlusterFoldValidation(cluster_method=cluster_method)
kf = ClusterFoldValidation(cluster_method=cluster_method)
for train_index, test_index in kf.split(X):
assert len(train_index) > 0
assert len(test_index) > 0
Expand All @@ -59,7 +59,7 @@ def test_no_split(random_xy_dataset_regr):
X, y = random_xy_dataset_regr
# With only one split, the method should raise a ValueError
cluster_method = DummyCluster(n_splits=1)
kf = KlusterFoldValidation(cluster_method=cluster_method)
kf = ClusterFoldValidation(cluster_method=cluster_method)
with pytest.raises(ValueError):
for train_index, test_index in kf.split(X):
assert len(train_index) > 0
Expand Down

0 comments on commit d321198

Please sign in to comment.