Skip to content

Commit

Permalink
Merge pull request #48 from pavelgein/fix_cleaml_logging
Browse files Browse the repository at this point in the history
Fix table logging
  • Loading branch information
alekseymalakhov11 authored Oct 25, 2024
2 parents 3b4075e + fcf4848 commit e75a7eb
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions turbo_alignment/common/tf/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def _rewrite_logs(logs: dict[str, Any]) -> dict[str, Any]:
cherry_pick_prefix = 'cherry_pick_'
cherry_pick_prefix_len = len(cherry_pick_prefix)
for k, v in logs.items():
if isinstance(v, pd.DataFrame):
v = wandb.Table(dataframe=v)

if k.startswith(eval_prefix):
rewritten_logs['eval/' + k[eval_prefix_len:]] = v
elif k.startswith(test_prefix):
Expand All @@ -69,9 +66,15 @@ def __init__(self, wandb_run: Run | RunDisabled) -> None:
self._wandb_run = wandb_run

def _log(self, logs: dict[str, Any], state: TrainerState) -> None:
rewritten_logs: dict[str, Any] = self._rewrite_logs(logs)
rewritten_logs = self._fix_table_type(self._rewrite_logs(logs))
self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step)

def _fix_table_type(self, logs: dict[str, Any]) -> dict[str, Any]:
return {
k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v
for k, v in logs.items()
}


class ClearMLLoggingCallback(LoggingCallback):
def __init__(self, clearml_task: Task) -> None:
Expand Down

0 comments on commit e75a7eb

Please sign in to comment.