From 1aa3d83c47184df41b5479e526eb80a9a936b65c Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Fri, 15 Nov 2024 09:29:31 +0800 Subject: [PATCH] Support Chunked DPO Loss Kernel (#378) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add support for a fused, torch-compiled, and chunked DPO ([Direct Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss kernel, as requested in https://github.com/linkedin/Liger-Kernel/issues/371. This implementation is largely based on the excellent work done on ORPO (https://github.com/linkedin/Liger-Kernel/pull/362) by @shivam15s. ### DPO Loss Formulation In a reference setting (not reference free): $$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$ $$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$ Corresponds to: ```python # Policy model log probabilities policy_chosen_logps = log_probs(policy_chosen_logits) policy_rejected_logps = log_probs(policy_rejected_logits) # Reference model log probabilities ref_chosen_logps = log_probs(ref_chosen_logits) ref_rejected_logps = log_probs(ref_rejected_logits) # Compute advantages chosen_advantages = policy_chosen_logps - ref_chosen_logps rejected_advantages = policy_rejected_logps - ref_rejected_logps # DPO loss logits_diff = (chosen_advantages - rejected_advantages) / beta losses = -F.logsigmoid(logits_diff) ``` In this PR: 1. The above mathematical equation shows that to maximize the reward difference, we get formula: $$r_θ(x_c) - r_θ(x_r)$$ 2. This can be further optimized using just: $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$ 3. So, the code implements: ```python logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) - π_θ(x_r))/β losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff)) ``` 4. Sum up DPO and NLL: $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$ ## Testing Done ![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731) ![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6) - Hardware Type: **NVIDIA L40S (48G)** - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu Co-authored-by: shivam15s --- benchmark/scripts/benchmark_dpo_loss.py | 226 ++++++++++++++++++++++ src/liger_kernel/chunked_loss/dpo_loss.py | 57 ++++++ test/chunked_loss/test_dpo_loss.py | 220 +++++++++++++++++++++ 3 files changed, 503 insertions(+) create mode 100644 benchmark/scripts/benchmark_dpo_loss.py create mode 100644 src/liger_kernel/chunked_loss/dpo_loss.py create mode 100644 test/chunked_loss/test_dpo_loss.py diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py new file mode 100644 index 000000000..537be47bc --- /dev/null +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -0,0 +1,226 @@ +from test.chunked_loss.test_dpo_loss import HF_DPO_Loss + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction + + +class TorchDPOLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + beta: float = 0.1, + ignore_index: int = -100, + bias: bool = False, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index) + + def forward(self, x, target): + return self.dpo_loss.get_batch_loss_metrics( + x, + self.lin.weight, + target, + self.lin.bias if hasattr(self.lin, "bias") else None, + ) + + +class LigerDPOLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + beta: float = 0.1, + ignore_index: int = -100, + bias: bool = False, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.beta = beta + self.ignore_index = ignore_index + + def forward(self, x, target): + return LigerFusedLinearDPOFunction.apply( + x, + self.lin.weight, + target, + self.lin.bias if hasattr(self.lin, "bias") else None, + self.ignore_index, + self.beta, + True, + ) + + +def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + device = "cuda" + torch_dpo_loss = TorchDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + liger_dpo_loss = LigerDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + # Add ignore_index tokens to simulate padding + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + def fwd(): + if provider == "liger": + return liger_dpo_loss(_input, target) + elif provider == "huggingface": + return torch_dpo_loss(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + device = "cuda" + torch_dpo_loss = TorchDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + liger_dpo_loss = LigerDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) + + # Add ignore_index tokens + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + def fwd(): + if provider == "liger": + return liger_dpo_loss(_input, target) + elif provider == "huggingface": + return torch_dpo_loss(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "dpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 6)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 512, + "H": 1024, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dpo_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + + run_benchmarks( + bench_test_fn=bench_memory_dpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py new file mode 100644 index 000000000..150cb9e1c --- /dev/null +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -0,0 +1,57 @@ +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): + + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute DPO loss (Direct Preference Optimization). + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the direct preference loss. + """ + logits_diff = beta * (chosen_logps - rejected_logps) + losses = -F.logsigmoid(logits_diff) + return losses.sum() + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, + compiled=True, + ): + """ + Fused linear layer with DPO (Direct Preference Optimization) loss. + Handles both the forward and backward pass of the final linear layer with DPO loss. + """ + return LigerFusedLinearPreferenceBase.forward( + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + beta=beta, + compiled=compiled, + ) + + @staticmethod + def backward(ctx, grad_output): + # Get gradients for _input, weight, bias, and target from the base class + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + # Return these gradients, followed by None for the remaining inputs + return *grads, None, None, None, None diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py new file mode 100644 index 000000000..0495fa723 --- /dev/null +++ b/test/chunked_loss/test_dpo_loss.py @@ -0,0 +1,220 @@ +from test.utils import assert_verbose_allclose, set_seed +from typing import Tuple + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction + +# set random seed globally +set_seed() + + +class HF_DPO_Loss: + """ + Implementation of the Direct Preference Optimization (DPO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py + """ + + def __init__(self, ignore_index: int = -100, beta: float = 0.1): + self.ignore_index = ignore_index + self.beta = beta + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError( + "Logits (batch and sequence length dim) and labels must have the same shape." + ) + + loss_mask = labels != self.ignore_index + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == self.ignore_index, 0, labels) + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> torch.FloatTensor: + """Compute DPO loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + The losses tensor contains the DPO loss for each example in the batch. + """ + # Derived from https://huggingface.co/papers/2305.18290 + logits_diff = self.beta * (policy_chosen_logps - policy_rejected_logps) + losses = -F.logsigmoid(logits_diff) + return losses + + def concatenated_forward( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + len_chosen = _input.shape[0] // 2 + + outputs = _input @ weight.t() + if bias is not None: + outputs = outputs + bias + all_logits = outputs.float() + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) + + all_logps = self.get_batch_logps( + all_logits, + target, + average_log_prob=True, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) + + def get_batch_loss_metrics( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + + forward_output = self.concatenated_forward(_input, weight, target, bias) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + + losses = self.dpo_loss(policy_chosen_logps, policy_rejected_logps) + # full DPO loss + loss = policy_nll_loss - losses.mean() + return loss + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 2e-2, 5e-1), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): + B = 2 * B # dpo loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HF_DPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( + input1, weight1, target, bias1 + ) + loss2 = LigerFusedLinearDPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, True + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)