From 354cb65d6e9f349e84baa30953c91004bcf43ae5 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:49:18 +0200 Subject: [PATCH] [`model cards`] Keep evaluation order in training logs if there's multiple evaluators (#2963) Also rename "loss" to Validation Loss --- sentence_transformers/model_card.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/model_card.py b/sentence_transformers/model_card.py index abe454d1a..92eed253f 100644 --- a/sentence_transformers/model_card.py +++ b/sentence_transformers/model_card.py @@ -152,6 +152,8 @@ def on_evaluate( **kwargs, ) -> None: loss_dict = {" ".join(key.split("_")[1:]): metrics[key] for key in metrics if key.endswith("_loss")} + if len(loss_dict) == 1 and "loss" in loss_dict: + loss_dict = {"Validation Loss": loss_dict["loss"]} if ( model.model_card_data.training_logs and model.model_card_data.training_logs[-1]["Step"] == state.global_step @@ -830,19 +832,25 @@ def try_to_pure_python(value: Any) -> Any: def format_training_logs(self): # Get the keys from all evaluation lines - eval_lines_keys = {key for lines in self.training_logs for key in lines.keys()} + eval_lines_keys = [] + for lines in self.training_logs: + for key in lines.keys(): + if key not in eval_lines_keys: + eval_lines_keys.append(key) # Sort the metric columns: Epoch, Step, Training Loss, Validation Loss, Evaluator results def sort_metrics(key: str) -> str: if key == "Epoch": - return "0" + return 0 if key == "Step": - return "1" + return 1 if key == "Training Loss": - return "2" + return 2 + if key == "Validation Loss": + return 3 if key.endswith("loss"): - return "3" - return key + return 4 + return eval_lines_keys.index(key) + 5 sorted_eval_lines_keys = sorted(eval_lines_keys, key=sort_metrics) training_logs = [