Skip to content

Commit

Permalink
Avoid passing eval_dataset=None to transformers due to >=v4.46.0 crash
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 5, 2024
1 parent b9316f9 commit 475d0c8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
8 changes: 7 additions & 1 deletion sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
"args": args,
"data_collator": data_collator,
"train_dataset": train_dataset,
"eval_dataset": eval_dataset,
"eval_dataset": eval_dataset if eval_dataset is not None or evaluator is None else "dummy",
"model_init": model_init,
"compute_metrics": compute_metrics,
"callbacks": callbacks,
Expand All @@ -222,6 +222,12 @@ def __init__(
else:
super_kwargs["tokenizer"] = tokenizer
super().__init__(**super_kwargs)
# Transformers v4.46.0 introduced a ValueError if `eval_dataset` is None while eval_strategy is not "no",
# but in Sentence Transformers you can also evaluate without an eval_dataset via an evaluator, so we set
# it to "dummy" in that case to avoid the ValueError
if self.eval_dataset == "dummy":
self.eval_dataset = None

# Every Sentence Transformer model can always return a loss, so we set this to True
# to avoid having to specify it in the data collator or model's forward
self.can_return_loss = True
Expand Down
46 changes: 46 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import pytest
import torch
from datasets.dataset_dict import DatasetDict

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.util import is_datasets_available, is_training_available
from tests.utils import SafeTemporaryDirectory
Expand Down Expand Up @@ -230,3 +232,47 @@ def test_trainer(
original_embeddings = original_model.encode("The cat is on the mat.", convert_to_tensor=True)
new_embeddings = model.encode("The cat is on the the mat.", convert_to_tensor=True)
assert not torch.equal(original_embeddings, new_embeddings)


@pytest.mark.parametrize("use_eval_dataset", [True, False])
@pytest.mark.parametrize("use_evaluator", [True, False])
def test_trainer_no_eval_dataset_with_eval_strategy(
stsb_bert_tiny_model: SentenceTransformer,
stsb_dataset_dict: DatasetDict,
use_eval_dataset: bool,
use_evaluator: bool,
tmp_path: Path,
) -> None:
# Expect a crash when `args.eval_strategy` is not "no" but neither `eval_dataset` or `evaluator` is provided
# Otherwise, the trainer should be created without any issues
model = stsb_bert_tiny_model
train_dataset = stsb_dataset_dict["train"].select(range(10))
eval_dataset = stsb_dataset_dict["validation"].select(range(10))
evaluator = EmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=[score / 5 for score in eval_dataset["score"]],
name="stsb-validation",
)
loss = losses.CosineSimilarityLoss(model=model)
args = SentenceTransformerTrainingArguments(output_dir=tmp_path, eval_strategy="steps")

kwargs = {}
if use_eval_dataset:
kwargs["eval_dataset"] = eval_dataset
if use_evaluator:
kwargs["evaluator"] = evaluator

if not use_eval_dataset and not use_evaluator:
context = pytest.raises(ValueError, match=".*`args.eval_strategy`.*")
else:
context = nullcontext()

with context:
SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
**kwargs,
)

0 comments on commit 475d0c8

Please sign in to comment.