Skip to content

Commit

Permalink
dask compatible kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnovsak committed May 15, 2023
1 parent 2fe829d commit 2733715
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
26 changes: 26 additions & 0 deletions Orange/clustering/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import warnings
from typing import Union

import numpy as np
import dask.array as da
import sklearn.cluster

from Orange.clustering.clustering import Clustering, ClusteringModel
from Orange.data import Table
from Orange.data.dask import DaskTable


__all__ = ["KMeans"]
Expand Down Expand Up @@ -37,6 +41,28 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10, max_iter=300,
preprocessors, {k: v for k, v in vars().items()
if k != "compute_silhouette_score"})

def fit(self, X: Union[np.ndarray, da.Array], y: np.ndarray = None):
params = self.params.copy()
__wraps__ = self.__wraps__
if isinstance(X, da.Array):
try:
import dask_ml.cluster

del params["n_init"]
assert params["init"] == "k-means||"

X = X.rechunk({0: "auto", 1: -1})
__wraps__ = dask_ml.cluster.KMeans

except ImportError:
warnings.warn("dask_ml is not installed. Using sklearn instead.")

return self.__returns__(__wraps__(**params).fit(X))

def preprocess(self, data):
# temporary workaround until preprocessors support dask
return data if isinstance(data, DaskTable) else super().preprocess(data)


if __name__ == "__main__":
d = Table("iris")
Expand Down
37 changes: 29 additions & 8 deletions Orange/widgets/unsupervised/owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from Orange.clustering import KMeans
from Orange.clustering.kmeans import KMeansModel
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
from Orange.data.dask import DaskTable
from Orange.data.util import get_unique_names, array_equal
from Orange.preprocess import Normalize
from Orange.preprocess.impute import ReplaceUnknowns
Expand Down Expand Up @@ -135,6 +136,7 @@ class Warning(widget.OWWidget.Warning):

INIT_METHODS = (("Initialize with KMeans++", "k-means++"),
("Random initialization", "random"))
DASK_METHODS = (("Initialize with KMeans||", "k-means||"),)

resizing_enabled = False

Expand Down Expand Up @@ -175,18 +177,18 @@ def __init__(self):
box="Number of Clusters", callback=self.update_method,
)

layout.addWidget(
gui.appendRadioButton(bg, "Fixed:", addToLayout=False), 1, 1)
self.fixed_radio_button = gui.appendRadioButton(bg, "Fixed:", addToLayout=False)
layout.addWidget(self.fixed_radio_button, 1, 1)
sb = gui.hBox(None, margin=0)
gui.spin(
sb, self, "k", minv=2, maxv=30,
controlWidth=60, alignment=Qt.AlignRight, callback=self.update_k)
gui.rubber(sb)
layout.addWidget(sb, 1, 2)

layout.addWidget(
gui.appendRadioButton(bg, "From", addToLayout=False), 2, 1)
ftobox = gui.hBox(None)
self.range_radio_button = gui.appendRadioButton(bg, "From", addToLayout=False)
layout.addWidget(self.range_radio_button, 2, 1)
self.ftobox = ftobox = gui.hBox(None)
ftobox.layout().setContentsMargins(0, 0, 0, 0)
layout.addWidget(ftobox, 2, 2)
gui.spin(
Expand Down Expand Up @@ -293,6 +295,12 @@ def _compute_clustering(data, k, init, n_init, max_iter, random_state):
random_state=random_state, preprocessors=[]
).get_model(data)

if isinstance(data, DaskTable):
# just skip silhouettes for now
model.silhouette_samples = None
model.silhouette = np.nan
return model

if data.X.shape[0] <= SILHOUETTE_MAX_SAMPLES:
model.silhouette_samples = silhouette_samples(data.X, model.labels)
model.silhouette = np.mean(model.silhouette_samples)
Expand Down Expand Up @@ -501,9 +509,7 @@ def preproces(self, data):
self.Warning.no_sparse_normalization()
else:
data = Normalize()(data)
for preprocessor in KMeans.preprocessors: # use same preprocessors than
data = preprocessor(data)
return data
return KMeans().preprocess(data) # why?

def send_data(self):
if self.optimize_k:
Expand Down Expand Up @@ -584,13 +590,28 @@ def set_data(self, data):
self.controls.normalize.setDisabled(
bool(self.data) and sp.issparse(self.data.X))

if type(data) is not type(old_data):
self.setup_controls(isinstance(self.data, DaskTable))

# Do not needlessly recluster the data if X hasn't changed
if old_data and self.data and array_equal(self.data.X, old_data.X):
if self.auto_commit:
self.send_data()
else:
self.invalidate(unconditional=True)

def setup_controls(self, is_dask):
self.ftobox.setDisabled(is_dask)
self.range_radio_button.setDisabled(is_dask)
if is_dask: self.fixed_radio_button.setChecked(True)
self.optimize_k = self.range_radio_button.isChecked()
self.INIT_METHODS = OWKMeans.DASK_METHODS \
if is_dask else OWKMeans.INIT_METHODS
self.controls.smart_init.clear()
self.controls.smart_init.addItems([t[0] for t in self.INIT_METHODS])
self.smart_init = 0
self.controls.n_init.setDisabled(is_dask)

def send_report(self):
# False positives (Setting is not recognized as int)
# pylint: disable=invalid-sequence-index
Expand Down

0 comments on commit 2733715

Please sign in to comment.