diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 0814c758..40f37b5d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -472,6 +472,8 @@ def test_trainer_evaluate_with_strings(model: SetFitModel): # The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." model.predict(["another positive sentence"]) + assert set(model.labels) == {"positive", "negative"} + def test_trainer_evaluate_multilabel_f1(): dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]}) @@ -606,6 +608,7 @@ def test_evaluate_with_strings(model: SetFitModel) -> None: trainer.train() metrics = trainer.evaluate() assert "accuracy" in metrics + assert set(model.labels) == {"positive", "negative"} def test_trainer_wrong_args(model: SetFitModel, tmp_path: Path) -> None: