diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 53f4825..d08849e 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -648,20 +648,16 @@ def get_batch_metrics( metrics[f'{prefix}rewards/kto_grad_term_rejected'] = kto_grad_term_rejected.item() elif self.loss_type == DPOLossesType.ORPO: - labels_w = batch['inputs_w']['labels'][:, 1:].clone() - loss_mask_w = labels_w != DISABLE_LOSS_LABEL - length_norm_policy_chosen_logps = policy_chosen_logps / loss_mask_w.sum(-1) - log_odds = (policy_chosen_logps - policy_rejected_logps) - ( torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7)) - torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7)) ) ratio = -F.logsigmoid(log_odds) or_loss = self.dpo_loss_registry.beta * ratio - nll_loss = -length_norm_policy_chosen_logps + nll_loss = -policy_chosen_logps - metrics[f'{prefix}orpo/nll_loss'] = nll_loss.clone().detach().cpu().mean().item() - metrics[f'{prefix}orpo/or_loss'] = or_loss.clone().detach().cpu().mean().item() + metrics[f'{prefix}orpo/nll_loss'] = nll_loss.detach().cpu().mean().item() + metrics[f'{prefix}orpo/or_loss'] = or_loss.detach().cpu().mean().item() metrics[f'{prefix}orpo/ratio'] = (ratio).detach().cpu().mean().item() metrics[f'{prefix}orpo/log_odds'] = (log_odds).detach().cpu().mean().item()