diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py index 1ad484d..c72ebc6 100755 --- a/turbo_alignment/common/tf/callbacks/logging.py +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -46,9 +46,6 @@ def _rewrite_logs(logs: dict[str, Any]) -> dict[str, Any]: cherry_pick_prefix = 'cherry_pick_' cherry_pick_prefix_len = len(cherry_pick_prefix) for k, v in logs.items(): - if isinstance(v, pd.DataFrame): - v = wandb.Table(dataframe=v) - if k.startswith(eval_prefix): rewritten_logs['eval/' + k[eval_prefix_len:]] = v elif k.startswith(test_prefix): @@ -69,9 +66,15 @@ def __init__(self, wandb_run: Run | RunDisabled) -> None: self._wandb_run = wandb_run def _log(self, logs: dict[str, Any], state: TrainerState) -> None: - rewritten_logs: dict[str, Any] = self._rewrite_logs(logs) + rewritten_logs = self._fix_table_type(self._rewrite_logs(logs)) self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step) + def _fix_table_type(self, logs: dict[str, Any]) -> dict[str, Any]: + return { + k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v + for k, v in logs.items() + } + class ClearMLLoggingCallback(LoggingCallback): def __init__(self, clearml_task: Task) -> None: