From edbbf5f4f1ebdef8f57fb74b77dcffed5048a500 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sun, 8 Dec 2024 20:11:32 +0800 Subject: [PATCH 01/10] add softcapping to preference based fused linear --- .../chunked_loss/fused_linear_preference.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c31cbba8b..cbe910e2e 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -26,11 +26,16 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + softcap=None ): len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: logits_chunk = logits_chunk + bias + if softcap is not None: + logits_chunk = logits_chunk / softcap + logits_chunk = torch.tanh(logits_chunk) + logits_chunk = logits_chunk * softcap log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) chosen_nll_loss = 0.0 @@ -81,6 +86,7 @@ def forward( use_ref_model=False, ref_weight=None, ref_bias=None, + softcap=None, **loss_kwargs, ): """ @@ -103,6 +109,7 @@ def forward( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). loss_kwargs (dict): Other possible arguments that a loss function might need """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU @@ -286,6 +293,7 @@ def _compute_loss( use_ref_model=False, ref_weight=None, ref_bias=None, + softcap=None, **loss_kwargs, ): """ @@ -304,6 +312,7 @@ def _compute_loss( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -319,6 +328,7 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + softcap=softcap ) chosen_nll_loss = ( chosen_nll_loss @@ -346,6 +356,7 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model + softcap=softcap ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps From f3b52318db47103a99f48b9910d35f37a9953786 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Sun, 8 Dec 2024 22:26:28 +0800 Subject: [PATCH 02/10] test scatch (debugging) --- .../test_fused_linear_preference.py | 214 ++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 test/transformers/test_fused_linear_preference.py diff --git a/test/transformers/test_fused_linear_preference.py b/test/transformers/test_fused_linear_preference.py new file mode 100644 index 000000000..64d5a334d --- /dev/null +++ b/test/transformers/test_fused_linear_preference.py @@ -0,0 +1,214 @@ +from test.utils import assert_verbose_allclose, set_seed +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) +from liger_kernel.utils import infer_device + +device = infer_device() + +# set random seed globally +set_seed() + + +class TorchLMHeadPreference(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based preference loss. + + :param H: hidden size + :param V: vocab size + :param bias: whether to use bias + :param beta: weight for the odds ratio loss + :param softcap: scaler for softcapping logits + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + softcap: Optional[float] = None, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.ignore_index = ignore_index + self.beta = beta + self.softcap = softcap + + def forward(self, x, target): + logits = self.lin(x).to(torch.float32) + if self.softcap is not None and self.softcap != 0.0: + logits = self.softcap * torch.tanh(logits / self.softcap) + + log_probs = F.log_softmax(logits, dim=-1) + + len_chosen = target.shape[0] // 2 + loss_mask = target != self.ignore_index + label = torch.where(loss_mask, target, 0) + + per_token_logps = log_probs.gather(-1, label.unsqueeze(-1)).squeeze(-1) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + + chosen_logps = average_log_prob[:len_chosen] + rejected_logps = average_log_prob[len_chosen:] + + # Simple preference loss + preference_loss = -self.beta * (chosen_logps - rejected_logps).mean() + + return preference_loss + + +class LigerLMHeadPreference(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + softcap: Optional[float] = None, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.ignore_index = ignore_index + self.beta = beta + self.softcap = softcap + + def forward(self, x, target): + def simple_preference_loss(chosen_logps, rejected_logps, target, beta=0.1): + return -beta * (chosen_logps - rejected_logps).mean() + + loss, *_ = LigerFusedLinearPreferenceBase.apply( + x, + self.lin.weight, + target, + self.lin.bias, + simple_preference_loss, + chunk_size=1, + ignore_index=self.ignore_index, + beta=self.beta, + compute_nll_loss=False, + compiled=True, + softcap=self.softcap, + ) + return loss + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "ignore_index, beta, softcap", + [ + (-100, 0.1, None), + (42, 0.2, 30.0), # Pass non-default values to ensure all params work + ], +) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + bias, + ignore_index, + beta, + softcap, + atol, + rtol, +): + torch_lm_head = TorchLMHeadPreference( + H=H, + V=V, + bias=bias, + ignore_index=ignore_index, + beta=beta, + softcap=softcap, + dtype=dtype, + ).to(device) + + liger_lm_head = LigerLMHeadPreference( + H=H, + V=V, + bias=bias, + ignore_index=ignore_index, + beta=beta, + softcap=softcap, + dtype=dtype, + ).to(device) + + # init the linear layers with the same weights + torch_lm_head.lin.weight.data = liger_lm_head.lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head.lin.bias.data = liger_lm_head.lin.bias.data = torch.rand( + V, device=device, dtype=dtype + ) + + # Create input tensors + _tensor = torch.randn(B * T * 2, H, device=device, dtype=dtype) * scalar # *2 for chosen/rejected pairs + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + # Create target tensor + target = torch.randint(0, V, (B * T * 2,), device=device, dtype=torch.long) + + # Assign some random elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T, (1,)).item() + indices_to_assign = torch.randperm(B * T * 2)[:num_elements_to_assign] + target[indices_to_assign] = ignore_index + + # Forward pass + output1 = torch_lm_head(_input1, target) + output2 = liger_lm_head(_input2, target) + + # Check outputs match + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + # Backward pass + output1.backward() + output2.backward() + + # Check gradients match + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head.lin.weight.grad, + liger_lm_head.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + if bias: + assert_verbose_allclose( + torch_lm_head.lin.bias.grad, + liger_lm_head.lin.bias.grad, + atol=atol, + rtol=rtol, + ) From be5e271500eeaedace7ddceae13a6c6e7666ba8c Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 11 Dec 2024 15:28:17 +0800 Subject: [PATCH 03/10] move test to chunk loss folder --- .../test_fused_linear_preference.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{transformers => chunked_loss}/test_fused_linear_preference.py (100%) diff --git a/test/transformers/test_fused_linear_preference.py b/test/chunked_loss/test_fused_linear_preference.py similarity index 100% rename from test/transformers/test_fused_linear_preference.py rename to test/chunked_loss/test_fused_linear_preference.py From 4bb5b3b3732b4b33688f924c944a067dbb02fd67 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 11 Dec 2024 17:28:09 +0800 Subject: [PATCH 04/10] apply softcap to all preference loss --- src/liger_kernel/chunked_loss/cpo_loss.py | 7 + src/liger_kernel/chunked_loss/dpo_loss.py | 7 + src/liger_kernel/chunked_loss/orpo_loss.py | 7 + src/liger_kernel/chunked_loss/simpo_loss.py | 7 + .../test_fused_linear_preference.py | 214 ------------------ 5 files changed, 28 insertions(+), 214 deletions(-) delete mode 100644 test/chunked_loss/test_fused_linear_preference.py diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 4f68e0b16..2eec9aad8 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -33,6 +34,7 @@ def forward( alpha=1.0, compute_nll_loss=True, compiled=True, + softcap=None, ): """ Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss. @@ -52,6 +54,7 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + softcap=softcap ) @staticmethod @@ -74,11 +77,13 @@ def __init__( alpha: float = 1.0, compute_nll_loss: bool = True, compiled: bool = True, + softcap: Optional[float] = None ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). """ super().__init__() self.ignore_index = ignore_index @@ -86,6 +91,7 @@ def __init__( self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.softcap = softcap def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCPOFunction.apply( @@ -98,4 +104,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.alpha, self.compute_nll_loss, self.compiled, + self.softcap, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 9e41d38c5..40a8069ea 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -52,6 +53,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=True, + softcap=None ): """ Fused linear layer with DPO (Direct Preference Optimization) loss. @@ -71,6 +73,7 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, + softcap=softcap, ) @staticmethod @@ -93,6 +96,7 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, use_ref_model: bool = False, + softcap: Optional[float] = None ): """ Args: @@ -101,6 +105,7 @@ def __init__( compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). """ super().__init__() self.ignore_index = ignore_index @@ -108,6 +113,7 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model + self.softcap = softcap def forward( self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None @@ -124,4 +130,5 @@ def forward( self.compute_nll_loss, self.compiled, self.use_ref_model, + self.softcap, ) diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 9e7caec19..f67799a1a 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -43,6 +44,7 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, + softcap=None ): """ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. @@ -61,6 +63,7 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + softcap=softcap, ) @staticmethod @@ -82,17 +85,20 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, + softcap: Optional[float] = None ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.softcap = softcap def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearORPOFunction.apply( @@ -104,4 +110,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.beta, self.compute_nll_loss, self.compiled, + self.softcap, ) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index c9c1459d6..e3cab4493 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -37,6 +38,7 @@ def forward( compute_nll_loss=False, compiled=True, gamma=0.5, + softcap=None, ): """ Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734 @@ -57,6 +59,7 @@ def forward( beta=beta, compiled=compiled, gamma=gamma, + softcap=softcap ) @staticmethod @@ -80,11 +83,13 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, + softcap: Optional[float] = None ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). """ super().__init__() self.ignore_index = ignore_index @@ -93,6 +98,7 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.gamma = gamma + self.softcap = softcap def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearSimPOFunction.apply( @@ -106,4 +112,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.compute_nll_loss, self.compiled, self.gamma, + self.softcap, ) diff --git a/test/chunked_loss/test_fused_linear_preference.py b/test/chunked_loss/test_fused_linear_preference.py deleted file mode 100644 index 64d5a334d..000000000 --- a/test/chunked_loss/test_fused_linear_preference.py +++ /dev/null @@ -1,214 +0,0 @@ -from test.utils import assert_verbose_allclose, set_seed -from typing import Optional - -import pytest -import torch -import torch.nn.functional as F - -from liger_kernel.chunked_loss.fused_linear_preference import ( - LigerFusedLinearPreferenceBase, -) -from liger_kernel.utils import infer_device - -device = infer_device() - -# set random seed globally -set_seed() - - -class TorchLMHeadPreference(torch.nn.Module): - """Ground truth implementation of the linear fused with torch based preference loss. - - :param H: hidden size - :param V: vocab size - :param bias: whether to use bias - :param beta: weight for the odds ratio loss - :param softcap: scaler for softcapping logits - """ - - def __init__( - self, - H: int, - V: int, - dtype: torch.dtype, - bias: bool = False, - ignore_index: int = -100, - beta: float = 0.1, - softcap: Optional[float] = None, - ): - super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) - self.ignore_index = ignore_index - self.beta = beta - self.softcap = softcap - - def forward(self, x, target): - logits = self.lin(x).to(torch.float32) - if self.softcap is not None and self.softcap != 0.0: - logits = self.softcap * torch.tanh(logits / self.softcap) - - log_probs = F.log_softmax(logits, dim=-1) - - len_chosen = target.shape[0] // 2 - loss_mask = target != self.ignore_index - label = torch.where(loss_mask, target, 0) - - per_token_logps = log_probs.gather(-1, label.unsqueeze(-1)).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - - chosen_logps = average_log_prob[:len_chosen] - rejected_logps = average_log_prob[len_chosen:] - - # Simple preference loss - preference_loss = -self.beta * (chosen_logps - rejected_logps).mean() - - return preference_loss - - -class LigerLMHeadPreference(torch.nn.Module): - def __init__( - self, - H: int, - V: int, - dtype: torch.dtype, - bias: bool = False, - ignore_index: int = -100, - beta: float = 0.1, - softcap: Optional[float] = None, - ): - super().__init__() - self.lin = torch.nn.Linear( - in_features=H, out_features=V, bias=bias, dtype=dtype - ) - self.ignore_index = ignore_index - self.beta = beta - self.softcap = softcap - - def forward(self, x, target): - def simple_preference_loss(chosen_logps, rejected_logps, target, beta=0.1): - return -beta * (chosen_logps - rejected_logps).mean() - - loss, *_ = LigerFusedLinearPreferenceBase.apply( - x, - self.lin.weight, - target, - self.lin.bias, - simple_preference_loss, - chunk_size=1, - ignore_index=self.ignore_index, - beta=self.beta, - compute_nll_loss=False, - compiled=True, - softcap=self.softcap, - ) - return loss - - -@pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (4, 47, 31, 123), # random shape - ], -) -@pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-2), - (1.0, torch.float32, 1e-5, 5e-4), - ], -) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize( - "ignore_index, beta, softcap", - [ - (-100, 0.1, None), - (42, 0.2, 30.0), # Pass non-default values to ensure all params work - ], -) -def test_correctness( - B, - T, - H, - V, - scalar, - dtype, - bias, - ignore_index, - beta, - softcap, - atol, - rtol, -): - torch_lm_head = TorchLMHeadPreference( - H=H, - V=V, - bias=bias, - ignore_index=ignore_index, - beta=beta, - softcap=softcap, - dtype=dtype, - ).to(device) - - liger_lm_head = LigerLMHeadPreference( - H=H, - V=V, - bias=bias, - ignore_index=ignore_index, - beta=beta, - softcap=softcap, - dtype=dtype, - ).to(device) - - # init the linear layers with the same weights - torch_lm_head.lin.weight.data = liger_lm_head.lin.weight.data = torch.rand( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head.lin.bias.data = liger_lm_head.lin.bias.data = torch.rand( - V, device=device, dtype=dtype - ) - - # Create input tensors - _tensor = torch.randn(B * T * 2, H, device=device, dtype=dtype) * scalar # *2 for chosen/rejected pairs - _input1 = _tensor.detach().clone().requires_grad_(True) - _input2 = _tensor.detach().clone().requires_grad_(True) - - # Create target tensor - target = torch.randint(0, V, (B * T * 2,), device=device, dtype=torch.long) - - # Assign some random elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T, (1,)).item() - indices_to_assign = torch.randperm(B * T * 2)[:num_elements_to_assign] - target[indices_to_assign] = ignore_index - - # Forward pass - output1 = torch_lm_head(_input1, target) - output2 = liger_lm_head(_input2, target) - - # Check outputs match - assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - - # Backward pass - output1.backward() - output2.backward() - - # Check gradients match - assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) - assert_verbose_allclose( - torch_lm_head.lin.weight.grad, - liger_lm_head.lin.weight.grad, - atol=atol, - rtol=rtol, - ) - - if bias: - assert_verbose_allclose( - torch_lm_head.lin.bias.grad, - liger_lm_head.lin.bias.grad, - atol=atol, - rtol=rtol, - ) From cb11e26124f7bc1430be35436da5404f59264cda Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Thu, 12 Dec 2024 10:59:40 +0800 Subject: [PATCH 05/10] fix inconsistant gradients --- log | 1991 +++++++++++++++++++ src/liger_kernel/chunked_loss/cpo_loss.py | 2 +- src/liger_kernel/chunked_loss/dpo_loss.py | 2 +- src/liger_kernel/chunked_loss/orpo_loss.py | 2 +- src/liger_kernel/chunked_loss/simpo_loss.py | 2 +- 5 files changed, 1995 insertions(+), 4 deletions(-) create mode 100644 log diff --git a/log b/log new file mode 100644 index 000000000..da042a6a0 --- /dev/null +++ b/log @@ -0,0 +1,1991 @@ +============================= test session starts ============================== +platform linux -- Python 3.12.7, pytest-8.3.4, pluggy-1.5.0 +rootdir: /home/ryan/Documents/GitHub/Liger-Kernel +configfile: pyproject.toml +collected 24 items + +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 4%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 8%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 12%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 16%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 20%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 25%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 29%] +test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 33%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 37%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 41%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 45%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 50%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 54%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 58%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 62%] +test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 66%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype0-0.05-0.5-2-2-8-8] PASSED [ 70%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype0-0.05-0.5-3-47-31-123] PASSED [ 75%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype1-1e-05-0.0005-2-2-8-8] PASSED [ 79%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype1-1e-05-0.0005-3-47-31-123] PASSED [ 83%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype0-0.05-0.5-2-2-8-8] PASSED [ 87%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype0-0.05-0.5-3-47-31-123] PASSED [ 91%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype1-1e-05-0.0005-2-2-8-8] PASSED [ 95%] +test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype1-1e-05-0.0005-3-47-31-123] PASSED [100%] + +=================================== FAILURES =================================== +__ test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] __ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = True, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(116.6298, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +____ test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-3-47-31-123] ____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = True, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(16.4326, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +_ test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] __ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(116.1180, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +___ test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] ____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(16.4060, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +_ test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] __ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = False, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(119.2628, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +___ test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] ____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = False, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(15.0560, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +_ test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] _ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(115.2395, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +___ test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] ___ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = -100, beta = 0.1 +alpha = 1.0 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(15.0284, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +__ test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] ___ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = True, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(117.2575, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +----------------------------- Captured stderr call ----------------------------- +W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] torch._dynamo hit config.cache_size_limit (8) +W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] function: 'fused_fwd_bwd' (/home/ryan/Documents/GitHub/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py:103) +W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] last reason: 0/0: L['compute_loss'].keywords['ignore_index'] == -100 +W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html. +____ test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-3-47-31-123] _____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = True, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(15.0183, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +__ test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] __ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(116.7569, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +____ test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] ____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(16.0889, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +__ test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] __ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = False, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(116.8942, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +____ test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-3-47-31-123] ____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 +atol = 0.005, rtol = 0.005, bias = False, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(15.8006, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +_ test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] __ + +B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(115.5728, device='cuda:0', + grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +___ test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] ____ + +B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 +atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = 42, beta = 0.2 +alpha = 0.85 + + @pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], + ) + @pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] + ) + def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + ): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() +> loss2.backward() + +test/chunked_loss/test_cpo_loss.py:214: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward + torch.autograd.backward( +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward + _engine_run_backward( +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +t_outputs = (tensor(14.6347, device='cuda:0', grad_fn=),) +args = ((tensor(1., device='cuda:0'),), False, False, ()) +kwargs = {'accumulate_grad': True, 'allow_unreachable': True} +attach_logging_hooks = False + + def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, + ) -> Tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: +> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass +E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) + +../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError +=============================== warnings summary =============================== +../../../miniconda3/lib/python3.12/site-packages/_pytest/config/__init__.py:1441 + /home/ryan/miniconda3/lib/python3.12/site-packages/_pytest/config/__init__.py:1441: PytestConfigWarning: Unknown config option: asyncio_mode + + self._warn_or_fail_if_strict(f"Unknown config option: {key}\n") + +-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html +=========================== short test summary info ============================ +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-3-47-31-123] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] +FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] +=================== 16 failed, 8 passed, 1 warning in 6.01s ==================== diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 4a407597d..7daae3816 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -68,7 +68,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 19a6bdfb2..4c3e9feee 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -89,7 +89,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 819868f78..9c652a010 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -77,7 +77,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None + return *grads, None, None, None, None, None class LigerFusedLinearORPOLoss(torch.nn.Module): diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index d27a2d164..8a378288d 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -73,7 +73,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearSimPOLoss(torch.nn.Module): From 3f58a61713fe07651943ccc2224823fe51e6e6ec Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Thu, 12 Dec 2024 11:00:09 +0800 Subject: [PATCH 06/10] remove redundant file --- log | 1991 ----------------------------------------------------------- 1 file changed, 1991 deletions(-) delete mode 100644 log diff --git a/log b/log deleted file mode 100644 index da042a6a0..000000000 --- a/log +++ /dev/null @@ -1,1991 +0,0 @@ -============================= test session starts ============================== -platform linux -- Python 3.12.7, pytest-8.3.4, pluggy-1.5.0 -rootdir: /home/ryan/Documents/GitHub/Liger-Kernel -configfile: pyproject.toml -collected 24 items - -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 4%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 8%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 12%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 16%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 20%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 25%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 29%] -test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 33%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 37%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 41%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 45%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 50%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] FAILED [ 54%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-3-47-31-123] FAILED [ 58%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] FAILED [ 62%] -test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] FAILED [ 66%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype0-0.05-0.5-2-2-8-8] PASSED [ 70%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype0-0.05-0.5-3-47-31-123] PASSED [ 75%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype1-1e-05-0.0005-2-2-8-8] PASSED [ 79%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[True-1.0-dtype1-1e-05-0.0005-3-47-31-123] PASSED [ 83%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype0-0.05-0.5-2-2-8-8] PASSED [ 87%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype0-0.05-0.5-3-47-31-123] PASSED [ 91%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype1-1e-05-0.0005-2-2-8-8] PASSED [ 95%] -test/chunked_loss/test_cpo_loss.py::test_correctness_functional[False-1.0-dtype1-1e-05-0.0005-3-47-31-123] PASSED [100%] - -=================================== FAILURES =================================== -__ test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] __ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = True, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(116.6298, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -____ test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-3-47-31-123] ____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = True, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(16.4326, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -_ test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] __ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(116.1180, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -___ test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] ____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(16.4060, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -_ test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] __ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = False, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(119.2628, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -___ test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] ____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = False, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(15.0560, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -_ test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] _ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(115.2395, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -___ test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] ___ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = -100, beta = 0.1 -alpha = 1.0 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(15.0284, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -__ test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] ___ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = True, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(117.2575, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError ------------------------------ Captured stderr call ----------------------------- -W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] torch._dynamo hit config.cache_size_limit (8) -W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] function: 'fused_fwd_bwd' (/home/ryan/Documents/GitHub/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py:103) -W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] last reason: 0/0: L['compute_loss'].keywords['ignore_index'] == -100 -W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". -W1211 23:18:09.758000 217831 site-packages/torch/_dynamo/convert_frame.py:844] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html. -____ test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-3-47-31-123] _____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = True, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(15.0183, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -__ test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] __ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(116.7569, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -____ test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] ____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = True, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(16.0889, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -__ test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] __ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = False, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(116.8942, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -____ test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-3-47-31-123] ____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.bfloat16 -atol = 0.005, rtol = 0.005, bias = False, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(15.8006, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -_ test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] __ - -B = 16, T = 128, H = 1024, V = 4096, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(115.5728, device='cuda:0', - grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -___ test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] ____ - -B = 6, T = 47, H = 31, V = 123, scalar = 1.0, dtype = torch.float32 -atol = 1e-05, rtol = 0.0005, bias = False, ignore_index = 42, beta = 0.2 -alpha = 0.85 - - @pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], - ) - @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), - ], - ) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.parametrize( - "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] - ) - def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha - ): - B = 2 * B # cpo loss requires B to be even - - torch_lm_head_cpo = TorchLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - liger_lm_head_cpo = LigerLMHeadCPO( - H=H, - V=V, - dtype=dtype, - bias=bias, - ignore_index=ignore_index, - beta=beta, - ) - - torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( - V, H, device=device, dtype=dtype - ) - - if bias: - torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn( - V, device=device, dtype=dtype - ) - - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device=device, - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) - - for i in range(len(aggregated_aux_outputs1)): - assert_verbose_allclose( - aggregated_aux_outputs1[i], - aggregated_aux_outputs2[i], - atol=atol, - rtol=rtol, - ) - - loss1.backward() -> loss2.backward() - -test/chunked_loss/test_cpo_loss.py:214: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../../../miniconda3/lib/python3.12/site-packages/torch/_tensor.py:581: in backward - torch.autograd.backward( -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:347: in backward - _engine_run_backward( -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -t_outputs = (tensor(14.6347, device='cuda:0', grad_fn=),) -args = ((tensor(1., device='cuda:0'),), False, False, ()) -kwargs = {'accumulate_grad': True, 'allow_unreachable': True} -attach_logging_hooks = False - - def _engine_run_backward( - t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], - *args: Any, - **kwargs: Any, - ) -> Tuple[torch.Tensor, ...]: - attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG - if attach_logging_hooks: - unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) - try: -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - t_outputs, *args, **kwargs - ) # Calls into the C++ engine to run the backward pass -E RuntimeError: function LigerFusedLinearCPOFunctionBackward returned an incorrect number of gradients (expected 10, got 9) - -../../../miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:825: RuntimeError -=============================== warnings summary =============================== -../../../miniconda3/lib/python3.12/site-packages/_pytest/config/__init__.py:1441 - /home/ryan/miniconda3/lib/python3.12/site-packages/_pytest/config/__init__.py:1441: PytestConfigWarning: Unknown config option: asyncio_mode - - self._warn_or_fail_if_strict(f"Unknown config option: {key}\n") - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -=========================== short test summary info ============================ -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype0-0.005-0.005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype0-0.005-0.005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-True-1.0-dtype1-1e-05-0.0005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype0-0.005-0.005-3-47-31-123] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096] -FAILED test/chunked_loss/test_cpo_loss.py::test_correctness[42-0.2-0.85-False-1.0-dtype1-1e-05-0.0005-3-47-31-123] -=================== 16 failed, 8 passed, 1 warning in 6.01s ==================== From 37d7fc4646dc778c99b4a6e510930cf56728d1e8 Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Thu, 12 Dec 2024 11:02:14 +0800 Subject: [PATCH 07/10] checkstyle --- src/liger_kernel/chunked_loss/cpo_loss.py | 7 ++++--- src/liger_kernel/chunked_loss/dpo_loss.py | 7 ++++--- src/liger_kernel/chunked_loss/fused_linear_preference.py | 4 ++-- src/liger_kernel/chunked_loss/orpo_loss.py | 7 ++++--- src/liger_kernel/chunked_loss/simpo_loss.py | 7 ++++--- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 7daae3816..6e5e263ef 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -1,6 +1,7 @@ +from typing import Optional + import torch import torch.nn.functional as F -from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -62,7 +63,7 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, - softcap=softcap + softcap=softcap, ) @staticmethod @@ -83,7 +84,7 @@ def __init__( alpha: float = 1.0, compute_nll_loss: bool = True, compiled: bool = True, - softcap: Optional[float] = None + softcap: Optional[float] = None, ): """ Args: diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 4c3e9feee..fd17ea2ba 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -1,6 +1,7 @@ +from typing import Optional + import torch import torch.nn.functional as F -from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -67,7 +68,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=True, - softcap=None + softcap=None, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -104,7 +105,7 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, use_ref_model: bool = False, - softcap: Optional[float] = None + softcap: Optional[float] = None, ): """ Args: diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 87398e3f6..d3ffc4870 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -345,7 +345,7 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, - softcap=softcap + softcap=softcap, ) chosen_nll_loss = ( chosen_nll_loss @@ -373,7 +373,7 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model - softcap=softcap + softcap=softcap, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 9c652a010..594d49f0f 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -1,6 +1,7 @@ +from typing import Optional + import torch import torch.nn.functional as F -from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -58,7 +59,7 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, - softcap=None + softcap=None, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -91,7 +92,7 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, - softcap: Optional[float] = None + softcap: Optional[float] = None, ): """ Args: diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 8a378288d..1953b04ca 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -1,6 +1,7 @@ +from typing import Optional + import torch import torch.nn.functional as F -from typing import Optional from liger_kernel.chunked_loss.fused_linear_preference import ( LigerFusedLinearPreferenceBase, @@ -67,7 +68,7 @@ def forward( beta=beta, compiled=compiled, gamma=gamma, - softcap=softcap + softcap=softcap, ) @staticmethod @@ -89,7 +90,7 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, - softcap: Optional[float] = None + softcap: Optional[float] = None, ): """ Args: From 0df301c3233fd20c960a0148d4faa0e0d375672d Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Thu, 12 Dec 2024 23:23:35 +0800 Subject: [PATCH 08/10] fix dpo loss conflicts --- src/liger_kernel/chunked_loss/dpo_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 8900221d6..d22103f57 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -92,7 +92,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): From 6d2a9350dfcf15dbafdb0ee578ec5d917319b1ca Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Fri, 13 Dec 2024 02:31:34 +0800 Subject: [PATCH 09/10] softcap + cpo test --- .../chunked_loss/fused_linear_preference.py | 12 +++------- test/chunked_loss/test_cpo_loss.py | 22 ++++++++++++++----- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c65702253..aca37bb1a 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -183,6 +183,9 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): if compiled: fused_fwd_bwd = torch.compile(fused_fwd_bwd) + if softcap is not None: + _input = softcap * torch.tanh(_input / softcap) + len_chosen = target.shape[0] // 2 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) @@ -284,16 +287,11 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, - softcap=None, ): len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: logits_chunk = logits_chunk + bias - if softcap is not None: - logits_chunk = logits_chunk / softcap - logits_chunk = torch.tanh(logits_chunk) - logits_chunk = logits_chunk * softcap log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) chosen_nll_loss = 0.0 @@ -343,7 +341,6 @@ def _compute_loss( ref_input_chunk=None, ref_weight=None, ref_bias=None, - softcap=None, **loss_kwargs, ): """ @@ -362,7 +359,6 @@ def _compute_loss( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). - softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -378,7 +374,6 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, - softcap=softcap, ) chosen_nll_loss = ( chosen_nll_loss @@ -406,7 +401,6 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model - softcap=softcap, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index f0fef7734..3ca2ca1bd 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -1,5 +1,5 @@ from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed -from typing import Tuple +from typing import Tuple, Optional import pytest import torch @@ -88,6 +88,7 @@ def __init__( alpha: float = 1.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, + softcap: Optional[float] = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -99,8 +100,12 @@ def __init__( loss_type=loss_type, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics + self.softcap = softcap def forward(self, x, y): + logits = self.lin(x).to(torch.float32) + if self.softcap is not None and self.softcap != 0.0: + logits = self.softcap * torch.tanh(logits / self.softcap) return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) @@ -114,13 +119,14 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + softcap: Optional[float] = None, ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha + ignore_index=ignore_index, beta=beta, alpha=alpha, softcap=softcap ) def forward(self, x, y): @@ -135,10 +141,12 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", + "scalar, dtype, atol, rtol, softcap", [ - (1.0, torch.bfloat16, 5e-3, 5e-3), - (1.0, torch.float32, 1e-5, 5e-4), + (1.0, torch.bfloat16, 5e-3, 5e-3, None), + (1.0, torch.float32, 1e-5, 5e-4, None), + (1.0, torch.bfloat16, 5e-3, 5e-3, 30), + (1.0, torch.float32, 5e-3, 5e-3, 30), ], ) @pytest.mark.parametrize("bias", [True, False]) @@ -146,7 +154,7 @@ def forward(self, x, y): "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] ) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha, softcap ): B = 2 * B # cpo loss requires B to be even @@ -157,6 +165,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + softcap=softcap, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -165,6 +174,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + softcap=softcap, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( From cb215b353801aecacd544f3551b4431ed044448d Mon Sep 17 00:00:00 2001 From: ryankert01 Date: Fri, 13 Dec 2024 02:45:48 +0800 Subject: [PATCH 10/10] Revert "softcap + cpo test" This reverts commit 6d2a9350dfcf15dbafdb0ee578ec5d917319b1ca. --- .../chunked_loss/fused_linear_preference.py | 12 +++++++--- test/chunked_loss/test_cpo_loss.py | 22 +++++-------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index aca37bb1a..c65702253 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -183,9 +183,6 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): if compiled: fused_fwd_bwd = torch.compile(fused_fwd_bwd) - if softcap is not None: - _input = softcap * torch.tanh(_input / softcap) - len_chosen = target.shape[0] // 2 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) @@ -287,11 +284,16 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + softcap=None, ): len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: logits_chunk = logits_chunk + bias + if softcap is not None: + logits_chunk = logits_chunk / softcap + logits_chunk = torch.tanh(logits_chunk) + logits_chunk = logits_chunk * softcap log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) chosen_nll_loss = 0.0 @@ -341,6 +343,7 @@ def _compute_loss( ref_input_chunk=None, ref_weight=None, ref_bias=None, + softcap=None, **loss_kwargs, ): """ @@ -359,6 +362,7 @@ def _compute_loss( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -374,6 +378,7 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + softcap=softcap, ) chosen_nll_loss = ( chosen_nll_loss @@ -401,6 +406,7 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model + softcap=softcap, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 3ca2ca1bd..f0fef7734 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -1,5 +1,5 @@ from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed -from typing import Tuple, Optional +from typing import Tuple import pytest import torch @@ -88,7 +88,6 @@ def __init__( alpha: float = 1.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, - softcap: Optional[float] = None, ): super().__init__() self.lin = torch.nn.Linear( @@ -100,12 +99,8 @@ def __init__( loss_type=loss_type, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics - self.softcap = softcap def forward(self, x, y): - logits = self.lin(x).to(torch.float32) - if self.softcap is not None and self.softcap != 0.0: - logits = self.softcap * torch.tanh(logits / self.softcap) return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) @@ -119,14 +114,13 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, - softcap: Optional[float] = None, ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha, softcap=softcap + ignore_index=ignore_index, beta=beta, alpha=alpha ) def forward(self, x, y): @@ -141,12 +135,10 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize( - "scalar, dtype, atol, rtol, softcap", + "scalar, dtype, atol, rtol", [ - (1.0, torch.bfloat16, 5e-3, 5e-3, None), - (1.0, torch.float32, 1e-5, 5e-4, None), - (1.0, torch.bfloat16, 5e-3, 5e-3, 30), - (1.0, torch.float32, 5e-3, 5e-3, 30), + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), ], ) @pytest.mark.parametrize("bias", [True, False]) @@ -154,7 +146,7 @@ def forward(self, x, y): "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] ) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha, softcap + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha ): B = 2 * B # cpo loss requires B to be even @@ -165,7 +157,6 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, - softcap=softcap, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -174,7 +165,6 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, - softcap=softcap, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(