diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index cfdb8c743b4..18a1ae94ef6 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -354,6 +354,8 @@ def bert_score( preds = list(preds) if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call target = list(target) + if not isinstance(idf, bool): + raise ValueError(f"The value of idf must be a boolean. Value passed:{idf=}") if verbose and (not _TQDM_AVAILABLE): raise ModuleNotFoundError( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index dfd6d60a0e5..230514b9091 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -175,7 +175,7 @@ def test_bertscore_differentiability( @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.parametrize( - "idf", + ["idf"], [(False,), (True,)], ) def test_bertscore_sorting(idf: bool):