Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API duck-typing for n_neighbors in CNN and deprecate estimator_ #891

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v0.10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@ Deprecation
estimator where `n_jobs` is set.
:pr:`887` by :user:`Guillaume Lemaitre <glemaitre>`.

- The fitted attribute `estimator_` in
:class:`~imblearn.under_sampling.CondensedNearestNeighbour`
has been deprecated and will be removed in 0.12. Instead, use the
`n_neighbors_` fitted attribute.
:pr:`891` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

- Add support to accept compatible `NearestNeighbors` objects by only
duck-typing. For instance, it allows to accept cuML instances.
:pr:`858` by :user:`NV-jpt <NV-jpt>` and
:user:`Guillaume Lemaitre <glemaitre>`.

- Add support to accept compatible `KNearestNeighbors` objects that can be
clone and have an attribute `n_neighbors`.
:pr:`891` by :user:`Guillaume Lemaitre <glemaitre>`.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
# License: MIT

from collections import Counter
from numbers import Integral

import numpy as np

from scipy.sparse import issparse

from sklearn.base import clone
from sklearn.base import clone, is_classifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils import check_random_state, _safe_indexing
from sklearn.utils.deprecation import deprecated

from ..base import BaseCleaningSampler
from ...utils import Substitution
Expand Down Expand Up @@ -58,9 +60,16 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
corresponds to the class labels from which to sample and the values
are the number of samples to sample.

n_neighbors_ : estimator object
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.

estimator_ : estimator object
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.

.. deprecated:: 0.10
`estimator_` is deprecated in 0.10 and will be removed in 0.12.
Use `n_neighbors_` instead.

sample_indices_ : ndarray of shape (n_new_samples,)
Indices of the samples selected.

Expand Down Expand Up @@ -94,18 +103,17 @@ class CondensedNearestNeighbour(BaseCleaningSampler):

Examples
--------
>>> from collections import Counter # doctest: +SKIP
>>> from sklearn.datasets import fetch_mldata # doctest: +SKIP
>>> from collections import Counter
>>> from sklearn.datasets import load_breast_cancer
>>> from imblearn.under_sampling import \
CondensedNearestNeighbour # doctest: +SKIP
>>> pima = fetch_mldata('diabetes_scale') # doctest: +SKIP
>>> X, y = pima['data'], pima['target'] # doctest: +SKIP
>>> print('Original dataset shape %s' % Counter(y)) # doctest: +SKIP
Original dataset shape Counter({{1: 500, -1: 268}}) # doctest: +SKIP
>>> cnn = CondensedNearestNeighbour(random_state=42) # doctest: +SKIP
>>> X_res, y_res = cnn.fit_resample(X, y) #doctest: +SKIP
>>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: +SKIP
Resampled dataset shape Counter({{-1: 268, 1: 227}}) # doctest: +SKIP
CondensedNearestNeighbour
>>> X, y = load_breast_cancer(return_X_y=True)
>>> print('Original dataset shape %s' % Counter(y))
Original dataset shape Counter({{1: 357, 0: 212}})
>>> cnn = CondensedNearestNeighbour(random_state=42)
>>> X_res, y_res = cnn.fit_resample(X, y)
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 212, 1: 50}})
"""

@_deprecate_positional_args
Expand All @@ -125,20 +133,20 @@ def __init__(
self.n_jobs = n_jobs

def _validate_estimator(self):
"""Private function to create the NN estimator"""
if self.n_neighbors is None:
self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, int):
self.estimator_ = KNeighborsClassifier(
self.n_neighbors_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, Integral):
self.n_neighbors_ = KNeighborsClassifier(
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
)
elif isinstance(self.n_neighbors, KNeighborsClassifier):
self.estimator_ = clone(self.n_neighbors)
elif is_classifier(self.n_neighbors) and hasattr(
self.n_neighbors, "n_neighbors"
):
self.n_neighbors_ = clone(self.n_neighbors)
else:
raise ValueError(
f"`n_neighbors` has to be a int or an object"
f" inhereited from KNeighborsClassifier."
f" Got {type(self.n_neighbors)} instead."
"`n_neighbors` must be an integer or a KNN classifier having an "
f"attribute `n_neighbors`. Got {self.n_neighbors!r} instead."
)

def _fit_resample(self, X, y):
Expand Down Expand Up @@ -175,7 +183,7 @@ def _fit_resample(self, X, y):
S_y = _safe_indexing(y, S_indices)

# fit knn on C
self.estimator_.fit(C_x, C_y)
self.n_neighbors_.fit(C_x, C_y)

good_classif_label = idx_maj_sample.copy()
# Check each sample in S if we keep it or drop it
Expand All @@ -188,7 +196,7 @@ def _fit_resample(self, X, y):
# Classify on S
if not issparse(x_sam):
x_sam = x_sam.reshape(1, -1)
pred_y = self.estimator_.predict(x_sam)
pred_y = self.n_neighbors_.predict(x_sam)

# If the prediction do not agree with the true label
# append it in C_x
Expand All @@ -202,12 +210,12 @@ def _fit_resample(self, X, y):
C_y = _safe_indexing(y, C_indices)

# fit a knn on C
self.estimator_.fit(C_x, C_y)
self.n_neighbors_.fit(C_x, C_y)

# This experimental to speed up the search
# Classify all the element in S and avoid to test the
# well classified elements
pred_S_y = self.estimator_.predict(S_x)
pred_S_y = self.n_neighbors_.predict(S_x)
good_classif_label = np.unique(
np.append(idx_maj_sample, np.flatnonzero(pred_S_y == S_y))
)
Expand All @@ -224,3 +232,11 @@ def _fit_resample(self, X, y):

def _more_tags(self):
return {"sample_indices": True}

@deprecated( # type: ignore
"`estimator_` is deprecated in version 0.10 and will be "
"removed in version 0.12. Use `n_neighbors_` instead."
)
@property
def estimator_(self):
return self.n_neighbors_
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,20 @@ def test_cnn_fit_resample_with_object():
def test_cnn_fit_resample_with_wrong_object():
knn = "rnd"
cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn)
with pytest.raises(ValueError, match="has to be a int or an "):
msg = "`n_neighbors` must be an integer or a KNN classifier"
with pytest.raises(ValueError, match=msg):
cnn.fit_resample(X, Y)


def test_cnn_estimator_deprecation():
cnn = CondensedNearestNeighbour(random_state=RND_SEED)
cnn.fit_resample(X, Y)

msg = "`estimator_` is deprecated in version 0.10"
with pytest.warns(FutureWarning, match=msg):
assert cnn.estimator_ == cnn.n_neighbors_


def test_cnn_custom_knn():
# FIXME: accept any arbitrary KNN classifier
pass