Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 10, 2024
1 parent 36de1a0 commit 51062d5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,16 @@ def bert_score(
"""
if len(preds) != len(target):
raise ValueError("Number of predicted and reference sententes must be the same!")
raise ValueError(
"Expected number of predicted and reference sententes to be the same, but got"
f"{len(preds)} and {len(target)}"
)
if not isinstance(preds, (str, list, dict)): # dict for BERTScore class compute call
preds = list(preds)
if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call
target = list(target)
if not isinstance(idf, bool):
raise ValueError(f"The value of idf must be a boolean. Value passed:{idf=}")
raise ValueError(f"Expected argument `idf` to be a boolean, but got {idf}.")

if verbose and (not _TQDM_AVAILABLE):
raise ModuleNotFoundError(
Expand Down
5 changes: 1 addition & 4 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,7 @@ def test_bertscore_differentiability(

@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
"idf",
[(False,), (True,)],
)
@pytest.mark.parametrize("idf", [True, False])
def test_bertscore_sorting(idf: bool):
"""Test that BERTScore is invariant to the order of the inputs."""
short = "Short text"
Expand Down

0 comments on commit 51062d5

Please sign in to comment.