Skip to content

Commit

Permalink
feat: DIA-1410: Multi label text classification (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
hakan458 authored Sep 24, 2024
1 parent 1306798 commit 7bd4384
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
11 changes: 8 additions & 3 deletions evalme/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ClassificationEvalItem(EvalItem):

SHAPE_KEY = 'undefined'

def exact_match(self, item, label_weights=None, per_label=False):
def exact_match(self, item, label_weights=None, per_label=False, label_order_matters=False):
label_weights = label_weights or {}
if self.empty and item.empty:
return {} if per_label else 1
Expand Down Expand Up @@ -42,6 +42,10 @@ def exact_match(self, item, label_weights=None, per_label=False):
region = EvalItem.has_regions([x, y])
if region:
mismatched_spans = not bool(EvalItem.general_iou_by_type(region, x, y))
# If order does not matter, sort labels
if not label_order_matters:
labels = sorted(labels)
y_labels = sorted(y_labels)
# choices are mismatched
if labels != y_labels or mismatched_spans:
if per_label:
Expand Down Expand Up @@ -110,7 +114,8 @@ def _as_pairwise(item, shape_key, **kwargs):
def exact_matching_choices(item_gt, item_pred, label_weights=None, per_label=False, shape_key=None, **kwargs):
return _as_choices(item_gt, shape_key, **kwargs).exact_match(_as_choices(item_pred, shape_key, **kwargs),
label_weights,
per_label=per_label)
per_label=per_label,
label_order_matters=False)


def exact_matching_pairwise(item_gt, item_pred, label_weights=None, per_label=False, shape_key=None, **kwargs):
Expand All @@ -119,7 +124,7 @@ def exact_matching_pairwise(item_gt, item_pred, label_weights=None, per_label=Fa
per_label=per_label)


def naive(x, y, per_label=False, label_order_matters=True, **kwargs):
def naive(x, y, per_label=False, label_order_matters=False, **kwargs):
"""
Naive comparison of annotations
Expand Down
8 changes: 4 additions & 4 deletions evalme/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def test_naive_order_doesnt_matter():
test_x = [first, second]
test_y = [second, first]

assert naive(test_x, test_y, label_order_matters=False) == 1.0
assert naive(test_x, test_y) == 0.0
assert naive(test_x, test_y) == 1.0
assert naive(test_x, test_y, label_order_matters=True) == 0.0


def test_naive_order_doesnt_matter_partial_agreement():
Expand Down Expand Up @@ -343,8 +343,8 @@ def test_naive_order_doesnt_matter_partial_agreement():
test_x = [first, second]
test_y = [second, first_2]

assert naive(test_x, test_y, label_order_matters=False) == 0.5
assert naive(test_x, test_y) == 0.0
assert naive(test_x, test_y) == 0.5
assert naive(test_x, test_y, label_order_matters=True) == 0.0


def test_naive_not_matching_per_label():
Expand Down

0 comments on commit 7bd4384

Please sign in to comment.