Skip to content

Commit

Permalink
Pretty fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Gorbatovski authored and Alexey Gorbatovski committed Sep 20, 2024
1 parent fd779ff commit 2ce346f
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,20 +648,16 @@ def get_batch_metrics(
metrics[f'{prefix}rewards/kto_grad_term_rejected'] = kto_grad_term_rejected.item()

elif self.loss_type == DPOLossesType.ORPO:
labels_w = batch['inputs_w']['labels'][:, 1:].clone()
loss_mask_w = labels_w != DISABLE_LOSS_LABEL
length_norm_policy_chosen_logps = policy_chosen_logps / loss_mask_w.sum(-1)

log_odds = (policy_chosen_logps - policy_rejected_logps) - (
torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7))
- torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7))
)
ratio = -F.logsigmoid(log_odds)
or_loss = self.dpo_loss_registry.beta * ratio
nll_loss = -length_norm_policy_chosen_logps
nll_loss = -policy_chosen_logps

metrics[f'{prefix}orpo/nll_loss'] = nll_loss.clone().detach().cpu().mean().item()
metrics[f'{prefix}orpo/or_loss'] = or_loss.clone().detach().cpu().mean().item()
metrics[f'{prefix}orpo/nll_loss'] = nll_loss.detach().cpu().mean().item()
metrics[f'{prefix}orpo/or_loss'] = or_loss.detach().cpu().mean().item()
metrics[f'{prefix}orpo/ratio'] = (ratio).detach().cpu().mean().item()
metrics[f'{prefix}orpo/log_odds'] = (log_odds).detach().cpu().mean().item()

Expand Down

0 comments on commit 2ce346f

Please sign in to comment.