Skip to content

Commit

Permalink
failing tests for bert_score
Browse files Browse the repository at this point in the history
Expected properties which are violated.
  • Loading branch information
Guilherme Paulino-Passos @ DoC-cluster committed Sep 9, 2024
1 parent dc431e2 commit 021f01a
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
from functools import partial
from itertools import product
from typing import Sequence

import pytest
Expand Down Expand Up @@ -190,4 +191,103 @@ def test_bertscore_sorting(idf: bool):
score = metric(preds, targets)

# First index should be the self-comparison - sorting by length should not shuffle this

@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
["idf", "batch_size"],
[(False, 1),
(False, 9),
(True, 1),
(True, 9)],
)
def test_bertscore_most_similar(idf: bool, batch_size: int):
"""Tests that BERTScore actually gives the highest score to self-similarity."""
short = "hello there"
long = "master kenobi"
longer = "general kenobi"

sentences = [short, long, longer]
preds, targets = list(zip(*list(product(sentences,
sentences))))
score = bert_score(preds, targets, idf=idf, lang="en",
rescale_with_baseline=False, batch_size=batch_size)
for i in range(len(preds)):
max_pred = i%(len(sentences))*(1 + len(sentences))
max_target = int(i/(len(sentences)))*(1 + len(sentences))
assert score["f1"][i] <= score["f1"][max_pred], \
f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_pred], targets[max_pred]}\n{i=}{max_pred=}"
assert score["f1"][i] <= score["f1"][max_target], \
f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_target], targets[max_target]}\n{i=}{max_target=}"



@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
["idf", "batch_size"],
[(False, 1),
(False, 9),
(True, 1),
(True, 9)],
)
def test_bertscore_symmetry(idf: bool, batch_size: int):
"""Tests that BERTscore F1 score is symmetric between reference and prediction.
As F1 is symmetric, it should also be symmetric."""

short = "hello there"
long = "master kenobi"
longer = "general kenobi"

sentences = [short, long, longer]
preds, targets = list(zip(*list(product(sentences,
sentences))))
score = bert_score(preds, targets, idf=idf, lang="en",
rescale_with_baseline=False, batch_size=batch_size)
for i in range(len(preds)):
for j in range(len(targets)):
if preds[i] == targets[j] and preds[j] == targets[i]:
assert score['f1'][i] == pytest.approx(score['f1'][j]), \
f"f1 score for {(preds[i], targets[i])} is not the same as {(preds[j], targets[j])}."
pass


@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
["idf", "batch_size"],
[(False, 1),
(False, 3)]
)
def test_bertscore_additional_sentence(idf: bool, batch_size: int):
"""Tests that BERTscore keeps the same scores for previous inputs
by adding additional elements to the input lists. This should be the case for idf=False."""

short = "hello there"
long = "master kenobi"
longer = "general kenobi"

preds = [long,long]
targets = [long,short]

score = bert_score(preds, targets, idf=idf, lang="en",
rescale_with_baseline=False, batch_size=batch_size)

longlong = score["f1"][0]
longshort = score["f1"][1]
# First index should be the self-comparison - sorting by length should not shuffle this
assert longlong > longshort

preds = preds + [short, longer]
targets = targets + [longer, long]

score = bert_score(preds, targets, idf=idf, lang="en",
rescale_with_baseline=False, batch_size=batch_size)

# First two indices should be exactly as in the previous call to metric
assert score["f1"][0] == pytest.approx(longlong)
assert score["f1"][1] == pytest.approx(longshort)
# Indices 1 and 2 should also be smaller than self-comparison.
assert score["f1"][0] > score["f1"][1]
assert score["f1"][0] > score["f1"][2]

0 comments on commit 021f01a

Please sign in to comment.