Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DPO with Reference Model #387

Closed
wants to merge 4 commits into from

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Nov 15, 2024

Summary

Thanks to @ByronHsu, he identified that the implementation in #378 lacked a reference model for DPO, effectively making it a CPO (Contrastive Preference Optimization) instead. To address this issue, I have:

  1. Added a reference model
  2. Implemented ref_chosen_logps and ref_rejected_logps
  3. Incorporated a partial function in the forward pass

These changes ensure that DPO tests and benchmarks now function correctly.

DPO Loss Formulation

As mentioned in the previous PR #378,

In a reference setting, we get the formula:

$$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$

For the loss:

$$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x))) * \beta))$$

This corresponds to the code:

# Policy model log probabilities
policy_chosen_logps = log_probs(policy_chosen_logits)
policy_rejected_logps = log_probs(policy_rejected_logits)

# Reference model log probabilities
ref_chosen_logps = log_probs(ref_chosen_logits)
ref_rejected_logps = log_probs(ref_rejected_logits)

# Compute advantages
chosen_advantages = policy_chosen_logps - ref_chosen_logps
rejected_advantages = policy_rejected_logps - ref_rejected_logps

# policy_chosen_logps - ref_chosen_logps - policy_rejected_logps + ref_rejected_logps
logits_diff = (chosen_advantages - rejected_advantages) * beta

# DPO loss
losses = -F.logsigmoid(logits_diff)

Testing Done

Updated benchmarks:

download
dpo_loss_speed (1)

  • Hardware Type: NVIDIA A100(40G)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
@austin362667
Copy link
Collaborator Author

This is wrong!! The correct implementation is in #405.

What I did wrong:

# This is incorrect:
ref_chosen_logps = torch.randn(B // 2, device="cuda", dtype=dtype)
ref_rejected_logps = torch.randn(B // 2, device="cuda", dtype=dtype)

Why I'm wrong:
I should not create random tensors for ref_chosen_logps and ref_rejected_logps. Here's why:

In DPO, the reference log probabilities MUST come from evaluating the reference model on the same inputs.
My random tensors break the crucial relationships between:

  • The input sequences
  • The policy model's predictions
  • The reference model's predictions

How to fix:
I need to add a reference model flag to switch on/off reference model usage, and compute proper reference logprobs when it's enabled.

Thanks to @shivam15s The correct implementation is already in PR #405. I'll close this PR to align with that approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant