Skip to content

Commit

Permalink
⚓️ Add APO losses & flip metrics
Browse files Browse the repository at this point in the history
⚓️ Add APO losses & flip metrics
  • Loading branch information
alekseymalakhov11 authored Sep 16, 2024
2 parents 54bff60 + c7b4cec commit 71ecf0f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
12 changes: 12 additions & 0 deletions turbo_alignment/settings/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class DPOLossesType(str, Enum):
CPO = 'cpo'
ORPO = 'orpo'
SIMPO = 'simpo'
APO_ZERO = 'apo_zero'
APO_DOWN = 'apo_down'


class DPOLossSettings(ExtraFieldsNotAllowedBaseModel):
Expand Down Expand Up @@ -75,6 +77,14 @@ class ORPOLossSettings(DPOLossSettings):
beta: float = 0.1


class APOZeroLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.APO_ZERO]


class APODownLossSettings(DPOLossSettings):
loss_type: Literal[DPOLossesType.APO_DOWN]


class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel):
sync_ref_model: bool = False
alpha: float = 1.0
Expand All @@ -92,6 +102,8 @@ class DPOTrainerSettings(TrainerSettings):
| SimPOLossSettings
| SlicHfLossSettings
| SigmoidLossWithMarginSettings
| APOZeroLossSettings
| APODownLossSettings
)
sync_ref_settings: SyncRefModelSettings
use_ref_model: bool = True
Expand Down
107 changes: 107 additions & 0 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelCallback
from turbo_alignment.constants import DISABLE_LOSS_LABEL
from turbo_alignment.settings.pipelines.train.dpo import (
APODownLossSettings,
APOZeroLossSettings,
CPOLossSettings,
DPOLossesType,
HingeLossSettings,
Expand Down Expand Up @@ -331,6 +333,70 @@ def compute_loss(
)


@DPOLossRegistry.register(DPOLossesType.APO_DOWN)
class APODownLoss(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
self.beta = beta
super().__init__(*args, **kwargs)

def compute_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps

losses_chosen = F.sigmoid(self.beta * chosen_logratios)
losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios))

loss = losses_chosen + losses_rejected

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

return (
loss,
chosen_rewards,
rejected_rewards,
)


@DPOLossRegistry.register(DPOLossesType.APO_ZERO)
class APOZeroLoss(DPOLossRegistry):
def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
self.beta = beta
super().__init__(*args, **kwargs)

def compute_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
precomputed_margins: torch.FloatTensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps

losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios)
losses_rejected = F.sigmoid(self.beta * rejected_logratios)

loss = losses_chosen + losses_rejected

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

return (
loss,
chosen_rewards,
rejected_rewards,
)


@dataclass
class DPOTrainingArguments(TrainingArguments):
loss_settings: (
Expand All @@ -344,6 +410,8 @@ class DPOTrainingArguments(TrainingArguments):
| SimPOLossSettings
| SlicHfLossSettings
| SigmoidLossWithMarginSettings
| APOZeroLossSettings
| APODownLossSettings
) = field(
default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
) # type: ignore[call-overload]
Expand Down Expand Up @@ -544,16 +612,24 @@ def get_batch_metrics(

metrics = self._compute_metrics(metrics, dpo_prefix_name, chosen_rewards, rejected_rewards)

logp_accuracies = (policy_chosen_logps > policy_rejected_logps).float()
metrics[f'{prefix}logps/accuracies'] = (logp_accuracies).detach().cpu().mean().item()
metrics[f'{prefix}logps/rejected'] = (policy_rejected_logps).detach().cpu().mean().item()
metrics[f'{prefix}logps/chosen'] = (policy_chosen_logps).detach().cpu().mean().item()

metrics[f'{prefix}logits/rejected'] = (policy_rejected_logits).detach().cpu().mean().item()
metrics[f'{prefix}logits/chosen'] = (policy_chosen_logits).detach().cpu().mean().item()

if self.args.use_ref_model:
ref_logp_accuracies = (reference_chosen_logps > reference_rejected_logps).float()
metrics[f'{prefix}logps/ref_accuracies'] = (ref_logp_accuracies).detach().cpu().mean().item()
metrics[f'{prefix}logps/ref_rejected'] = (reference_rejected_logps).detach().cpu().mean().item()
metrics[f'{prefix}logps/ref_chosen'] = (reference_chosen_logps).detach().cpu().mean().item()

metrics = self._compute_flips(
metrics, prefix, logp_accuracies.detach().cpu(), ref_logp_accuracies.detach().cpu()
)

if self.loss_type == DPOLossesType.KTO:
kto_chosen_KL = (
(policy_chosen_logps.detach().cpu() - reference_chosen_logps.detach().cpu()).mean().clamp(min=0)
Expand Down Expand Up @@ -622,6 +698,37 @@ def _compute_metrics(

return metrics

def _compute_flips(
self,
metrics: dict[str, Any],
prefix_name: str,
logp_accuracies: torch.Tensor,
ref_logp_accuracies: torch.Tensor,
):
correct_correct = (ref_logp_accuracies == 1) & (logp_accuracies == 1)
correct_incorrect = (ref_logp_accuracies == 1) & (logp_accuracies == 0)
incorrect_correct = (ref_logp_accuracies == 0) & (logp_accuracies == 1)
incorrect_incorrect = (ref_logp_accuracies == 0) & (logp_accuracies == 0)

correct_correct_count = correct_correct.sum().item()
correct_incorrect_count = correct_incorrect.sum().item()
incorrect_correct_count = incorrect_correct.sum().item()
incorrect_incorrect_count = incorrect_incorrect.sum().item()

total_count = len(logp_accuracies)

correct_correct_ratio = correct_correct_count / total_count
correct_incorrect_ratio = correct_incorrect_count / total_count
incorrect_correct_ratio = incorrect_correct_count / total_count
incorrect_incorrect_ratio = incorrect_incorrect_count / total_count

metrics[f'{prefix_name}flips/correct->correct'] = correct_correct_ratio
metrics[f'{prefix_name}flips/correct->incorrect'] = correct_incorrect_ratio
metrics[f'{prefix_name}flips/incorrect->correct'] = incorrect_correct_ratio
metrics[f'{prefix_name}flips/incorrect->incorrect'] = incorrect_incorrect_ratio

return metrics

def compute_loss(
self,
model: PreTrainedModel | nn.Module,
Expand Down

0 comments on commit 71ecf0f

Please sign in to comment.