diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index bc85133..3d35fa3 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -26,6 +26,7 @@ class DPOLossesType(str, Enum): APO_DOWN = 'apo_down' ASFT = 'asft' DPOP = 'dpop' + NCA_PAIT = 'nca_pair' class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -93,6 +94,10 @@ class DPOPLossSettings(DPOLossSettings): lam: float = 0.1 +class NCAPairLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.NCA_PAIT] + + class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): sync_ref_model: bool = False alpha: float = 1.0 @@ -114,6 +119,7 @@ class DPOTrainerSettings(TrainerSettings): | APOZeroLossSettings | APODownLossSettings | DPOPLossSettings + | NCAPairLossSettings ) sync_ref_settings: SyncRefModelSettings use_ref_model: bool = True diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ef3288f..7473987 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -39,6 +39,7 @@ SimPOLossSettings, SlicHfLossSettings, SyncRefModelSettings, + NCAPairLossSettings ) from turbo_alignment.trainers.utils import ( DPOLossRegistry, @@ -456,6 +457,35 @@ def compute_loss( return loss, chosen_rewards, rejected_rewards +@DPOLossRegistry.register(DPOLossesType.NCA_PAIT) +class NCAPairLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: + self.beta = beta + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + precomputed_margins: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + loss = ( + -F.logsigmoid(chosen_logratios * self.beta) + - 0.5 * F.logsigmoid(-chosen_logratios * self.beta) + - 0.5 * F.logsigmoid(-rejected_logratios * self.beta) + ) + + return loss, chosen_rewards, rejected_rewards + + @dataclass class DPOTrainingArguments(TrainingArguments): loss_settings: ( @@ -473,6 +503,7 @@ class DPOTrainingArguments(TrainingArguments): | APOZeroLossSettings | APODownLossSettings | DPOPLossSettings + | NCAPairLossSettings ) = field( default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) ) # type: ignore[call-overload]