Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DPO with Reference Model #387

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
import os
import sys

import torch
import triton
Expand All @@ -13,6 +14,8 @@

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchDPOLoss(torch.nn.Module):
def __init__(
Expand All @@ -24,18 +27,22 @@ def __init__(
ignore_index: int = -100,
bias: bool = False,
):
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)

def forward(self, x, target):
def forward(self, x, target, ref_chosen_logps, ref_rejected_logps):
return self.dpo_loss.get_batch_loss_metrics(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
ref_chosen_logps,
ref_rejected_logps,
)


Expand All @@ -51,20 +58,13 @@ def __init__(
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.beta = beta
self.ignore_index = ignore_index
self.dpo_loss = LigerFusedLinearDPOFunction.apply

def forward(self, x, target):
return LigerFusedLinearDPOFunction.apply(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
self.ignore_index,
self.beta,
True,
def forward(self, x, y, ref_chosen_logps, ref_rejected_logps):
return self.dpo_loss(
x, self.lin.weight, y, ref_chosen_logps, ref_rejected_logps
)


Expand Down Expand Up @@ -92,16 +92,19 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

ref_chosen_logps = torch.randn(B // 2, device=device)
ref_rejected_logps = torch.randn(B // 2, device=device)

# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
return liger_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)
return torch_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)

def full():
y = fwd()
Expand Down Expand Up @@ -137,20 +140,22 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

ref_chosen_logps = torch.randn(B // 2, device=device)
ref_rejected_logps = torch.randn(B // 2, device=device)

# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
return liger_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)
return torch_dpo_loss(_input, target, ref_chosen_logps, ref_rejected_logps)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
68 changes: 55 additions & 13 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,42 @@


class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
def preference_loss_fn(
chosen_logps,
rejected_logps,
beta=0.1,
ref_chosen_logps=None,
ref_rejected_logps=None,
):
"""
Compute DPO loss (Direct Preference Optimization).
Compute DPO (Direct Preference Optimization) loss using policy and reference model probabilities.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the direct preference loss.
chosen_logps (torch.Tensor): Policy model avg log probs of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Policy model avg log probs of rejected tokens. Shape: (batch_size,).
beta (float): Temperature parameter for the DPO loss.
ref_chosen_logps (torch.Tensor): Reference model avg log probs of chosen tokens. Shape: (batch_size,).
ref_rejected_logps (torch.Tensor): Reference model avg log probs of rejected tokens. Shape: (batch_size,).
"""
logits_diff = beta * (chosen_logps - rejected_logps)
losses = -F.logsigmoid(logits_diff)
return losses.sum()
if ref_chosen_logps is None or ref_rejected_logps is None:
raise ValueError("Reference model logits are required for DPO loss")

chosen_advantages = chosen_logps - ref_chosen_logps
rejected_advantages = rejected_logps - ref_rejected_logps

logits = beta * (chosen_advantages - rejected_advantages)

loss = -F.logsigmoid(logits).mean()
return loss

@staticmethod
def forward(
ctx,
_input,
weight,
target,
ref_chosen_logps,
ref_rejected_logps,
bias=None,
ignore_index=-100,
beta=0.1,
Expand All @@ -36,14 +52,38 @@ def forward(
"""
Fused linear layer with DPO (Direct Preference Optimization) loss.
Handles both the forward and backward pass of the final linear layer with DPO loss.

Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
ref_chosen_logps (torch.Tensor): Reference model log probs for chosen responses. Shape: (batch_size,).
ref_rejected_logps (torch.Tensor): Reference model log probs for rejected responses. Shape: (batch_size,).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
ignore_index (int): Index to ignore in loss computation.
beta (float): Temperature parameter for DPO loss.
compute_nll_loss (bool): Whether to compute and add NLL loss.
compiled (bool): Whether to use torch.compile for chunk accumulation.
"""
# Create partial function with reference model logits
partial_loss_fn = (
lambda c, r, beta=beta: LigerFusedLinearDPOFunction.preference_loss_fn(
c,
r,
beta=beta,
ref_chosen_logps=ref_chosen_logps,
ref_rejected_logps=ref_rejected_logps,
)
)

return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
loss_fn=partial_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
Expand All @@ -52,10 +92,12 @@ def forward(

@staticmethod
def backward(ctx, grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Get gradients from base class
d_input, d_weight, d_target, d_bias = LigerFusedLinearPreferenceBase.backward(
ctx, grad_output
)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None
return d_input, d_weight, d_target, None, None, d_bias, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand Down
19 changes: 13 additions & 6 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,19 @@ def alignment_loss(
):
"""Compute DPO loss for a batch of policy log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
policy_chosen_logps: Log probs of policy model for chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probs of policy model for rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probs of reference model for chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probs of reference model for rejected responses. Shape: (batch_size,)

Returns:
The losses tensor contains the DPO loss for each example in the batch.
DPO loss for each example in the batch
"""
# Derived from https://huggingface.co/papers/2305.18290
logits_diff = self.beta * (policy_chosen_logps - policy_rejected_logps)

chosen_advantages = policy_chosen_logps - reference_chosen_logps
rejected_advantages = policy_rejected_logps - reference_rejected_logps
logits_diff = self.beta * (chosen_advantages - rejected_advantages)

losses = -F.logsigmoid(logits_diff)
return losses

Expand Down Expand Up @@ -99,7 +104,9 @@ def forward(self, x, y):
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta):
def test_dpo_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta
):
B = 2 * B # dpo loss requires B to be even

torch_lm_head_dpo = TorchLMHeadDPO(
Expand Down
Loading