diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index c58b22d..53f4825 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -291,11 +291,7 @@ def compute_loss( ) ratio = -F.logsigmoid(log_odds) - losses = self.beta * ratio - - nll_loss = -policy_chosen_logps - - losses += nll_loss + losses = -policy_chosen_logps + self.beta * ratio chosen_rewards = self.beta * policy_chosen_logps.detach() rejected_rewards = self.beta * policy_rejected_logps.detach()