-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Chunked DPO Loss Kernel (#378)
## Summary Add support for a fused, torch-compiled, and chunked DPO ([Direct Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss kernel, as requested in #371. This implementation is largely based on the excellent work done on ORPO (#362) by @shivam15s. ### DPO Loss Formulation In a reference setting (not reference free): $$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$ $$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$ Corresponds to: ```python # Policy model log probabilities policy_chosen_logps = log_probs(policy_chosen_logits) policy_rejected_logps = log_probs(policy_rejected_logits) # Reference model log probabilities ref_chosen_logps = log_probs(ref_chosen_logits) ref_rejected_logps = log_probs(ref_rejected_logits) # Compute advantages chosen_advantages = policy_chosen_logps - ref_chosen_logps rejected_advantages = policy_rejected_logps - ref_rejected_logps # DPO loss logits_diff = (chosen_advantages - rejected_advantages) / beta losses = -F.logsigmoid(logits_diff) ``` In this PR: 1. The above mathematical equation shows that to maximize the reward difference, we get formula: $$r_θ(x_c) - r_θ(x_r)$$ 2. This can be further optimized using just: $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$ 3. So, the code implements: ```python logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) - π_θ(x_r))/β losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff)) ``` 4. Sum up DPO and NLL: $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$ ## Testing Done ![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731) ![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6) - Hardware Type: **NVIDIA L40S (48G)** - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]> Co-authored-by: shivam15s <[email protected]>
- Loading branch information
1 parent
2281b7e
commit 1aa3d83
Showing
3 changed files
with
503 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from test.chunked_loss.test_dpo_loss import HF_DPO_Loss | ||
|
||
import torch | ||
import triton | ||
from utils import ( | ||
QUANTILES, | ||
SingleBenchmarkRunInput, | ||
SingleBenchmarkRunOutput, | ||
_test_memory, | ||
parse_benchmark_script_args, | ||
run_benchmarks, | ||
) | ||
|
||
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction | ||
|
||
|
||
class TorchDPOLoss(torch.nn.Module): | ||
def __init__( | ||
self, | ||
H: int, | ||
V: int, | ||
dtype: torch.dtype, | ||
beta: float = 0.1, | ||
ignore_index: int = -100, | ||
bias: bool = False, | ||
): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=bias, dtype=dtype | ||
) | ||
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index) | ||
|
||
def forward(self, x, target): | ||
return self.dpo_loss.get_batch_loss_metrics( | ||
x, | ||
self.lin.weight, | ||
target, | ||
self.lin.bias if hasattr(self.lin, "bias") else None, | ||
) | ||
|
||
|
||
class LigerDPOLoss(torch.nn.Module): | ||
def __init__( | ||
self, | ||
H: int, | ||
V: int, | ||
dtype: torch.dtype, | ||
beta: float = 0.1, | ||
ignore_index: int = -100, | ||
bias: bool = False, | ||
): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=bias, dtype=dtype | ||
) | ||
self.beta = beta | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, x, target): | ||
return LigerFusedLinearDPOFunction.apply( | ||
x, | ||
self.lin.weight, | ||
target, | ||
self.lin.bias if hasattr(self.lin, "bias") else None, | ||
self.ignore_index, | ||
self.beta, | ||
True, | ||
) | ||
|
||
|
||
def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
bias = input.extra_benchmark_config["bias"] | ||
beta = input.extra_benchmark_config["beta"] | ||
ignore_index = input.extra_benchmark_config["ignore_index"] | ||
provider = input.kernel_provider | ||
|
||
device = "cuda" | ||
torch_dpo_loss = TorchDPOLoss( | ||
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | ||
).to(device) | ||
liger_dpo_loss = LigerDPOLoss( | ||
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | ||
).to(device) | ||
|
||
# Input shape: [B, T, H] | ||
_input = torch.randn(B, T, H, device=device, dtype=dtype) | ||
# Target shape: [B, T] | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
# Add ignore_index tokens to simulate padding | ||
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | ||
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | ||
target.view(-1)[indices_to_assign] = ignore_index | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_dpo_loss(_input, target) | ||
elif provider == "huggingface": | ||
return torch_dpo_loss(_input, target) | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | ||
return SingleBenchmarkRunOutput( | ||
y_20=mem_20, | ||
y_50=mem_50, | ||
y_80=mem_80, | ||
) | ||
|
||
|
||
def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
bias = input.extra_benchmark_config["bias"] | ||
beta = input.extra_benchmark_config["beta"] | ||
ignore_index = input.extra_benchmark_config["ignore_index"] | ||
provider = input.kernel_provider | ||
mode = input.kernel_operation_mode | ||
|
||
device = "cuda" | ||
torch_dpo_loss = TorchDPOLoss( | ||
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | ||
).to(device) | ||
liger_dpo_loss = LigerDPOLoss( | ||
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | ||
).to(device) | ||
|
||
# Input shape: [B, T, H] | ||
_input = torch.randn(B, T, H, device=device, dtype=dtype) | ||
|
||
# Target shape: [B, T] | ||
target = torch.randint(V, (B, T), device=device, dtype=torch.long) | ||
|
||
# Add ignore_index tokens | ||
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | ||
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | ||
target.view(-1)[indices_to_assign] = ignore_index | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_dpo_loss(_input, target) | ||
elif provider == "huggingface": | ||
return torch_dpo_loss(_input, target) | ||
|
||
if mode == "forward": | ||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
fwd, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "backward": | ||
y = fwd() | ||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
lambda: y.backward(retain_graph=True), | ||
grad_to_none=[_input], | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "full": | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
full, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
|
||
return SingleBenchmarkRunOutput( | ||
y_20=ms_20, | ||
y_50=ms_50, | ||
y_80=ms_80, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_benchmark_script_args() | ||
|
||
common_configs = { | ||
"kernel_name": "dpo_loss", | ||
"x_name": "B", | ||
"x_label": "Batch Size (B)", | ||
"x_values": [2**i for i in range(1, 6)], | ||
"kernel_providers": ["liger", "huggingface"], | ||
"extra_benchmark_configs": [ | ||
{ | ||
"T": 512, | ||
"H": 1024, | ||
"V": 128256, | ||
"mode": "forward", | ||
"dtype": torch.bfloat16, | ||
"bias": True, | ||
"beta": 0.1, | ||
"ignore_index": 42, | ||
} | ||
], | ||
"overwrite": args.overwrite, | ||
} | ||
|
||
run_benchmarks( | ||
bench_test_fn=bench_speed_dpo_loss, | ||
kernel_operation_modes=["forward", "full"], | ||
metric_name="speed", | ||
metric_unit="ms", | ||
**common_configs | ||
) | ||
|
||
run_benchmarks( | ||
bench_test_fn=bench_memory_dpo_loss, | ||
kernel_operation_modes=["full"], | ||
metric_name="memory", | ||
metric_unit="MB", | ||
**common_configs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch.nn.functional as F | ||
|
||
from liger_kernel.chunked_loss.fused_linear_preference import ( | ||
LigerFusedLinearPreferenceBase, | ||
) | ||
|
||
|
||
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): | ||
|
||
@staticmethod | ||
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): | ||
""" | ||
Compute DPO loss (Direct Preference Optimization). | ||
Args: | ||
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | ||
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | ||
beta (float): Weight for the direct preference loss. | ||
""" | ||
logits_diff = beta * (chosen_logps - rejected_logps) | ||
losses = -F.logsigmoid(logits_diff) | ||
return losses.sum() | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
_input, | ||
weight, | ||
target, | ||
bias=None, | ||
ignore_index=-100, | ||
beta=0.1, | ||
compute_nll_loss=True, | ||
compiled=True, | ||
): | ||
""" | ||
Fused linear layer with DPO (Direct Preference Optimization) loss. | ||
Handles both the forward and backward pass of the final linear layer with DPO loss. | ||
""" | ||
return LigerFusedLinearPreferenceBase.forward( | ||
ctx=ctx, | ||
_input=_input, | ||
weight=weight, | ||
target=target, | ||
bias=bias, | ||
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn, | ||
compute_nll_loss=compute_nll_loss, | ||
ignore_index=ignore_index, | ||
beta=beta, | ||
compiled=compiled, | ||
) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
# Get gradients for _input, weight, bias, and target from the base class | ||
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | ||
# Return these gradients, followed by None for the remaining inputs | ||
return *grads, None, None, None, None |
Oops, something went wrong.