-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Sort out clustering base class #2251
base: main
Are you sure you want to change the base?
Changes from 15 commits
2306ebd
0035e49
2c4fd41
04c385b
a346080
60615dd
aad00ef
bed81be
24732d2
e83387c
30fbda9
3557ac4
2a5807d
42e5a16
3f80902
63ed72d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
"""Base class for clustering.""" | ||
|
||
from typing import Optional | ||
|
||
__maintainer__ = [] | ||
__all__ = ["BaseClusterer"] | ||
|
||
|
@@ -28,8 +26,7 @@ class BaseClusterer(BaseCollectionEstimator): | |
"fit_is_empty": False, | ||
} | ||
|
||
def __init__(self, n_clusters: Optional[int] = None): | ||
self.n_clusters = n_clusters | ||
def __init__(self): | ||
# required for compatibility with some sklearn interfaces e.g. | ||
# CalibratedClassifierCV | ||
self._estimator_type = "clusterer" | ||
|
@@ -125,6 +122,7 @@ def predict_proba(self, X) -> np.ndarray: | |
self._check_shape(X) | ||
return self._predict_proba(X) | ||
|
||
@final | ||
def fit_predict(self, X, y=None) -> np.ndarray: | ||
"""Compute cluster centers and predict cluster index for each time series. | ||
|
||
|
@@ -143,11 +141,10 @@ def fit_predict(self, X, y=None) -> np.ndarray: | |
np.ndarray (1d array of shape (n_cases,)) | ||
Index of the cluster each time series in X belongs to. | ||
""" | ||
self.fit(X) | ||
return self.predict(X) | ||
return self._fit_predict(X, y) | ||
|
||
def score(self, X, y=None) -> float: | ||
"""Score the quality of the clusterer. | ||
def _fit_predict(self, X, y=None) -> np.ndarray: | ||
"""Fit predict using base methods. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -159,13 +156,13 @@ def score(self, X, y=None) -> float: | |
|
||
Returns | ||
------- | ||
score : float | ||
Score of the clusterer. | ||
np.ndarray (1d array of shape (n_cases,)) | ||
Index of the cluster each time series in X belongs to. | ||
""" | ||
self._check_is_fitted() | ||
X = self._preprocess_collection(X, store_metadata=False) | ||
self._check_shape(X) | ||
return self._score(X, y) | ||
self.fit(X) | ||
if hasattr(self, "labels_"): | ||
return self.labels_ | ||
return self.predict(X) | ||
|
||
def _predict_proba(self, X) -> np.ndarray: | ||
"""Predicts labels probabilities for sequences in X. | ||
|
@@ -198,17 +195,17 @@ def _predict_proba(self, X) -> np.ndarray: | |
for i, u in enumerate(unique): | ||
preds[preds == u] = i | ||
n_cases = len(preds) | ||
n_clusters = self.n_clusters | ||
if hasattr(self, "n_clusters"): | ||
n_clusters = self.n_clusters | ||
else: | ||
n_clusters = len(np.unique(preds)) | ||
Comment on lines
+198
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think this could be risky down the line if there are non-int methods of generating |
||
if n_clusters is None: | ||
n_clusters = int(max(preds)) + 1 | ||
dists = np.zeros((X.shape[0], n_clusters)) | ||
dists = np.zeros((len(X), n_clusters)) | ||
for i in range(n_cases): | ||
dists[i, preds[i]] = 1 | ||
return dists | ||
|
||
@abstractmethod | ||
def _score(self, X, y=None): ... | ||
|
||
@abstractmethod | ||
def _predict(self, X) -> np.ndarray: | ||
"""Predict the closest cluster each sample in X belongs to. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -173,10 +173,10 @@ def __init__( | |
self.save_last_model = save_last_model | ||
self.best_file_name = best_file_name | ||
self.random_state = random_state | ||
self.n_clusters = n_clusters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can keep this for now, but I think @hadifawaz1999 said these would be better removed ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it should not be added as we're gonna remove it in a few days |
||
|
||
super().__init__( | ||
estimator=estimator, | ||
n_clusters=n_clusters, | ||
clustering_algorithm=clustering_algorithm, | ||
clustering_params=clustering_params, | ||
batch_size=batch_size, | ||
|
@@ -336,12 +336,6 @@ def _fit(self, X): | |
|
||
return self | ||
|
||
def _score(self, X, y=None): | ||
# Transpose to conform to Keras input style. | ||
X = X.transpose(0, 2, 1) | ||
latent_space = self.model_.layers[1].predict(X) | ||
return self._estimator.score(latent_space) | ||
|
||
def _fit_multi_rec_model( | ||
self, | ||
autoencoder, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, everything should have this.