From 2a39f0dcf8bb04f27b834e60ae26aa404e00cbe9 Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Wed, 20 Nov 2024 21:45:45 -0800 Subject: [PATCH] add nn.module support for chunked loss function (#402) ## Summary Same as title ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/__init__.py | 4 + src/liger_kernel/chunked_loss/cpo_loss.py | 42 ++++- src/liger_kernel/chunked_loss/dpo_loss.py | 39 ++++- src/liger_kernel/chunked_loss/functional.py | 9 ++ .../chunked_loss/fused_linear_preference.py | 2 +- src/liger_kernel/chunked_loss/orpo_loss.py | 40 ++++- src/liger_kernel/chunked_loss/simpo_loss.py | 43 ++++++ test/chunked_loss/test_cpo_loss.py | 144 +++++++++++++++++- test/chunked_loss/test_dpo_loss.py | 137 ++++++++++++++++- test/chunked_loss/test_orpo_loss.py | 140 +++++++++++++++-- test/chunked_loss/test_simpo_loss.py | 122 ++++++++++++++- test/utils.py | 2 +- 12 files changed, 686 insertions(+), 38 deletions(-) create mode 100644 src/liger_kernel/chunked_loss/functional.py diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index e69de29bb..238bdded9 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -0,0 +1,4 @@ +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index cc8bd44ef..84336b4eb 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -1,3 +1,4 @@ +import torch import torch.nn.functional as F from liger_kernel.chunked_loss.fused_linear_preference import ( @@ -46,10 +47,10 @@ def forward( target, bias, loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn, - compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, alpha=alpha, beta=beta, + compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -59,3 +60,42 @@ def backward(ctx, grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] # Return these gradients, followed by None for the remaining inputs return *grads, None, None, None, None, None + + +class LigerFusedLinearCPOLoss(torch.nn.Module): + """ + Fused linear layer with CPO loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + compute_nll_loss: bool = True, + compiled: bool = True, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.alpha = alpha + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + + def forward(self, lin_weight, _input, target, bias=None): + return LigerFusedLinearCPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.alpha, + self.compute_nll_loss, + self.compiled, + ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 150cb9e1c..601c15c3d 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -1,3 +1,4 @@ +import torch import torch.nn.functional as F from liger_kernel.chunked_loss.fused_linear_preference import ( @@ -43,9 +44,9 @@ def forward( target=target, bias=bias, loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn, - compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, + compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -55,3 +56,39 @@ def backward(ctx, grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] # Return these gradients, followed by None for the remaining inputs return *grads, None, None, None, None + + +class LigerFusedLinearDPOLoss(torch.nn.Module): + """ + Fused linear layer with DPO loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compute_nll_loss: bool = True, + compiled: bool = True, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + + def forward(self, lin_weight, _input, target, bias=None): + return LigerFusedLinearDPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.compute_nll_loss, + self.compiled, + ) diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py new file mode 100644 index 000000000..5a51d3f72 --- /dev/null +++ b/src/liger_kernel/chunked_loss/functional.py @@ -0,0 +1,9 @@ +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction + +liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply +liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply +liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply +liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 73981dff4..7dd2af160 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -27,10 +27,10 @@ def forward( bias=None, loss_fn=None, chunk_size=1, - compute_nll_loss=True, ignore_index=-100, alpha=1.0, beta=0.1, + compute_nll_loss=True, compiled=True, **loss_kwargs, ): diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index a921f3f11..d578f1f71 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -34,7 +34,7 @@ def forward( ignore_index=-100, beta=0.1, compute_nll_loss=True, - compiled=False, + compiled=True, ): """ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. @@ -49,9 +49,9 @@ def forward( target=target, bias=bias, loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn, - compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, + compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -61,3 +61,39 @@ def backward(ctx, grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] # Return these gradients, followed by None for the remaining inputs return *grads, None, None, None, None + + +class LigerFusedLinearORPOLoss(torch.nn.Module): + """ + Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compute_nll_loss: bool = True, + compiled: bool = True, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + + def forward(self, lin_weight, _input, target, bias=None): + return LigerFusedLinearORPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.compute_nll_loss, + self.compiled, + ) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index eff581406..1753f7809 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -1,3 +1,4 @@ +import torch import torch.nn.functional as F from liger_kernel.chunked_loss.fused_linear_preference import ( @@ -62,3 +63,45 @@ def backward(ctx, grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] # Return these gradients, followed by None for the remaining inputs return *grads, None, None, None, None, None, None + + +class LigerFusedLinearSimPOLoss(torch.nn.Module): + """ + Fused linear layer with SimPO loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + compute_nll_loss: bool = True, + compiled: bool = True, + gamma: float = 0.5, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.alpha = alpha + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + self.gamma = gamma + + def forward(self, lin_weight, _input, target, bias=None): + return LigerFusedLinearSimPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.alpha, + self.compute_nll_loss, + self.compiled, + self.gamma, + ) diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index b8fce9e06..6f9305ec8 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -5,7 +5,9 @@ import torch import torch.nn.functional as F +from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.chunked_loss.functional import liger_fused_linear_cpo # set random seed globally set_seed() @@ -72,6 +74,57 @@ def alignment_loss( return losses +class TorchLMHeadCPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + loss_type: str = "sigmoid", + simpo_gamma: float = 0.5, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.cpo_loss = HFCPOLoss( + ignore_index=ignore_index, + beta=beta, + loss_type=loss_type, + simpo_gamma=simpo_gamma, + ).get_batch_loss_metrics + + def forward(self, x, y): + return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) + + +class LigerLMHeadCPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.cpo_loss = LigerFusedLinearCPOLoss( + ignore_index=ignore_index, beta=beta, alpha=alpha + ) + + def forward(self, x, y): + return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) + + @pytest.mark.parametrize( "B, T, H, V", [ @@ -95,6 +148,32 @@ def test_correctness( ): B = 2 * B # cpo loss requires B to be even + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device="cuda", dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device="cuda", dtype=dtype + ) + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -114,6 +193,63 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index + loss1 = torch_lm_head_cpo(input1, target) + loss2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_cpo.lin.weight.grad, + liger_lm_head_cpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_cpo.lin.bias.grad, + liger_lm_head_cpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + B = 2 * B + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + _weight = torch.randn(V, H, device="cuda", dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) @@ -122,12 +258,8 @@ def test_correctness( bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1 = HFCPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( - input1, weight1, target, bias1, alpha=alpha - ) - loss2 = LigerFusedLinearCPOFunction.apply( - input2, weight2, target, bias2, ignore_index, beta, alpha, True - ) + loss1 = LigerFusedLinearCPOFunction.apply(input1, weight1, target, bias1) + loss2 = liger_fused_linear_cpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 7f4eef053..e858626fd 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -4,13 +4,15 @@ import torch import torch.nn.functional as F +from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.functional import liger_fused_linear_dpo # set random seed globally set_seed() -class HF_DPO_Loss(HFAlignmentLoss): +class HFDPOLoss(HFAlignmentLoss): """ Implementation of the Odds Ratio Preference Optimization (ORPO) loss, adapted from Hugging Face's implementation. @@ -39,6 +41,48 @@ def alignment_loss( return losses +class TorchLMHeadDPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.dpo_loss = HFDPOLoss( + ignore_index=ignore_index, beta=beta + ).get_batch_loss_metrics + + def forward(self, x, y): + return self.dpo_loss(self.lin.weight, x, y, self.lin.bias) + + +class LigerLMHeadDPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.dpo_loss = LigerFusedLinearDPOLoss(ignore_index=ignore_index, beta=beta) + + def forward(self, x, y): + return self.dpo_loss(self.lin.weight, x, y, self.lin.bias) + + @pytest.mark.parametrize( "B, T, H, V", [ @@ -58,6 +102,32 @@ def alignment_loss( def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): B = 2 * B # dpo loss requires B to be even + torch_lm_head_dpo = TorchLMHeadDPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_dpo = LigerLMHeadDPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn( + V, H, device="cuda", dtype=dtype + ) + + if bias: + torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn( + V, device="cuda", dtype=dtype + ) + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -77,6 +147,63 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index + loss1 = torch_lm_head_dpo(input1, target) + loss2 = liger_lm_head_dpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_dpo.lin.weight.grad, + liger_lm_head_dpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_dpo.lin.bias.grad, + liger_lm_head_dpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + B = 2 * B + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + _weight = torch.randn(V, H, device="cuda", dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) @@ -85,12 +212,8 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1 = HF_DPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( - input1, weight1, target, bias1 - ) - loss2 = LigerFusedLinearDPOFunction.apply( - input2, weight2, target, bias2, ignore_index, beta, True - ) + loss1 = LigerFusedLinearDPOFunction.apply(input1, weight1, target, bias1) + loss2 = liger_fused_linear_dpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 5e532938b..41e6c9421 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -5,6 +5,8 @@ import torch import torch.nn.functional as F +from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_orpo from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction # set random seed globally @@ -57,6 +59,48 @@ def alignment_loss( return losses +class TorchLMHeadORPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.orpo_loss = HFORPOLoss( + ignore_index=ignore_index, beta=beta + ).get_batch_loss_metrics + + def forward(self, x, y): + return self.orpo_loss(self.lin.weight, x, y, self.lin.bias) + + +class LigerLMHeadORPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta) + + def forward(self, x, y): + return self.orpo_loss(self.lin.weight, x, y, self.lin.bias) + + @pytest.mark.parametrize( "B, T, H, V", [ @@ -75,6 +119,31 @@ def alignment_loss( @pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): B = 2 * B # orpo loss requires B to be even + torch_lm_head_orpo = TorchLMHeadORPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_orpo = LigerLMHeadORPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_orpo.lin.weight.data = liger_lm_head_orpo.lin.weight.data = ( + torch.randn(V, H, device="cuda", dtype=dtype) + ) + + if bias: + torch_lm_head_orpo.lin.bias.data = liger_lm_head_orpo.lin.bias.data = ( + torch.randn(V, device="cuda", dtype=dtype) + ) _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) @@ -95,6 +164,63 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index + loss1 = torch_lm_head_orpo(input1, target) + loss2 = liger_lm_head_orpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_orpo.lin.weight.grad, + liger_lm_head_orpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_orpo.lin.bias.grad, + liger_lm_head_orpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + B = 2 * B + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + _weight = torch.randn(V, H, device="cuda", dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) @@ -103,18 +229,8 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1 = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( - input1, weight1, target, bias1 - ) - loss2 = LigerFusedLinearORPOFunction.apply( - input2, - weight2, - target, - bias2, - ignore_index, - beta, - True, - ) + loss1 = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1) + loss2 = liger_fused_linear_orpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 727aaa56e..89658b69c 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -1,15 +1,41 @@ -from test.chunked_loss.test_cpo_loss import HFCPOLoss +from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO from test.utils import assert_verbose_allclose, set_seed import pytest import torch +from liger_kernel.chunked_loss import LigerFusedLinearSimPOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_simpo from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction # set random seed globally set_seed() +class LigerLMHeadSimPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + gamma: float = 0.5, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.simpo_loss = LigerFusedLinearSimPOLoss( + ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma + ) + + def forward(self, x, y): + return self.simpo_loss(self.lin.weight, x, y, self.lin.bias) + + @pytest.mark.parametrize( "B, T, H, V", [ @@ -33,6 +59,35 @@ def test_correctness( ): B = 2 * B # SimPO loss requires B to be even + torch_lm_head_simpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + loss_type="simpo", + simpo_gamma=gamma, + ) + liger_lm_head_simpo = LigerLMHeadSimPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + gamma=gamma, + ) + + torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = ( + torch.randn(V, H, device="cuda", dtype=dtype) + ) + + if bias: + torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = ( + torch.randn(V, device="cuda", dtype=dtype) + ) + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -52,6 +107,63 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index + loss1 = torch_lm_head_simpo(input1, target) + loss2 = liger_lm_head_simpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_simpo.lin.weight.grad, + liger_lm_head_simpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_simpo.lin.bias.grad, + liger_lm_head_simpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + B = 2 * B + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + _weight = torch.randn(V, H, device="cuda", dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) @@ -60,12 +172,8 @@ def test_correctness( bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1 = HFCPOLoss( - ignore_index=ignore_index, beta=beta, simpo_gamma=gamma, loss_type="simpo" - ).get_batch_loss_metrics(input1, weight1, target, bias1) - loss2 = LigerFusedLinearSimPOFunction.apply( - input2, weight2, target, bias2, ignore_index, beta, 1.0, True, True, gamma - ) + loss1 = LigerFusedLinearSimPOFunction.apply(input1, weight1, target, bias1) + loss2 = liger_fused_linear_simpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index f1b919687..e65bbabdc 100644 --- a/test/utils.py +++ b/test/utils.py @@ -458,8 +458,8 @@ def cross_entropy_loss(logits, labels): def get_batch_loss_metrics( self, - _input: torch.FloatTensor, weight: torch.FloatTensor, + _input: torch.FloatTensor, target: torch.LongTensor, bias: torch.FloatTensor = None, alpha: float = 1.0,