From fcf4848997491470d3f8873baddd4ba16c4730f3 Mon Sep 17 00:00:00 2001 From: Pavel Geyn Date: Fri, 18 Oct 2024 09:56:38 +0500 Subject: [PATCH] Fix table logging Converting pandas DataFrame to wandb.Table moved to WandbLoggingCallback --- turbo_alignment/common/tf/callbacks/logging.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py index 1ad484d..c72ebc6 100755 --- a/turbo_alignment/common/tf/callbacks/logging.py +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -46,9 +46,6 @@ def _rewrite_logs(logs: dict[str, Any]) -> dict[str, Any]: cherry_pick_prefix = 'cherry_pick_' cherry_pick_prefix_len = len(cherry_pick_prefix) for k, v in logs.items(): - if isinstance(v, pd.DataFrame): - v = wandb.Table(dataframe=v) - if k.startswith(eval_prefix): rewritten_logs['eval/' + k[eval_prefix_len:]] = v elif k.startswith(test_prefix): @@ -69,9 +66,15 @@ def __init__(self, wandb_run: Run | RunDisabled) -> None: self._wandb_run = wandb_run def _log(self, logs: dict[str, Any], state: TrainerState) -> None: - rewritten_logs: dict[str, Any] = self._rewrite_logs(logs) + rewritten_logs = self._fix_table_type(self._rewrite_logs(logs)) self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step) + def _fix_table_type(self, logs: dict[str, Any]) -> dict[str, Any]: + return { + k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v + for k, v in logs.items() + } + class ClearMLLoggingCallback(LoggingCallback): def __init__(self, clearml_task: Task) -> None: