From fcf4848997491470d3f8873baddd4ba16c4730f3 Mon Sep 17 00:00:00 2001
From: Pavel Geyn
Date: Fri, 18 Oct 2024 09:56:38 +0500
Subject: [PATCH] Fix table logging
Converting pandas DataFrame to wandb.Table moved to WandbLoggingCallback
---
turbo_alignment/common/tf/callbacks/logging.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py
index 1ad484d..c72ebc6 100755
--- a/turbo_alignment/common/tf/callbacks/logging.py
+++ b/turbo_alignment/common/tf/callbacks/logging.py
@@ -46,9 +46,6 @@ def _rewrite_logs(logs: dict[str, Any]) -> dict[str, Any]:
cherry_pick_prefix = 'cherry_pick_'
cherry_pick_prefix_len = len(cherry_pick_prefix)
for k, v in logs.items():
- if isinstance(v, pd.DataFrame):
- v = wandb.Table(dataframe=v)
-
if k.startswith(eval_prefix):
rewritten_logs['eval/' + k[eval_prefix_len:]] = v
elif k.startswith(test_prefix):
@@ -69,9 +66,15 @@ def __init__(self, wandb_run: Run | RunDisabled) -> None:
self._wandb_run = wandb_run
def _log(self, logs: dict[str, Any], state: TrainerState) -> None:
- rewritten_logs: dict[str, Any] = self._rewrite_logs(logs)
+ rewritten_logs = self._fix_table_type(self._rewrite_logs(logs))
self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step)
+ def _fix_table_type(self, logs: dict[str, Any]) -> dict[str, Any]:
+ return {
+ k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v
+ for k, v in logs.items()
+ }
+
class ClearMLLoggingCallback(LoggingCallback):
def __init__(self, clearml_task: Task) -> None: