diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index 9464031..84991fd 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -22,6 +22,8 @@ class DPOLossesType(str, Enum): CPO = 'cpo' ORPO = 'orpo' SIMPO = 'simpo' + APO_ZERO = 'apo_zero' + APO_DOWN = 'apo_down' class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -75,6 +77,14 @@ class ORPOLossSettings(DPOLossSettings): beta: float = 0.1 +class APOZeroLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.APO_ZERO] + + +class APODownLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.APO_DOWN] + + class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): sync_ref_model: bool = False alpha: float = 1.0 @@ -92,6 +102,8 @@ class DPOTrainerSettings(TrainerSettings): | SimPOLossSettings | SlicHfLossSettings | SigmoidLossWithMarginSettings + | APOZeroLossSettings + | APODownLossSettings ) sync_ref_settings: SyncRefModelSettings use_ref_model: bool = True diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 47d6beb..a647f77 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -24,6 +24,8 @@ from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelCallback from turbo_alignment.constants import DISABLE_LOSS_LABEL from turbo_alignment.settings.pipelines.train.dpo import ( + APODownLossSettings, + APOZeroLossSettings, CPOLossSettings, DPOLossesType, HingeLossSettings, @@ -331,6 +333,70 @@ def compute_loss( ) +@DPOLossRegistry.register(DPOLossesType.APO_DOWN) +class APODownLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: + self.beta = beta + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + precomputed_margins: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + losses_chosen = F.sigmoid(self.beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) + + loss = losses_chosen + losses_rejected + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return ( + loss, + chosen_rewards, + rejected_rewards, + ) + + +@DPOLossRegistry.register(DPOLossesType.APO_ZERO) +class APOZeroLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: + self.beta = beta + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + precomputed_margins: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) + losses_rejected = F.sigmoid(self.beta * rejected_logratios) + + loss = losses_chosen + losses_rejected + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return ( + loss, + chosen_rewards, + rejected_rewards, + ) + + @dataclass class DPOTrainingArguments(TrainingArguments): loss_settings: ( @@ -344,6 +410,8 @@ class DPOTrainingArguments(TrainingArguments): | SimPOLossSettings | SlicHfLossSettings | SigmoidLossWithMarginSettings + | APOZeroLossSettings + | APODownLossSettings ) = field( default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) ) # type: ignore[call-overload] @@ -544,6 +612,8 @@ def get_batch_metrics( metrics = self._compute_metrics(metrics, dpo_prefix_name, chosen_rewards, rejected_rewards) + logp_accuracies = (policy_chosen_logps > policy_rejected_logps).float() + metrics[f'{prefix}logps/accuracies'] = (logp_accuracies).detach().cpu().mean().item() metrics[f'{prefix}logps/rejected'] = (policy_rejected_logps).detach().cpu().mean().item() metrics[f'{prefix}logps/chosen'] = (policy_chosen_logps).detach().cpu().mean().item() @@ -551,9 +621,15 @@ def get_batch_metrics( metrics[f'{prefix}logits/chosen'] = (policy_chosen_logits).detach().cpu().mean().item() if self.args.use_ref_model: + ref_logp_accuracies = (reference_chosen_logps > reference_rejected_logps).float() + metrics[f'{prefix}logps/ref_accuracies'] = (ref_logp_accuracies).detach().cpu().mean().item() metrics[f'{prefix}logps/ref_rejected'] = (reference_rejected_logps).detach().cpu().mean().item() metrics[f'{prefix}logps/ref_chosen'] = (reference_chosen_logps).detach().cpu().mean().item() + metrics = self._compute_flips( + metrics, prefix, logp_accuracies.detach().cpu(), ref_logp_accuracies.detach().cpu() + ) + if self.loss_type == DPOLossesType.KTO: kto_chosen_KL = ( (policy_chosen_logps.detach().cpu() - reference_chosen_logps.detach().cpu()).mean().clamp(min=0) @@ -622,6 +698,37 @@ def _compute_metrics( return metrics + def _compute_flips( + self, + metrics: dict[str, Any], + prefix_name: str, + logp_accuracies: torch.Tensor, + ref_logp_accuracies: torch.Tensor, + ): + correct_correct = (ref_logp_accuracies == 1) & (logp_accuracies == 1) + correct_incorrect = (ref_logp_accuracies == 1) & (logp_accuracies == 0) + incorrect_correct = (ref_logp_accuracies == 0) & (logp_accuracies == 1) + incorrect_incorrect = (ref_logp_accuracies == 0) & (logp_accuracies == 0) + + correct_correct_count = correct_correct.sum().item() + correct_incorrect_count = correct_incorrect.sum().item() + incorrect_correct_count = incorrect_correct.sum().item() + incorrect_incorrect_count = incorrect_incorrect.sum().item() + + total_count = len(logp_accuracies) + + correct_correct_ratio = correct_correct_count / total_count + correct_incorrect_ratio = correct_incorrect_count / total_count + incorrect_correct_ratio = incorrect_correct_count / total_count + incorrect_incorrect_ratio = incorrect_incorrect_count / total_count + + metrics[f'{prefix_name}flips/correct->correct'] = correct_correct_ratio + metrics[f'{prefix_name}flips/correct->incorrect'] = correct_incorrect_ratio + metrics[f'{prefix_name}flips/incorrect->correct'] = incorrect_correct_ratio + metrics[f'{prefix_name}flips/incorrect->incorrect'] = incorrect_incorrect_ratio + + return metrics + def compute_loss( self, model: PreTrainedModel | nn.Module,