-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0be30fb
commit e9996b9
Showing
2 changed files
with
297 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,297 @@ | ||
from inspect import getmembers, isclass | ||
from itertools import combinations | ||
|
||
import desbordante as desb | ||
import unittest | ||
|
||
for algorithm_name, type_ in getmembers(desb, isclass): | ||
if not (issubclass(type_, desb.FdAlgorithm) | ||
and type_ is not desb.FdAlgorithm): | ||
continue | ||
algorithm = type_() | ||
for option_name in algorithm.get_possible_options(): | ||
print(option_name, algorithm.get_option_type(option_name)) | ||
algorithm.load_data('WDC_satellites.csv', ',', False) | ||
algorithm.execute() | ||
print(algorithm_name, algorithm.get_fds()) | ||
|
||
class ITestAlgorithm: | ||
def __init__(self, alg_type, testcase, **kwargs): | ||
self.alg_type = alg_type | ||
self.testcase = testcase | ||
self.options = kwargs | ||
|
||
def __get_options_combinations(self): | ||
dct = [] | ||
for i in range(1, len(self.options) + 1): | ||
a = list(combinations(self.options.items(), i)) | ||
dct += list((map(dict, a))) | ||
return dct | ||
|
||
def _init_alg(self): | ||
raise NotImplementedError() | ||
|
||
def __test_options(self, **kwargs): | ||
self._init_alg() | ||
for i in kwargs: | ||
self.alg.set_option(i, kwargs[i]) | ||
for i in kwargs: | ||
if isinstance(kwargs[i], float): | ||
self.testcase.assertEqual(float(self.alg.get_opts()[i]), kwargs[i]) | ||
else: | ||
self.testcase.assertEqual(self.alg.get_opts()[i], str(kwargs[i])) | ||
|
||
def _test_data_loading(self): | ||
raise NotImplementedError() | ||
|
||
def __test_setting_options(self): | ||
for i in self.__get_options_combinations(): | ||
with self.testcase.subTest(msg=f"testing options combination:{i}"): | ||
self.__test_options(**i) | ||
|
||
def execute(self): | ||
with self.testcase.subTest(msg="testing algorithm.load_data()"): | ||
self._test_data_loading() | ||
with self.testcase.subTest(msg="testing algorithm.set_option()"): | ||
self.__test_setting_options() | ||
|
||
|
||
class TestCommonAlgo(ITestAlgorithm): | ||
def _test_data_loading(self): | ||
alg = self.alg_type() | ||
alg.load_data("WDC_satellites.csv", ",", False, is_null_equal_null=False) | ||
|
||
self.testcase.assertEqual("False", alg.get_opts()["is_null_equal_null"]) | ||
|
||
alg = self.alg_type() | ||
alg.load_data("WDC_satellites.csv", ",", False, is_null_equal_null=True) | ||
|
||
self.testcase.assertEqual("True", alg.get_opts()["is_null_equal_null"]) | ||
|
||
def _init_alg(self): | ||
self.alg = self.alg_type() | ||
self.alg.load_data("WDC_satellites.csv", ",", False, is_null_equal_null=False) | ||
|
||
|
||
class TestApriori(ITestAlgorithm): | ||
def _test_concrete_load(self, input_format, **kwargs): | ||
alg = self.alg_type() | ||
|
||
alg.load_data("TestWide.csv", ",", True, input_format=input_format, **kwargs) | ||
for i in kwargs: | ||
self.testcase.assertEqual(str(kwargs[i]), alg.get_opts()[i]) | ||
|
||
def _test_data_loading(self): | ||
self._test_concrete_load(input_format="tabular", has_tid=True) | ||
self._test_concrete_load(input_format="tabular", has_tid=False) | ||
|
||
self._test_concrete_load( | ||
input_format="singular", | ||
tid_column_index=1, | ||
item_column_index=2 | ||
) | ||
|
||
self._test_concrete_load( | ||
input_format="singular", | ||
tid_column_index=0, | ||
item_column_index=2 | ||
) | ||
|
||
self._test_concrete_load( | ||
input_format="singular", | ||
tid_column_index=1, | ||
item_column_index=1 | ||
) | ||
|
||
def _init_alg(self): | ||
self.alg = self.alg_type() | ||
self.alg.load_data( | ||
"rules-kaggle-rows.csv", | ||
",", | ||
True, | ||
is_null_equal_null=False, | ||
input_format="tabular", | ||
has_tid=True, | ||
tid_column_index=0, | ||
item_column_index=1, | ||
) | ||
|
||
|
||
def run_test(alg_type, testcase, **kwargs): | ||
tester = TestCommonAlgo(alg_type, testcase, **kwargs) | ||
|
||
if alg_type == desb.Apriori: | ||
tester = TestApriori(alg_type, testcase, **kwargs) | ||
|
||
tester.execute() | ||
|
||
|
||
def test_metricverifier(dataset, **kwargs) -> bool: | ||
alg = desb.MetricVerifier() | ||
alg.load_data(dataset, ",", True) | ||
worked = True | ||
try: | ||
for i in kwargs: | ||
alg.set_option(i, kwargs[i]) | ||
except Exception as _: | ||
print(_) | ||
worked = False | ||
return worked | ||
|
||
|
||
class TestBindings(unittest.TestCase): | ||
only_null_eq_null = [ | ||
desb.Aid, | ||
desb.Depminer, | ||
desb.FDep, | ||
desb.FUN, | ||
desb.FdMine, | ||
desb.HyFD, | ||
] | ||
|
||
def test_algos_with_only_null_eq_null(self): | ||
for i in self.only_null_eq_null: | ||
with self.subTest(i=i): | ||
run_test(i, self) | ||
|
||
def test_pyro(self): | ||
run_test(desb.Pyro, self, seed=1, max_lhs=12, threads=15, error=0.015) | ||
|
||
def test_dfd(self): | ||
run_test(desb.DFD, self, threads=15) | ||
|
||
def test_fastfds(self): | ||
run_test(desb.FastFDs, self, max_lhs=12, threads=15) | ||
|
||
def test_tane(self): | ||
run_test(desb.Tane, self, max_lhs=12, error=0.015) | ||
|
||
def test_datastats(self): | ||
run_test(desb.DataStats, self, threads=15) | ||
|
||
def test_hyucc(self): | ||
run_test(desb.HyUCC, self, threads=15) | ||
|
||
def test_fdverifier(self): | ||
run_test(desb.FDVerifier, self, lhs_indices=[1, 2, 3], rhs_indices=[1, 2, 3]) | ||
|
||
def test_apriori(self): | ||
run_test(desb.Apriori, self, minconf=0.00312, minsup=0.2321) | ||
|
||
def test_metricverifier_euclidean(self): | ||
with self.subTest(msg="metric = euclidean, non numeric values"): | ||
self.assertFalse( | ||
test_metricverifier( | ||
"WDC_satellites.csv", | ||
metric="euclidean", | ||
rhs_indices=[1, 2], | ||
) | ||
) | ||
|
||
with self.subTest(msg="metric = euclidean, q should not be set"): | ||
self.assertFalse( | ||
test_metricverifier("TestLong.csv", metric="euclidean", q=123) | ||
) | ||
|
||
with self.subTest( | ||
msg="metric = euclidean, metric_algorithm set before rhs_indices" | ||
): | ||
self.assertFalse( | ||
test_metricverifier( | ||
"TestLong.csv", | ||
metric="euclidean", | ||
metric_algorithm="brute", | ||
) | ||
) | ||
|
||
with self.subTest(msg="should work normal"): | ||
alg = desb.MetricVerifier() | ||
alg.load_data("TestLong.csv", ",", True) | ||
opts = { | ||
"metric": "euclidean", | ||
"rhs_indices": [1, 2], | ||
"parameter": 213.213111, | ||
"dist_from_null_is_infinity": False, | ||
"metric_algorithm": "approx", | ||
"lhs_indices": [0, 1, 2], | ||
} | ||
for i in opts: | ||
alg.set_option(i, opts[i]) | ||
for i in opts: | ||
self.assertEqual(alg.get_opts()[i], str(opts[i])) | ||
|
||
def test_metricverifier_levenshtein(self): | ||
with self.subTest(msg="metric = levenshtein, multidimentional rhs_indices"): | ||
self.assertFalse( | ||
test_metricverifier( | ||
"WDC_satellites.csv", | ||
metric="levenshtein", | ||
rhs_indices=[1, 2], | ||
) | ||
) | ||
|
||
with self.subTest(msg="metric = levenshtein, q should not be set"): | ||
self.assertFalse( | ||
test_metricverifier("WDC_satellites.csv", metric="levenshtein", q=123) | ||
) | ||
|
||
with self.subTest( | ||
msg="metric = levenshtein, metric_algorithm set before rhs_indices" | ||
): | ||
self.assertFalse( | ||
test_metricverifier( | ||
"WDC_satellites.csv", | ||
metric="levenshtein", | ||
metric_algorithm="brute", | ||
) | ||
) | ||
|
||
with self.subTest(msg="metric = levenshtein, non string columns"): | ||
self.assertFalse( | ||
test_metricverifier( | ||
"TestLong.csv", metric="levenshtein", rhs_indices=[1] | ||
) | ||
) | ||
|
||
with self.subTest(msg="should work normal"): | ||
alg = desb.MetricVerifier() | ||
alg.load_data("WDC_satellites.csv", ",", True) | ||
opts = { | ||
"metric": "levenshtein", | ||
"rhs_indices": [1], | ||
"parameter": 213.213111, | ||
"dist_from_null_is_infinity": False, | ||
"metric_algorithm": "approx", | ||
"lhs_indices": [0, 1, 2], | ||
} | ||
for i in opts: | ||
alg.set_option(i, opts[i]) | ||
for i in opts: | ||
self.assertEqual(alg.get_opts()[i], str(opts[i])) | ||
|
||
def test_metricverifier_cosine(self): | ||
with self.subTest(msg="cosine metric unavliable for non string columns"): | ||
self.assertFalse(test_metricverifier("TestLong.csv", rhs_indices=[1])) | ||
|
||
with self.subTest(msg="multidimentional rhs is not allowed"): | ||
self.assertFalse( | ||
test_metricverifier("WDC_satellites.csv", rhs_indices=[1, 2]) | ||
) | ||
|
||
with self.subTest(msg="set q before rhs"): | ||
self.assertFalse(test_metricverifier("WDC_satellites.csv", q=123)) | ||
|
||
with self.subTest(msg="set metric_algorithm before rhs"): | ||
self.assertFalse( | ||
test_metricverifier("WDC_satellites.csv", metric_algorithm="approx") | ||
) | ||
|
||
with self.subTest(msg="normal options"): | ||
alg = desb.MetricVerifier() | ||
alg.load_data("WDC_satellites.csv", ",", True) | ||
opts = { | ||
"metric": "cosine", | ||
"rhs_indices": [1], | ||
"parameter": 213.213111, | ||
"dist_from_null_is_infinity": False, | ||
"q": 123, | ||
"metric_algorithm": "approx", | ||
"lhs_indices": [0, 1, 2], | ||
} | ||
for i in opts: | ||
alg.set_option(i, opts[i]) | ||
for i in opts: | ||
self.assertEqual(alg.get_opts()[i], str(opts[i])) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |