diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py index c72ebc6..8f90288 100755 --- a/turbo_alignment/common/tf/callbacks/logging.py +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -70,10 +70,7 @@ def _log(self, logs: dict[str, Any], state: TrainerState) -> None: self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step) def _fix_table_type(self, logs: dict[str, Any]) -> dict[str, Any]: - return { - k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v - for k, v in logs.items() - } + return {k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v for k, v in logs.items()} class ClearMLLoggingCallback(LoggingCallback): diff --git a/turbo_alignment/metrics/distinctness.py b/turbo_alignment/metrics/distinctness.py index a1ff2bd..4f10de2 100755 --- a/turbo_alignment/metrics/distinctness.py +++ b/turbo_alignment/metrics/distinctness.py @@ -1,4 +1,5 @@ from collections import defaultdict + from transformers.tokenization_utils_base import PreTrainedTokenizerBase from turbo_alignment.metrics.metric import Metric diff --git a/turbo_alignment/metrics/registry.py b/turbo_alignment/metrics/registry.py index d9f5abd..c99195c 100755 --- a/turbo_alignment/metrics/registry.py +++ b/turbo_alignment/metrics/registry.py @@ -1,4 +1,5 @@ from enum import Enum + from pydantic import field_validator from turbo_alignment.common.registry import Registrable diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 490e6d1..e49231b 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -26,9 +26,9 @@ from turbo_alignment.settings.pipelines.train.dpo import ( APODownLossSettings, APOZeroLossSettings, + ASFTLossSettings, CPOLossSettings, DPOLossesType, - ASFTLossSettings, HingeLossSettings, IPOLossSettings, KTOLossSettings,