Skip to content

Commit

Permalink
Add tests, ensure that labels are set correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Jan 16, 2024
1 parent 174abc4 commit 82c9ddb
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def test_trainer_evaluate_with_strings(model: SetFitModel):
# The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."
model.predict(["another positive sentence"])

assert set(model.labels) == {"positive", "negative"}


def test_trainer_evaluate_multilabel_f1():
dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})
Expand Down Expand Up @@ -606,6 +608,7 @@ def test_evaluate_with_strings(model: SetFitModel) -> None:
trainer.train()
metrics = trainer.evaluate()
assert "accuracy" in metrics
assert set(model.labels) == {"positive", "negative"}


def test_trainer_wrong_args(model: SetFitModel, tmp_path: Path) -> None:
Expand Down

0 comments on commit 82c9ddb

Please sign in to comment.