From 91ec87390225e6e73efb69fe66bd6a278e9b0c36 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 11 Nov 2024 11:00:37 +0100 Subject: [PATCH] Rename slightly; add docstring which'll go in docs --- sentence_transformers/trainer.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 684429171..b3cd416e9 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -172,10 +172,6 @@ def __init__( "the evaluation loss." ) - # Get a dictionary of the default training arguments, so we can determine which arguments have been changed - # for the model card - default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict() - # If the model ID is set via the SentenceTransformerTrainingArguments, but not via the SentenceTransformerModelCardData, # then we can set it here for the model card regardless if args.hub_model_id and not model.model_card_data.model_id: @@ -277,11 +273,23 @@ def __init__( self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column( eval_dataset, args.prompts, dataset_name="eval" ) - self.model_card_callback_init(default_args_dict) + self.add_model_card_callback() + + def add_model_card_callback(self) -> None: + """ + Add a callback responsible for automatically tracking data required for the automatic model card generation + + This method is called in the ``__init__`` method of the + :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` class. + + .. note:: + + This method can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases + """ + # Get a dictionary of the default training arguments, so we can determine which arguments have been changed + # for the model card + default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict() - def model_card_callback_init(self, default_args_dict: dict[str, Any]) -> None: - # Add a callback responsible for automatically tracking data required for the automatic model card generation - # can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases model_card_callback = ModelCardCallback(self, default_args_dict) self.add_callback(model_card_callback) model_card_callback.on_init_end(self.args, self.state, self.control, self.model)