Skip to content

Commit

Permalink
ORPO fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Gorbatovski authored and Alexey Gorbatovski committed Sep 19, 2024
1 parent 71ecf0f commit abc448e
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ def compute_loss(

ratio = -F.logsigmoid(log_odds)
losses = self.beta * ratio

nll_loss = -policy_chosen_logps

losses += nll_loss

chosen_rewards = self.beta * policy_chosen_logps.detach()
rejected_rewards = self.beta * policy_rejected_logps.detach()
Expand Down Expand Up @@ -657,9 +661,11 @@ def get_batch_metrics(
- torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7))
)
ratio = -F.logsigmoid(log_odds)
or_loss = self.dpo_loss_registry.beta * ratio
nll_loss = -length_norm_policy_chosen_logps

metrics[f'{prefix}orpo/nll_loss'] = nll_loss.clone().detach().cpu().mean().item()
metrics[f'{prefix}orpo/or_loss'] = or_loss.clone().detach().cpu().mean().item()
metrics[f'{prefix}orpo/ratio'] = (ratio).detach().cpu().mean().item()
metrics[f'{prefix}orpo/log_odds'] = (log_odds).detach().cpu().mean().item()

Expand Down

0 comments on commit abc448e

Please sign in to comment.