Skip to content

Commit

Permalink
Merge pull request #6487 from markotoplak/dask-preprocessors
Browse files Browse the repository at this point in the history
Dask preprocessors
  • Loading branch information
markotoplak committed Oct 10, 2023
2 parents b09c2d6 + 22283ec commit d7070ae
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 27 deletions.
4 changes: 4 additions & 0 deletions Orange/data/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def join_columns(self, data):

class _FromTableConversionDask(_FromTableConversion):

# set very high to make the compute graph smaller, because
# for dask operations it does not matter how high it is
max_rows_at_once = 5000*1000

_array_conversion_class = _ArrayConversionDask

def __init__(self, source, destination):
Expand Down
11 changes: 9 additions & 2 deletions Orange/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.impute import SimpleImputer

import Orange.data
from Orange.data.dask import DaskTable
from Orange.data.filter import HasClass
from Orange.statistics import distribution
from Orange.util import Reprable, Enum, deprecated
Expand Down Expand Up @@ -157,8 +158,14 @@ def __call__(self, data):
from Orange.data.sql.table import SqlTable
if isinstance(data, SqlTable):
return Impute()(data)
imputer = SimpleImputer(strategy=self.strategy)
imputer.fit(data.X)
wraps = self.__wraps__
X = data.X
if isinstance(data, DaskTable):
import dask_ml.impute
wraps = dask_ml.impute.SimpleImputer
X = X.rechunk({0: "auto", 1: -1})
imputer = wraps(strategy=self.strategy)
imputer.fit(X)
# Create new variables with appropriate `compute_value`, but
# drop the ones which do not have valid `imputer.statistics_`
# (i.e. all NaN columns). `sklearn.preprocessing.Imputer` already
Expand Down
77 changes: 54 additions & 23 deletions Orange/tests/test_continuize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,27 @@
import unittest

from Orange.data import Table, Variable
from Orange.data.dask import DaskTable
from Orange.preprocess.continuize import DomainContinuizer
from Orange.preprocess import Continuize
from Orange.preprocess import transformation
from Orange.tests import test_filename
from Orange.tests.test_dasktable import temp_dasktable


class TestDomainContinuizer(unittest.TestCase):
def setUp(self):
self.data = Table(test_filename("datasets/test4"))

@classmethod
def setUpClass(cls):
cls.data = Table(test_filename("datasets/test4"))

def compare_tables(self, data, solution):
for i in range(len(data)):
for j in range(len(data[i])):
if type(solution[i][j]) == float:
self.assertAlmostEqual(data[i, j], solution[i][j], places=20)
else:
self.assertEqual(data[i, j], solution[i][j])

def test_default(self):
for inp in (self.data, self.data.domain):
Expand All @@ -29,9 +41,10 @@ def test_default(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 1, 0, 1, 0, 0, "a"])
self.assertEqual(dat2[1], [0, 0, 0, 1, 0, 1, 0, "b"])
self.assertEqual(dat2[2], [2, 2, 0, 1, 0, 0, 1, "c"])
solution = [[1, -2, 1, 0, 1, 0, 0, "a"],
[0, 0, 0, 1, 0, 1, 0, "b"],
[2, 2, 0, 1, 0, 0, 1, "c"]]
self.compare_tables(dat2, solution)

def test_continuous_transform_class(self):
for inp in (self.data, self.data.domain):
Expand All @@ -48,9 +61,10 @@ def test_continuous_transform_class(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 1, 0, 1, 0, 0, 1, 0, 0])
self.assertEqual(dat2[1], [0, 0, 0, 1, 0, 1, 0, 0, 1, 0])
self.assertEqual(dat2[2], [2, 2, 0, 1, 0, 0, 1, 0, 0, 1])
solution = [[1, -2, 1, 0, 1, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 0, 1, 0],
[2, 2, 0, 1, 0, 0, 1, 0, 0, 1]]
self.compare_tables(dat2, solution)

def test_multi_indicators(self):
for inp in (self.data, self.data.domain):
Expand All @@ -69,9 +83,10 @@ def test_multi_indicators(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 1, 0, 1, 0, 0, "a"])
self.assertEqual(dat2[1], [0, 0, 0, 1, 0, 1, 0, "b"])
self.assertEqual(dat2[2], [2, 2, 0, 1, 0, 0, 1, "c"])
solution = [[1, -2, 1, 0, 1, 0, 0, "a"],
[0, 0, 0, 1, 0, 1, 0, "b"],
[2, 2, 0, 1, 0, 0, 1, "c"]]
self.compare_tables(dat2, solution)

def test_multi_lowest_base(self):
for inp in (self.data, self.data.domain):
Expand All @@ -89,9 +104,10 @@ def test_multi_lowest_base(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 0, 0, 0, "a"])
self.assertEqual(dat2[1], [0, 0, 1, 1, 0, "b"])
self.assertEqual(dat2[2], [2, 2, 1, 0, 1, "c"])
solution = [[1, -2, 0, 0, 0, "a"],
[0, 0, 1, 1, 0, "b"],
[2, 2, 1, 0, 1, "c"]]
self.compare_tables(dat2, solution)

def test_multi_ignore(self):
dom = DomainContinuizer(multinomial_treatment=Continuize.Remove)
Expand Down Expand Up @@ -153,9 +169,10 @@ def test_as_ordinal(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 0, 0, "a"])
self.assertEqual(dat2[1], [0, 0, 1, 1, "b"])
self.assertEqual(dat2[2], [2, 2, 1, 2, "c"])
solution = [[1, -2, 0, 0, "a"],
[0, 0, 1, 1, "b"],
[2, 2, 1, 2, "c"]]
self.compare_tables(dat2, solution)

def test_as_ordinal_class(self):
for inp in (self.data, self.data.domain):
Expand All @@ -172,9 +189,10 @@ def test_as_ordinal_class(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 0, 0, 0])
self.assertEqual(dat2[1], [0, 0, 1, 1, 1])
self.assertEqual(dat2[2], [2, 2, 1, 2, 2])
solution = [[1, -2, 0, 0, 0],
[0, 0, 1, 1, 1],
[2, 2, 1, 2, 2]]
self.compare_tables(dat2, solution)

def test_as_normalized_ordinal(self):
for inp in (self.data, self.data.domain):
Expand All @@ -190,6 +208,19 @@ def test_as_normalized_ordinal(self):

dat2 = self.data.transform(dom)
# c1 c2 d2 d3 cl1
self.assertEqual(dat2[0], [1, -2, 0, 0, "a"])
self.assertEqual(dat2[1], [0, 0, 1, 0.5, "b"])
self.assertEqual(dat2[2], [2, 2, 1, 1, "c"])
solution = [[1, -2, 0, 0, "a"],
[0, 0, 1, 0.5, "b"],
[2, 2, 1, 1, "c"]]
self.compare_tables(dat2, solution)


class TestDomainContinuizerDask(TestDomainContinuizer):

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.data = temp_dasktable(cls.data)

def compare_tables(self, data, solution):
self.assertIsInstance(data, DaskTable)
super().compare_tables(data.compute(), solution) # .compute avoids warning
22 changes: 20 additions & 2 deletions Orange/tests/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import scipy.sparse as sp

from Orange.data import Table, Domain, ContinuousVariable
from Orange.data.dask import DaskTable
from Orange.preprocess import Normalize
from Orange.tests import test_filename
from Orange.tests.test_dasktable import temp_dasktable


class TestNormalizer(unittest.TestCase):
Expand All @@ -26,6 +28,8 @@ def compare_tables(self, dataNorm, solution):
@classmethod
def setUpClass(cls):
cls.data = Table(test_filename("datasets/test5.tab"))
cls.iris = Table("iris")
cls.test10 = Table(test_filename("datasets/test10.tab"))

def test_normalize_default(self):
normalizer = Normalize()
Expand Down Expand Up @@ -134,7 +138,7 @@ def test_skip_normalization(self):
np.testing.assert_array_equal(data.X, normalized.X)

def test_datetime_normalization(self):
data = Table(test_filename("datasets/test10.tab"))
data = self.test10
normalizer = Normalize(zero_based=False,
norm_type=Normalize.NormalizeBySD,
transform_class=False)
Expand All @@ -145,7 +149,7 @@ def test_datetime_normalization(self):
self.compare_tables(data_norm, solution)

def test_retain_vars_attributes(self):
data = Table("iris")
data = self.iris
attributes = {"foo": "foo", "baz": 1}
data.domain.attributes[0].attributes = attributes
self.assertDictEqual(
Expand All @@ -169,5 +173,19 @@ def test_number_of_decimals(self):
self.assertEqual(str(val1[0]), val2)


class TestNormalizerDask(TestNormalizer):

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.data = temp_dasktable(cls.data)
cls.iris = temp_dasktable(cls.iris)
cls.test10 = temp_dasktable(cls.test10)

def compare_tables(self, dataNorm, solution):
self.assertIsInstance(dataNorm, DaskTable)
super().compare_tables(dataNorm.compute(), solution) # .compute avoids warnings


if __name__ == "__main__":
unittest.main()

0 comments on commit d7070ae

Please sign in to comment.