Skip to content

Commit

Permalink
Support Chunked DPO Loss Kernel (#378)
Browse files Browse the repository at this point in the history
## Summary

Add support for a fused, torch-compiled, and chunked DPO ([Direct
Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss
kernel, as requested in
#371.
This implementation is largely based on the excellent work done on ORPO
(#362) by @shivam15s.

### DPO Loss Formulation

In a reference setting (not reference free):

$$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) -
\log(\pi_\theta(y_r|x))$$

$$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) -
\log(\pi_{\theta_{\text{ref}}}(y_c|x)) +
\log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$

Corresponds to:
```python
# Policy model log probabilities
policy_chosen_logps = log_probs(policy_chosen_logits)
policy_rejected_logps = log_probs(policy_rejected_logits)

# Reference model log probabilities
ref_chosen_logps = log_probs(ref_chosen_logits)
ref_rejected_logps = log_probs(ref_rejected_logits)

# Compute advantages
chosen_advantages = policy_chosen_logps - ref_chosen_logps
rejected_advantages = policy_rejected_logps - ref_rejected_logps

# DPO loss
logits_diff = (chosen_advantages - rejected_advantages) / beta
losses = -F.logsigmoid(logits_diff)
```

In this PR:

1. The above mathematical equation shows that to maximize the reward
difference, we get formula:
    $$r_θ(x_c) - r_θ(x_r)$$
2. This can be further optimized using just:
    $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$
3. So, the code implements:
    ```python
logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) -
π_θ(x_r))/β
losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff))
    ```
4. Sum up DPO and NLL:
    $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$

## Testing Done


![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731)

![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6)


- Hardware Type: **NVIDIA L40S (48G)**
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
  • Loading branch information
austin362667 and shivam15s authored Nov 15, 2024
1 parent 2281b7e commit 1aa3d83
Show file tree
Hide file tree
Showing 3 changed files with 503 additions and 0 deletions.
226 changes: 226 additions & 0 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction


class TorchDPOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
beta: float = 0.1,
ignore_index: int = -100,
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)

def forward(self, x, target):
return self.dpo_loss.get_batch_loss_metrics(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
)


class LigerDPOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
beta: float = 0.1,
ignore_index: int = -100,
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.beta = beta
self.ignore_index = ignore_index

def forward(self, x, target):
return LigerFusedLinearDPOFunction.apply(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
self.ignore_index,
self.beta,
True,
)


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

device = "cuda"
torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_dpo_loss = LigerDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"
torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_dpo_loss = LigerDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "dpo_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_dpo_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)

run_benchmarks(
bench_test_fn=bench_memory_dpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
57 changes: 57 additions & 0 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)


class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
"""
Compute DPO loss (Direct Preference Optimization).
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the direct preference loss.
"""
logits_diff = beta * (chosen_logps - rejected_logps)
losses = -F.logsigmoid(logits_diff)
return losses.sum()

@staticmethod
def forward(
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=True,
):
"""
Fused linear layer with DPO (Direct Preference Optimization) loss.
Handles both the forward and backward pass of the final linear layer with DPO loss.
"""
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compiled=compiled,
)

@staticmethod
def backward(ctx, grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None
Loading

0 comments on commit 1aa3d83

Please sign in to comment.