Skip to content

Commit

Permalink
Rename slightly; add docstring which'll go in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 11, 2024
1 parent 00e2005 commit 91ec873
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ def __init__(
"the evaluation loss."
)

# Get a dictionary of the default training arguments, so we can determine which arguments have been changed
# for the model card
default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict()

# If the model ID is set via the SentenceTransformerTrainingArguments, but not via the SentenceTransformerModelCardData,
# then we can set it here for the model card regardless
if args.hub_model_id and not model.model_card_data.model_id:
Expand Down Expand Up @@ -277,11 +273,23 @@ def __init__(
self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column(
eval_dataset, args.prompts, dataset_name="eval"
)
self.model_card_callback_init(default_args_dict)
self.add_model_card_callback()

def add_model_card_callback(self) -> None:
"""
Add a callback responsible for automatically tracking data required for the automatic model card generation
This method is called in the ``__init__`` method of the
:class:`~sentence_transformers.trainer.SentenceTransformerTrainer` class.
.. note::
This method can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases
"""
# Get a dictionary of the default training arguments, so we can determine which arguments have been changed
# for the model card
default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict()

def model_card_callback_init(self, default_args_dict: dict[str, Any]) -> None:
# Add a callback responsible for automatically tracking data required for the automatic model card generation
# can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases
model_card_callback = ModelCardCallback(self, default_args_dict)
self.add_callback(model_card_callback)
model_card_callback.on_init_end(self.args, self.state, self.control, self.model)
Expand Down

0 comments on commit 91ec873

Please sign in to comment.