Skip to content

Commit

Permalink
When calling bert_score separately, it works
Browse files Browse the repository at this point in the history
  • Loading branch information
Guilherme Paulino-Passos @ DoC-cluster committed Sep 13, 2024
1 parent 021f01a commit 9c1f295
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,36 @@ def test_bertscore_most_similar(idf: bool, batch_size: int):
assert score["f1"][i] <= score["f1"][max_target], \
f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_target], targets[max_target]}\n{i=}{max_target=}"

@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
["idf"],
[(False,),
(True,)],
)
def test_bertscore_most_similar_separate_calls(idf: bool):
"""Tests that BERTScore actually gives the highest score to self-similarity."""
short = "hello there"
long = "master kenobi"
longer = "general kenobi"

sentences = [short, long, longer]
pairs_to_compare = product(sentences,
sentences)
preds, targets = list(zip(*list(product(sentences,
sentences))))
score = {"f1": [bert_score([pred],[target], idf=idf, lang="en",
rescale_with_baseline=False)["f1"].item()
for pred, target in pairs_to_compare]}
for i in range(len(preds)):
max_pred = i%(len(sentences))*(1 + len(sentences))
max_target = int(i/(len(sentences)))*(1 + len(sentences))
assert score["f1"][i] <= score["f1"][max_pred], \
f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_pred], targets[max_pred]}\n{i=}{max_pred=}"
assert score["f1"][i] <= score["f1"][max_target], \
f"pair: {preds[i], targets[i]} does not have a lower score than {preds[max_target], targets[max_target]}\n{i=}{max_target=}"


@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -251,7 +279,35 @@ def test_bertscore_symmetry(idf: bool, batch_size: int):
f"f1 score for {(preds[i], targets[i])} is not the same as {(preds[j], targets[j])}."
pass


@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
["idf"],
[(False,),
(True,)],
)
def test_bertscore_symmetry_separate_calls(idf: bool):
"""Tests that BERTscore F1 score is symmetric between reference and prediction.
As F1 is symmetric, it should also be symmetric."""
short = "hello there"
long = "master kenobi"
longer = "general kenobi"

sentences = [short, long, longer]
pairs_to_compare = product(sentences,
sentences)
preds, targets = list(zip(*list(product(sentences,
sentences))))
score = {"f1": [bert_score([pred],[target], idf=idf, lang="en",
rescale_with_baseline=False)["f1"].item()
for pred, target in pairs_to_compare]}
for i in range(len(preds)):
for j in range(len(targets)):
if preds[i] == targets[j] and preds[j] == targets[i]:
assert score['f1'][i] == pytest.approx(score['f1'][j]), \
f"f1 score for {(preds[i], targets[i])} is not the same as {(preds[j], targets[j])}."
pass

@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
Expand Down

0 comments on commit 9c1f295

Please sign in to comment.