diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index c5312f0..bc85133 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -25,6 +25,7 @@ class DPOLossesType(str, Enum): APO_ZERO = 'apo_zero' APO_DOWN = 'apo_down' ASFT = 'asft' + DPOP = 'dpop' class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -87,6 +88,11 @@ class APODownLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.APO_DOWN] +class DPOPLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.DPOP] + lam: float = 0.1 + + class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): sync_ref_model: bool = False alpha: float = 1.0 @@ -107,6 +113,7 @@ class DPOTrainerSettings(TrainerSettings): | SigmoidLossWithMarginSettings | APOZeroLossSettings | APODownLossSettings + | DPOPLossSettings ) sync_ref_settings: SyncRefModelSettings use_ref_model: bool = True diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 490e6d1..ef3288f 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -26,9 +26,10 @@ from turbo_alignment.settings.pipelines.train.dpo import ( APODownLossSettings, APOZeroLossSettings, + ASFTLossSettings, CPOLossSettings, DPOLossesType, - ASFTLossSettings, + DPOPLossSettings, HingeLossSettings, IPOLossSettings, KTOLossSettings, @@ -426,6 +427,35 @@ def compute_loss( ) +@DPOLossRegistry.register(DPOLossesType.DPOP) +class DPOPLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, lam: float = 0.1, **kwargs) -> None: + self.beta = beta + self.lam = lam + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + precomputed_margins: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + penalty_term = self.lam * torch.relu(reference_chosen_logps - policy_chosen_logps) + + logits = pi_logratios - ref_logratios + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps - penalty_term).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + loss = -F.logsigmoid(self.beta * (logits - penalty_term)) + + return loss, chosen_rewards, rejected_rewards + + @dataclass class DPOTrainingArguments(TrainingArguments): loss_settings: ( @@ -442,6 +472,7 @@ class DPOTrainingArguments(TrainingArguments): | SigmoidLossWithMarginSettings | APOZeroLossSettings | APODownLossSettings + | DPOPLossSettings ) = field( default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) ) # type: ignore[call-overload]