-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
API Refactor - MatcherResults and metrics (#70)
- Loading branch information
1 parent
96430e7
commit dd15f95
Showing
13 changed files
with
659 additions
and
341 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import unittest | ||
import math | ||
|
||
from tests import df1, df2 | ||
from valentine.algorithms.matcher_results import MatcherResults | ||
from valentine.algorithms import JaccardDistanceMatcher | ||
from valentine.metrics import Precision | ||
from valentine import valentine_match | ||
|
||
|
||
class TestMatcherResults(unittest.TestCase): | ||
def setUp(self): | ||
self.matches = valentine_match(df1, df2, JaccardDistanceMatcher()) | ||
self.ground_truth = [ | ||
('Cited by', 'Cited by'), | ||
('Authors', 'Authors'), | ||
('EID', 'EID') | ||
] | ||
|
||
def test_dict(self): | ||
assert isinstance(self.matches, dict) | ||
|
||
def test_get_metrics(self): | ||
metrics = self.matches.get_metrics(self.ground_truth) | ||
assert all([x in metrics for x in {"Precision", "Recall", "F1Score"}]) | ||
|
||
metrics_specific = self.matches.get_metrics(self.ground_truth, metrics={Precision()}) | ||
assert "Precision" in metrics_specific | ||
|
||
def test_one_to_one(self): | ||
m = self.matches | ||
|
||
# Add multiple matches per column | ||
pairs = list(m.keys()) | ||
for (ta, ca), (tb, cb) in pairs: | ||
m[((ta, ca), (tb, cb + 'foo'))] = m[((ta, ca), (tb, cb))] / 2 | ||
|
||
# Verify that len gets corrected from 6 to 3 | ||
m_one_to_one = m.one_to_one() | ||
assert len(m_one_to_one) == 3 and len(m) == 6 | ||
|
||
# Verify that none of the lower similarity "foo" entries made it | ||
for (ta, ca), (tb, cb) in pairs: | ||
assert ((ta, ca), (tb, cb + 'foo')) not in m_one_to_one | ||
|
||
# Verify that the cache resets on a new MatcherResults instance | ||
m_entry = MatcherResults(m) | ||
assert m_entry._cached_one_to_one is None | ||
|
||
# Add one new entry with lower similarity | ||
m_entry[(('table_1', 'BLA'), ('table_2', 'BLA'))] = 0.7214057 | ||
|
||
# Verify that the new one_to_one is different from the old one | ||
m_entry_one_to_one = m_entry.one_to_one() | ||
assert m_one_to_one != m_entry_one_to_one | ||
|
||
# Verify that all remaining values are above the median | ||
median = sorted(list(m_entry.values()), reverse=True)[math.ceil(len(m_entry)/2)] | ||
for k in m_entry_one_to_one: | ||
assert m_entry_one_to_one[k] >= median | ||
|
||
def test_take_top_percent(self): | ||
take_0_percent = self.matches.take_top_percent(0) | ||
assert len(take_0_percent) == 0 | ||
|
||
take_40_percent = self.matches.take_top_percent(40) | ||
assert len(take_40_percent) == 2 | ||
|
||
take_100_percent = self.matches.take_top_percent(100) | ||
assert len(take_100_percent) == len(self.matches) | ||
|
||
def test_take_top_n(self): | ||
take_none = self.matches.take_top_n(0) | ||
assert len(take_none) == 0 | ||
|
||
take_some = self.matches.take_top_n(2) | ||
assert len(take_some) == 2 | ||
|
||
take_all = self.matches.take_top_n(len(self.matches)) | ||
assert len(take_all) == len(self.matches) | ||
|
||
take_more_than_all = self.matches.take_top_n(len(self.matches)+1) | ||
assert len(take_more_than_all) == len(self.matches) | ||
|
||
def test_copy(self): | ||
assert self.matches.get_copy() is not self.matches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,76 @@ | ||
import unittest | ||
from valentine.metrics import * | ||
from valentine.algorithms.matcher_results import MatcherResults | ||
from valentine.metrics.metric_helpers import get_fp, get_tp_fn | ||
|
||
import math | ||
from valentine.metrics.metrics import one_to_one_matches | ||
from copy import deepcopy | ||
class TestMetrics(unittest.TestCase): | ||
def setUp(self): | ||
self.matches = MatcherResults({ | ||
(('table_1', 'Cited by'), ('table_2', 'Cited by')): 0.8374313, | ||
(('table_1', 'Authors'), ('table_2', 'Authors')): 0.83498037, | ||
(('table_1', 'EID'), ('table_2', 'EID')): 0.8214057, | ||
(('table_1', 'Title'), ('table_2', 'DUMMY1')): 0.8214057, | ||
(('table_1', 'Title'), ('table_2', 'DUMMY2')): 0.8114057, | ||
}) | ||
self.ground_truth = [ | ||
('Cited by', 'Cited by'), | ||
('Authors', 'Authors'), | ||
('EID', 'EID'), | ||
('Title', 'Title'), | ||
('DUMMY3', 'DUMMY3') | ||
|
||
matches = { | ||
(('table_1', 'Cited by'), ('table_2', 'Cited by')): 0.8374313, | ||
(('table_1', 'Authors'), ('table_2', 'Authors')): 0.83498037, | ||
(('table_1', 'EID'), ('table_2', 'EID')): 0.8214057, | ||
} | ||
] | ||
|
||
ground_truth = [ | ||
('Cited by', 'Cited by'), | ||
('Authors', 'Authors'), | ||
('EID', 'EID') | ||
] | ||
def test_precision(self): | ||
precision = self.matches.get_metrics(self.ground_truth, metrics={Precision()}) | ||
assert 'Precision' in precision and precision['Precision'] == 0.75 | ||
|
||
precision_not_one_to_one = self.matches.get_metrics(self.ground_truth, metrics={Precision(one_to_one=False)}) | ||
assert 'Precision' in precision_not_one_to_one and precision_not_one_to_one['Precision'] == 0.6 | ||
|
||
class TestMetrics(unittest.TestCase): | ||
def test_recall(self): | ||
recall = self.matches.get_metrics(self.ground_truth, metrics={Recall()}) | ||
assert 'Recall' in recall and recall['Recall'] == 0.6 | ||
|
||
recall_not_one_to_one = self.matches.get_metrics(self.ground_truth, metrics={Recall(one_to_one=False)}) | ||
assert 'Recall' in recall_not_one_to_one and recall_not_one_to_one['Recall'] == 0.6 | ||
|
||
def test_f1(self): | ||
f1 = self.matches.get_metrics(self.ground_truth, metrics={F1Score()}) | ||
assert 'F1Score' in f1 and round(100*f1['F1Score']) == 67 | ||
|
||
f1_not_one_to_one = self.matches.get_metrics(self.ground_truth, metrics={F1Score(one_to_one=False)}) | ||
assert 'F1Score' in f1_not_one_to_one and f1_not_one_to_one['F1Score'] == 0.6 | ||
|
||
def test_precision_top_n_percent(self): | ||
precision_0 = self.matches.get_metrics(self.ground_truth, metrics={PrecisionTopNPercent(n=0)}) | ||
assert 'PrecisionTop0Percent' in precision_0 and precision_0['PrecisionTop0Percent'] == 0 | ||
|
||
def test_one_to_one(self): | ||
m = deepcopy(matches) | ||
precision_50 = self.matches.get_metrics(self.ground_truth, metrics={PrecisionTopNPercent(n=50)}) | ||
assert 'PrecisionTop50Percent' in precision_50 and precision_50['PrecisionTop50Percent'] == 1.0 | ||
|
||
# Add multiple matches per column | ||
pairs = list(m.keys()) | ||
for (ta, ca), (tb, cb) in pairs: | ||
m[((ta, ca), (tb, cb + 'foo'))] = m[((ta, ca), (tb, cb))] / 2 | ||
precision = self.matches.get_metrics(self.ground_truth, metrics={Precision()}) | ||
precision_100 = self.matches.get_metrics(self.ground_truth, metrics={PrecisionTopNPercent(n=100)}) | ||
assert 'PrecisionTop100Percent' in precision_100 and precision_100['PrecisionTop100Percent'] == precision['Precision'] | ||
|
||
# Verify that len gets corrected to 3 | ||
m_one_to_one = one_to_one_matches(m) | ||
assert len(m_one_to_one) == 3 and len(m) == 6 | ||
precision_70_not_one_to_one = self.matches.get_metrics(self.ground_truth, metrics={PrecisionTopNPercent(n=70, one_to_one=False)}) | ||
assert 'PrecisionTop70Percent' in precision_70_not_one_to_one and precision_70_not_one_to_one['PrecisionTop70Percent'] == 0.75 | ||
|
||
# Verify that none of the lower similarity "foo" entries made it | ||
for (ta, ca), (tb, cb) in pairs: | ||
assert ((ta, ca), (tb, cb + 'foo')) not in m_one_to_one | ||
def test_recall_at_size_of_ground_truth(self): | ||
recall = self.matches.get_metrics(self.ground_truth, metrics={RecallAtSizeofGroundTruth()}) | ||
assert 'RecallAtSizeofGroundTruth' in recall and recall['RecallAtSizeofGroundTruth'] == 0.6 | ||
|
||
# Add one new entry with lower similarity | ||
m_entry = deepcopy(matches) | ||
m_entry[(('table_1', 'BLA'), ('table_2', 'BLA'))] = 0.7214057 | ||
def test_metric_helpers(self): | ||
limit = 2 | ||
tp, fn = get_tp_fn(self.matches, self.ground_truth, n=limit) | ||
assert tp <= len(self.ground_truth) and fn <= len(self.ground_truth) | ||
|
||
m_entry_one_to_one = one_to_one_matches(m_entry) | ||
fp = get_fp(self.matches, self.ground_truth, n=limit) | ||
assert fp <= limit | ||
assert tp == 2 and fn == 3 # Since we limit to 2 of the matches | ||
assert fp == 0 | ||
|
||
# Verify that all remaining values are above the median | ||
median = sorted(set(m_entry.values()), reverse=True)[math.ceil(len(m_entry)/2)] | ||
for k in m_entry_one_to_one: | ||
assert m_entry_one_to_one[k] >= median | ||
def test_metric_equals(self): | ||
assert PrecisionTopNPercent(n=10, one_to_one=False) == PrecisionTopNPercent(n=10, one_to_one=False) | ||
assert PrecisionTopNPercent(n=10, one_to_one=False) != PrecisionTopNPercent(n=10, one_to_one=True) | ||
assert PrecisionTopNPercent(n=10, one_to_one=False) != Precision() |
Oops, something went wrong.