From 2d66515c0dd760b1e1dddda23e7029e67e568aaa Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:46:31 +0800 Subject: [PATCH 01/14] Add weight support for LigerCrossEntropy --- src/liger_kernel/ops/cross_entropy.py | 49 +++++++++++++- .../transformers/cross_entropy.py | 3 + src/liger_kernel/transformers/functional.py | 1 + test/transformers/test_cross_entropy.py | 67 ++++++++++++++++++- 4 files changed, 116 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 2a980c69e..41b5757d6 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -27,11 +27,14 @@ def liger_cross_entropy_kernel( X_stride, Y_ptr, Y_stride, + weight_ptr, + weight_stride, loss_ptr, z_loss_ptr, loss_stride, n_cols, n_non_ignore, + sum_of_non_ignore_weight, ignore_index, lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, @@ -39,6 +42,7 @@ def liger_cross_entropy_kernel( softcap, RETURN_Z_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, ): """ @@ -86,6 +90,9 @@ def liger_cross_entropy_kernel( loss_ptr += program_id * loss_stride z_loss_ptr += program_id * loss_stride + if HAS_WEIGHT: + weight = tl.load(weight_ptr + y).cast(tl.float32) + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -162,7 +169,12 @@ def liger_cross_entropy_kernel( X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) # reduction scale if reduction == "mean": - X_block = X_block / (n_non_ignore) + if HAS_WEIGHT: + X_block = X_block / (sum_of_non_ignore_weight) + else: + X_block = X_block / (n_non_ignore) + if HAS_WEIGHT: + X_block = X_block * weight # chain rule # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: @@ -201,8 +213,16 @@ def liger_cross_entropy_kernel( loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": - z_loss = z_loss / n_non_ignore - loss = loss / n_non_ignore + if HAS_WEIGHT: + z_loss = z_loss / sum_of_non_ignore_weight + loss = loss / sum_of_non_ignore_weight + else: + z_loss = z_loss / n_non_ignore + loss = loss / n_non_ignore + + if HAS_WEIGHT: + z_loss = z_loss * weight + loss = loss * weight tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: @@ -224,6 +244,7 @@ def liger_cross_entropy_kernel( def cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -254,6 +275,21 @@ def cross_entropy_forward( z_loss_1d = loss_1d # dummy ptr when return_z_loss == False n_non_ignore = (target != ignore_index).sum().item() + sum_of_non_ignore_weight = n_non_ignore + if weight is not None: + assert ( + weight.shape[0] == V + ), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point( + weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + selected_weight = torch.gather(weight, dim=-1, index=target) + if ignore_index >= 0 and ignore_index < V: + sum_of_non_ignore_weight = selected_weight.sum().item() + else: + sum_of_non_ignore_weight = selected_weight.sum().item() + if weight.stride(-1) != 1: + weight = weight.contiguous() # ensure _input and target are contiguous in the last dimension if _input.stride(-1) != 1: @@ -267,18 +303,22 @@ def cross_entropy_forward( X_stride=_input.stride(-2), Y_ptr=target, Y_stride=target.stride(-1), # always 1 + weight_ptr=weight if weight is not None else _input, # dummy if None + weight_stride=weight.stride(-1) if weight is not None else 0, loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + sum_of_non_ignore_weight=sum_of_non_ignore_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -329,6 +369,7 @@ def forward( ctx, _input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.FloatTensor], ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -356,6 +397,7 @@ def forward( loss, z_loss, _input = cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -397,4 +439,5 @@ def backward(ctx, grad_output, grad_ouput2): None, None, None, + None, ) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 7bd27edd6..f3e51808c 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -8,6 +8,7 @@ class LigerCrossEntropyLoss(torch.nn.Module): def __init__( self, + weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -30,6 +31,7 @@ def __init__( assert ( softcap is None or softcap > 0 ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.weight = weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -41,6 +43,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): loss, z_loss = LigerCrossEntropyFunction.apply( _input, target, + self.weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 45ad6159a..5d6086caa 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -34,6 +34,7 @@ def liger_cross_entropy( loss, z_loss = LigerCrossEntropyFunction.apply( input, target, + weight, ignore_index, lse_square_scale, label_smoothing, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 28e3ec5dc..1df09a321 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -301,6 +301,27 @@ def _test_correctness_with_z_loss_with_other_params_once( assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_weight_once( + target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): @@ -345,6 +366,7 @@ def _test_correctness_functional( y1, y1_z = liger_cross_entropy( x1, target, + None, ignore_index=0, lse_square_scale=1e-4, label_smoothing=0.1, @@ -353,7 +375,7 @@ def _test_correctness_functional( return_z_loss=True, ) y2, y2_z = LigerCrossEntropyFunction.apply( - x2, target, 0, 1e-4, 0.1, "mean", 30.0, True + x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) @@ -687,6 +709,41 @@ def test_correctness_with_z_loss_with_other_params_once( ) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 32000), # llama2, mistral + # # weird shapes + (3, 423, 32000), + ], +) +@pytest.mark.parametrize("weight", [0.5, 0.1]) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_weight_once( + B, T, V, weight, reduction, scalar, dtype, atol, rtol +): + weight = torch.rand(V, device=device, dtype=dtype) + test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) + _test_correctness_with_weight_once( + test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [ @@ -746,17 +803,21 @@ def test_float32_internal(): X_stride=X_bf16.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_bf16, # dummy ptr, not used + weight_stride=X_bf16.stride(-2), z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, @@ -770,17 +831,21 @@ def test_float32_internal(): X_stride=X_fp32.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_fp32, # dummy ptr, not used + weight_stride=X_fp32.stride(-2), loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From dbe42378356ec123aba55a608a30e1502089813c Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:12:48 +0800 Subject: [PATCH 02/14] Update cross_entropy_kernel args in flce --- src/liger_kernel/ops/cross_entropy.py | 8 +++++--- src/liger_kernel/ops/fused_linear_cross_entropy.py | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 41b5757d6..3e1782c1f 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -283,11 +283,13 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - selected_weight = torch.gather(weight, dim=-1, index=target) if ignore_index >= 0 and ignore_index < V: - sum_of_non_ignore_weight = selected_weight.sum().item() + weight_mask = torch.ones_like(weight) + weight_mask[ignore_index] = 0 + selected_weight = torch.gather(weight * weight_mask, dim=-1, index=target) else: - sum_of_non_ignore_weight = selected_weight.sum().item() + selected_weight = torch.gather(weight, dim=-1, index=target) + sum_of_non_ignore_weight = selected_weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 191a2b3d2..a3d0406f1 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -83,17 +83,21 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=_input, # dummy ptr, not used + weight_stride=0, loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, From e77018258c89cbb13e1fe2006da77cfb4f64086f Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:26:34 +0800 Subject: [PATCH 03/14] Add comments --- src/liger_kernel/ops/cross_entropy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 3e1782c1f..35644bf68 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -54,11 +54,14 @@ def liger_cross_entropy_kernel( X_stride (int): The stride of the input tensor. Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + weight_stride (int): The stride of the weight tesnor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. + sum_of_non_ignore_weight (float): The denominator when `reduction="mean"` if `weight` is given. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. @@ -66,6 +69,7 @@ def liger_cross_entropy_kernel( reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to dteremine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ @@ -386,6 +390,7 @@ def forward( ctx : The context object. _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size C and floating point dtype ignore_index (int): The index to ignore in the target. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. From f38e1e26e9bdf2d6aee772aeef2a58a2c5aa8e1c Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 3 Dec 2024 01:01:46 +0800 Subject: [PATCH 04/14] Add complete test with other params --- src/liger_kernel/ops/cross_entropy.py | 6 +- test/transformers/test_cross_entropy.py | 163 ++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 35644bf68..d3e834895 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -28,7 +28,6 @@ def liger_cross_entropy_kernel( Y_ptr, Y_stride, weight_ptr, - weight_stride, loss_ptr, z_loss_ptr, loss_stride, @@ -69,7 +68,7 @@ def liger_cross_entropy_kernel( reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. - HAS_WEIGHT (bool): The boolean value to dteremine whether assigning weight to each of the classes. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ @@ -310,7 +309,6 @@ def cross_entropy_forward( Y_ptr=target, Y_stride=target.stride(-1), # always 1 weight_ptr=weight if weight is not None else _input, # dummy if None - weight_stride=weight.stride(-1) if weight is not None else 0, loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 @@ -390,7 +388,7 @@ def forward( ctx : The context object. _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. - weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size C and floating point dtype + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index (int): The index to ignore in the target. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 1df09a321..fd8df6ea9 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -20,6 +20,7 @@ class CrossEntropyWithZLoss(torch.nn.Module): def __init__( self, + weight=None, lse_square_scale=0.0, reduction="mean", ignore_index=-100, @@ -28,6 +29,7 @@ def __init__( dtype=torch.float32, ): super().__init__() + self.weight = weight self.lse_square_scale = lse_square_scale self.reduction = reduction self.ignore_index = ignore_index @@ -38,10 +40,24 @@ def __init__( def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) + HAS_WEIGHT = True if self.weight is not None else False + if HAS_WEIGHT: + self.weight = self.weight.to(torch.float32) + if self.ignore_index >= 0 and self.ignore_index < logits.shape[-1]: + weight_mask = torch.ones_like(self.weight) + weight_mask[self.ignore_index] = 0 + selected_weight = torch.gather( + self.weight * weight_mask, dim=-1, index=targets + ) + del weight_mask + else: + selected_weight = torch.gather(self.weight, dim=-1, index=targets) + sum_of_non_ignore_weight = selected_weight.sum().item() # Standard cross entropy loss ce_loss = F.cross_entropy( logits, targets, + weight=self.weight, reduction=self.reduction, label_smoothing=self.label_smoothing, ignore_index=self.ignore_index, @@ -54,9 +70,14 @@ def forward(self, logits, targets): z_loss = torch.where( targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 ) - z_loss = z_loss.to(logits.dtype) + if HAS_WEIGHT: + z_loss = z_loss * selected_weight + if self.reduction == "mean": - z_loss = z_loss.sum() / (targets != self.ignore_index).sum() + if HAS_WEIGHT: + z_loss = z_loss.sum() / sum_of_non_ignore_weight + else: + z_loss = z_loss.sum() / (targets != self.ignore_index).sum() elif self.reduction == "sum": z_loss = z_loss.sum() else: @@ -185,13 +206,15 @@ def _test_correctness_with_softcap_once( _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar # upcasting to match liger's casting strategy - _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) + _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # downcasting to original dtype - output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) + output = torch_ce( + softcap * torch.tanh(_input.to(torch.float32) / softcap), target + ).to(dtype) output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) @@ -322,6 +345,59 @@ def _test_correctness_with_weight_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_weight_with_other_params_once( + target_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + torch.manual_seed(0) + torch_ce = CrossEntropyWithZLoss( + weight=weight, + lse_square_scale=lse_square_scale, + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + dtype=dtype, + ) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + # upcasting to match liger's casting strategy + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce( + softcap * torch.tanh(_input.to(torch.float32) / softcap), target + ).to(dtype) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): @@ -717,7 +793,6 @@ def test_correctness_with_z_loss_with_other_params_once( (3, 423, 32000), ], ) -@pytest.mark.parametrize("weight", [0.5, 0.1]) @pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", @@ -734,9 +809,7 @@ def test_correctness_with_z_loss_with_other_params_once( (1.0, torch.float32, 1e-8, 1e-6), ], ) -def test_correctness_with_weight_once( - B, T, V, weight, reduction, scalar, dtype, atol, rtol -): +def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, rtol): weight = torch.rand(V, device=device, dtype=dtype) test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) _test_correctness_with_weight_once( @@ -744,6 +817,78 @@ def test_correctness_with_weight_once( ) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 3200), # llama2, mistral + # # weird shapes + (3, 423, 3200), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "ignore_index, lse_square_scale, label_smoothing, softcap", + [ + (-100, 1e-4, 0.1, 30.0), + (42, 1e-5, 0.2, 40.0), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_weight_with_other_params_once( + B, + T, + V, + reduction, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + weight = torch.rand(V, device=device, dtype=torch.float32) # match softcap casting + test_ce = LigerCrossEntropyLoss( + weight=weight, + lse_square_scale=lse_square_scale, + reduction=reduction, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + softcap=softcap, + ) + _test_correctness_with_weight_with_other_params_once( + test_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, + ) + + @pytest.mark.parametrize( "B, T, V", [ @@ -804,7 +949,6 @@ def test_float32_internal(): Y_ptr=Y, Y_stride=Y.stride(-1), weight_ptr=X_bf16, # dummy ptr, not used - weight_stride=X_bf16.stride(-2), z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), @@ -832,7 +976,6 @@ def test_float32_internal(): Y_ptr=Y, Y_stride=Y.stride(-1), weight_ptr=X_fp32, # dummy ptr, not used - weight_stride=X_fp32.stride(-2), loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), From 45f6c1f9db0938a3d380a86452e3d237056ccbda Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 3 Dec 2024 02:07:08 +0800 Subject: [PATCH 05/14] Fix invalid range access bug --- src/liger_kernel/ops/cross_entropy.py | 12 +++++------- test/transformers/test_cross_entropy.py | 21 ++++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index d3e834895..ea67ba6cd 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -277,7 +277,8 @@ def cross_entropy_forward( else: z_loss_1d = loss_1d # dummy ptr when return_z_loss == False - n_non_ignore = (target != ignore_index).sum().item() + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() sum_of_non_ignore_weight = n_non_ignore if weight is not None: assert ( @@ -286,12 +287,9 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - if ignore_index >= 0 and ignore_index < V: - weight_mask = torch.ones_like(weight) - weight_mask[ignore_index] = 0 - selected_weight = torch.gather(weight * weight_mask, dim=-1, index=target) - else: - selected_weight = torch.gather(weight, dim=-1, index=target) + selected_weight = torch.where( + target_mask, torch.gather(weight, dim=0, index=target * target_mask), 0.0 + ) sum_of_non_ignore_weight = selected_weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index fd8df6ea9..ff8302569 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -41,18 +41,17 @@ def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) HAS_WEIGHT = True if self.weight is not None else False + + target_mask = targets != self.ignore_index if HAS_WEIGHT: self.weight = self.weight.to(torch.float32) - if self.ignore_index >= 0 and self.ignore_index < logits.shape[-1]: - weight_mask = torch.ones_like(self.weight) - weight_mask[self.ignore_index] = 0 - selected_weight = torch.gather( - self.weight * weight_mask, dim=-1, index=targets - ) - del weight_mask - else: - selected_weight = torch.gather(self.weight, dim=-1, index=targets) + selected_weight = torch.where( + target_mask, + torch.gather(self.weight, dim=-1, index=targets * target_mask), + 0.0, + ) sum_of_non_ignore_weight = selected_weight.sum().item() + # Standard cross entropy loss ce_loss = F.cross_entropy( logits, @@ -71,7 +70,11 @@ def forward(self, logits, targets): targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 ) if HAS_WEIGHT: + # print(f"{z_loss.shape=}") z_loss = z_loss * selected_weight + # print(f"{selected_weight.shape=}") + # print(f"{selected_weight[targets == self.ignore_index]=}") + # print(f"{selected_weight[targets != self.ignore_index]=}") if self.reduction == "mean": if HAS_WEIGHT: From a1a4f0ac85a831a8b00b6d3ff984b962a15c92b9 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 15 Dec 2024 13:39:34 +0800 Subject: [PATCH 06/14] Refactor variable names and computation of target's weights --- src/liger_kernel/ops/cross_entropy.py | 58 ++++++++++++++------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index ea67ba6cd..1998e7831 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -33,7 +33,7 @@ def liger_cross_entropy_kernel( loss_stride, n_cols, n_non_ignore, - sum_of_non_ignore_weight, + weight_sum, ignore_index, lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, @@ -54,13 +54,13 @@ def liger_cross_entropy_kernel( Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. weight_ptr: Pointer to weight tensor. - weight_stride (int): The stride of the weight tesnor. + weight_stride (int): The stride of the weight tensor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - sum_of_non_ignore_weight (float): The denominator when `reduction="mean"` if `weight` is given. + n_non_ignore (flaot): The number of non-ignored elements or the sum of non-ignored target's weights in the batch + weight_sum (float): The sum of weigh tensor ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. @@ -94,7 +94,7 @@ def liger_cross_entropy_kernel( z_loss_ptr += program_id * loss_stride if HAS_WEIGHT: - weight = tl.load(weight_ptr + y).cast(tl.float32) + weight_y = tl.load(weight_ptr + y).cast(tl.float32) # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -126,7 +126,15 @@ def liger_cross_entropy_kernel( block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum( + tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0) + ) + else: + scaled_x_sum += tl.sum( + tl.where(X_offsets < n_cols, -eps * X_block, 0.0) + ) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -172,12 +180,9 @@ def liger_cross_entropy_kernel( X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) # reduction scale if reduction == "mean": - if HAS_WEIGHT: - X_block = X_block / (sum_of_non_ignore_weight) - else: - X_block = X_block / (n_non_ignore) + X_block = X_block / (n_non_ignore) if HAS_WEIGHT: - X_block = X_block * weight + X_block = X_block * weight_y # chain rule # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: @@ -197,6 +202,8 @@ def liger_cross_entropy_kernel( # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) @@ -207,7 +214,10 @@ def liger_cross_entropy_kernel( # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * lse + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss # An auxiliary loss, z_loss @@ -216,17 +226,8 @@ def liger_cross_entropy_kernel( loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": - if HAS_WEIGHT: - z_loss = z_loss / sum_of_non_ignore_weight - loss = loss / sum_of_non_ignore_weight - else: - z_loss = z_loss / n_non_ignore - loss = loss / n_non_ignore - - if HAS_WEIGHT: - z_loss = z_loss * weight - loss = loss * weight - + z_loss = z_loss / n_non_ignore + loss = loss / n_non_ignore tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: tl.store(z_loss_ptr, z_loss) @@ -279,7 +280,7 @@ def cross_entropy_forward( target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() - sum_of_non_ignore_weight = n_non_ignore + weight_sum = weight.sum().item() if weight is not None: assert ( weight.shape[0] == V @@ -287,10 +288,11 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - selected_weight = torch.where( - target_mask, torch.gather(weight, dim=0, index=target * target_mask), 0.0 + n_non_ignore = ( + torch.gather(weight, dim=0, index=target.masked_select(target_mask)) + .sum() + .item() ) - sum_of_non_ignore_weight = selected_weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() @@ -313,7 +315,7 @@ def cross_entropy_forward( n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, - sum_of_non_ignore_weight=sum_of_non_ignore_weight, + weight_sum=weight_sum, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, From 2e6ded2f8f0b01291673a39c90f1fa82c3ce0656 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 15 Dec 2024 13:50:29 +0800 Subject: [PATCH 07/14] Fix unit test --- test/transformers/test_cross_entropy.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index f3113aaae..8c402254e 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -346,8 +346,8 @@ def _test_correctness_with_weight_once( output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -397,11 +397,11 @@ def _test_correctness_with_weight_with_other_params_once( softcap * torch.tanh(_input.to(torch.float32) / softcap), target ).to(dtype) output2 = target_ce(_input2, target) - assert torch.allclose(output, output2, atol=atol, rtol=rtol) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() - assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) def _test_correctness_not_last_layer_once( @@ -831,7 +831,7 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r (3, 423, 3200), ], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "ignore_index, lse_square_scale, label_smoothing, softcap", [ From d54ce805bbda26e8ccef55981a715db940b3d2d2 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 22 Dec 2024 12:25:34 +0800 Subject: [PATCH 08/14] Block invalid operation when weight is None --- src/liger_kernel/ops/cross_entropy.py | 2 +- test/transformers/test_cross_entropy.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index d9a981912..91a5be15e 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -280,7 +280,7 @@ def cross_entropy_forward( target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() - weight_sum = weight.sum().item() + weight_sum = weight.sum().item() if weight is not None else 0 if weight is not None: assert ( weight.shape[0] == V diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 8c402254e..1d68d02f8 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -835,22 +835,22 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r @pytest.mark.parametrize( "ignore_index, lse_square_scale, label_smoothing, softcap", [ - (-100, 1e-4, 0.1, 30.0), - (42, 1e-5, 0.2, 40.0), + (-100, 0, 0.1, 30.0), + # (42, 1e-5, 0.2, 40.0), ], ) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 1.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), + # pytest.param( + # 1.0, + # torch.bfloat16, + # 1e-8, + # 5e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), (1.0, torch.float32, 1e-8, 1e-6), ], ) From cbaf88fece9f19fbcbed0e160dd9d1cc65089092 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 22 Dec 2024 13:23:30 +0800 Subject: [PATCH 09/14] Update gradients calculation for weighted smooth loss --- src/liger_kernel/ops/cross_entropy.py | 65 ++++++++++++++++++------- test/transformers/test_cross_entropy.py | 26 ++-------- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 91a5be15e..5702561f8 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -33,6 +33,7 @@ def liger_cross_entropy_kernel( loss_stride, n_cols, n_non_ignore, + n_sum_non_ignore_weight, weight_sum, ignore_index, lse_square_scale: tl.constexpr, @@ -170,20 +171,42 @@ def liger_cross_entropy_kernel( if HAS_SOFTCAPPING: intermediate = tanh(X_block / softcap) X_block = softcap * intermediate - # softmax(x_i) - X_block = tl.exp(X_block - m) / d - # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) - X_block += 2 * lse_square_scale * lse * X_block - # smoothing term - X_block += -eps - # special handle dx_y - X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) - # reduction scale - if reduction == "mean": - X_block = X_block / (n_non_ignore) - if HAS_WEIGHT: - X_block = X_block * weight_y - # chain rule + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where( + X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing) + ) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / n_sum_non_ignore_weight + dloss_smooth = dloss_smooth / n_sum_non_ignore_weight + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: X_block = X_block * (1 - intermediate * intermediate) @@ -223,11 +246,15 @@ def liger_cross_entropy_kernel( # An auxiliary loss, z_loss # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html z_loss = lse_square_scale * lse * lse - loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": + if HAS_WEIGHT: + loss = loss / n_sum_non_ignore_weight + else: + loss = loss / n_non_ignore z_loss = z_loss / n_non_ignore - loss = loss / n_non_ignore + loss += z_loss + tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: tl.store(z_loss_ptr, z_loss) @@ -280,7 +307,8 @@ def cross_entropy_forward( target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() - weight_sum = weight.sum().item() if weight is not None else 0 + n_sum_non_ignore_weight = n_non_ignore + weight_sum = weight.sum().item() if weight is not None else 0.0 if weight is not None: assert ( weight.shape[0] == V @@ -288,7 +316,7 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - n_non_ignore = ( + n_sum_non_ignore_weight = ( torch.gather(weight, dim=0, index=target.masked_select(target_mask)) .sum() .item() @@ -314,6 +342,7 @@ def cross_entropy_forward( loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, + n_sum_non_ignore_weight=n_sum_non_ignore_weight, ignore_index=ignore_index, weight_sum=weight_sum, lse_square_scale=lse_square_scale, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 1d68d02f8..67d977481 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -41,17 +41,8 @@ def __init__( def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) - HAS_WEIGHT = True if self.weight is not None else False target_mask = targets != self.ignore_index - if HAS_WEIGHT: - self.weight = self.weight.to(torch.float32) - selected_weight = torch.where( - target_mask, - torch.gather(self.weight, dim=-1, index=targets * target_mask), - 0.0, - ) - sum_of_non_ignore_weight = selected_weight.sum().item() # Standard cross entropy loss ce_loss = F.cross_entropy( @@ -70,18 +61,9 @@ def forward(self, logits, targets): z_loss = torch.where( targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 ) - if HAS_WEIGHT: - # print(f"{z_loss.shape=}") - z_loss = z_loss * selected_weight - # print(f"{selected_weight.shape=}") - # print(f"{selected_weight[targets == self.ignore_index]=}") - # print(f"{selected_weight[targets != self.ignore_index]=}") if self.reduction == "mean": - if HAS_WEIGHT: - z_loss = z_loss.sum() / sum_of_non_ignore_weight - else: - z_loss = z_loss.sum() / (targets != self.ignore_index).sum() + z_loss = z_loss.sum() / target_mask.sum() elif self.reduction == "sum": z_loss = z_loss.sum() else: @@ -960,7 +942,8 @@ def test_float32_internal(): loss_stride=loss_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, - sum_of_non_ignore_weight=n_non_ignore, # not used + n_sum_non_ignore_weight=n_non_ignore, # not used + weight_sum=0.0, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -987,7 +970,8 @@ def test_float32_internal(): loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, - sum_of_non_ignore_weight=n_non_ignore, # not used + n_sum_non_ignore_weight=n_non_ignore, # not used + weight_sum=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, From ec134fcb9848fb374a149006a22497d9f96a13db Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 22 Dec 2024 14:06:56 +0800 Subject: [PATCH 10/14] Clean up --- src/liger_kernel/ops/cross_entropy.py | 25 +++++++++++---------- test/transformers/test_cross_entropy.py | 30 ++++++++++++------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 5702561f8..d9f7947b2 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -33,7 +33,7 @@ def liger_cross_entropy_kernel( loss_stride, n_cols, n_non_ignore, - n_sum_non_ignore_weight, + sum_non_ignore_weight, weight_sum, ignore_index, lse_square_scale: tl.constexpr, @@ -55,19 +55,19 @@ def liger_cross_entropy_kernel( Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. weight_ptr: Pointer to weight tensor. - weight_stride (int): The stride of the weight tensor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. - n_non_ignore (flaot): The number of non-ignored elements or the sum of non-ignored target's weights in the batch - weight_sum (float): The sum of weigh tensor + n_non_ignore (flaot): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. - RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. BLOCK_SIZE (int): The block size for Triton operations. HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. @@ -200,8 +200,8 @@ def liger_cross_entropy_kernel( dz_loss = 2 * lse_square_scale * lse * softmax_X # reduction scale if reduction == "mean": - dloss_ori = dloss_ori / n_sum_non_ignore_weight - dloss_smooth = dloss_smooth / n_sum_non_ignore_weight + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight dz_loss = dz_loss / n_non_ignore # derivative of total_loss X_block = dloss_ori + dloss_smooth + dz_loss @@ -249,7 +249,7 @@ def liger_cross_entropy_kernel( # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": if HAS_WEIGHT: - loss = loss / n_sum_non_ignore_weight + loss = loss / sum_non_ignore_weight else: loss = loss / n_non_ignore z_loss = z_loss / n_non_ignore @@ -307,8 +307,8 @@ def cross_entropy_forward( target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() - n_sum_non_ignore_weight = n_non_ignore - weight_sum = weight.sum().item() if weight is not None else 0.0 + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 if weight is not None: assert ( weight.shape[0] == V @@ -316,11 +316,12 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - n_sum_non_ignore_weight = ( + sum_non_ignore_weight = ( torch.gather(weight, dim=0, index=target.masked_select(target_mask)) .sum() .item() ) + weight_sum = weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() @@ -342,7 +343,7 @@ def cross_entropy_forward( loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, - n_sum_non_ignore_weight=n_sum_non_ignore_weight, + sum_non_ignore_weight=sum_non_ignore_weight, ignore_index=ignore_index, weight_sum=weight_sum, lse_square_scale=lse_square_scale, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 67d977481..5c050b983 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -808,31 +808,31 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r @pytest.mark.parametrize( "B, T, V", [ - (2, 4096, 3200), # llama2, mistral + (2, 4096, 32000), # llama2, mistral # # weird shapes - (3, 423, 3200), + (3, 423, 32000), ], ) @pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "ignore_index, lse_square_scale, label_smoothing, softcap", [ - (-100, 0, 0.1, 30.0), - # (42, 1e-5, 0.2, 40.0), + (-100, 1e-4, 0.1, 30.0), + (42, 1e-5, 0.2, 40.0), ], ) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - # pytest.param( - # 1.0, - # torch.bfloat16, - # 1e-8, - # 5e-2, - # marks=pytest.mark.skipif( - # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - # ), - # ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), (1.0, torch.float32, 1e-8, 1e-6), ], ) @@ -942,7 +942,7 @@ def test_float32_internal(): loss_stride=loss_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, - n_sum_non_ignore_weight=n_non_ignore, # not used + sum_non_ignore_weight=n_non_ignore, # not used weight_sum=0.0, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, @@ -970,7 +970,7 @@ def test_float32_internal(): loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, - n_sum_non_ignore_weight=n_non_ignore, # not used + sum_non_ignore_weight=n_non_ignore, # not used weight_sum=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, From 7ed4dd9903c77646504c6e98c3cedd19e27e6858 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 22 Dec 2024 14:25:39 +0800 Subject: [PATCH 11/14] Update flce --- .../ops/fused_linear_cross_entropy.py | 83 ++++++++++++------- src/liger_kernel/transformers/functional.py | 2 + .../fused_linear_cross_entropy.py | 3 + .../test_fused_linear_cross_entropy.py | 23 +++-- 4 files changed, 72 insertions(+), 39 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index a3d0406f1..341ed3199 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -19,6 +19,7 @@ def fused_linear_cross_entropy_forward( _input, weight, target, + ce_weight=None, bias=None, ignore_index=-100, lse_square_scale=0.0, @@ -54,7 +55,25 @@ def fused_linear_cross_entropy_forward( loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) # NOTE: skip .item() here to avoid CUDA synchronization - total_n_non_ignore = (target != ignore_index).sum() + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ( + ce_weight.shape[0] == V + ), f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point( + ce_weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)) + .sum() + .item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() for chunk_id in range(num_chunks): start_idx = chunk_id * chunk_size @@ -71,7 +90,6 @@ def fused_linear_cross_entropy_forward( # unreduced loss loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, - n_non_ignore = (target_chunk != ignore_index).sum().item() # ensure _input and target are contiguous logits_chunk = logits_chunk.contiguous() @@ -83,14 +101,14 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 - weight_ptr=_input, # dummy ptr, not used - weight_stride=0, + weight_ptr=ce_weight, # dummy ptr, not used loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, - n_non_ignore=n_non_ignore, - sum_of_non_ignore_weight=n_non_ignore, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -103,19 +121,8 @@ def fused_linear_cross_entropy_forward( num_warps=32 if not is_hip() else 16, ) - # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V - # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H - # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only - # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. - # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. - - if reduction == "mean": - alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0 - else: - alpha = 1.0 - - loss_1d[start_idx:end_idx] = loss_1d_slice * alpha - grad_logits_chunk = logits_chunk * alpha # chunk_size x V + loss_1d[start_idx:end_idx] = loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V grad_input[start_idx:end_idx] = grad_logits_chunk @ weight @@ -125,7 +132,7 @@ def fused_linear_cross_entropy_forward( mat1=logits_chunk.t(), mat2=_input_chunk, out=grad_weight, - alpha=alpha, + alpha=1.0, beta=1.0, ) @@ -134,7 +141,7 @@ def fused_linear_cross_entropy_forward( input=grad_bias, other=logits_chunk.sum(dim=0), out=grad_bias, - alpha=alpha, + alpha=1.0, ) loss = torch.sum(loss_1d) @@ -199,6 +206,7 @@ def forward( weight, target, bias=None, + ce_weight=None, ignore_index=-100, lse_square_scale=0.0, label_smoothing=0.0, @@ -218,21 +226,23 @@ def forward( target: (B*T) where each value is in [0, V-1] weight: (V, H) where V is the number of classes bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index: the index to ignore in the target label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction: reduction to apply """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, - weight, - target, - bias, - ignore_index, - lse_square_scale, - label_smoothing, - reduction, - softcap, + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -249,4 +259,15 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 5d6086caa..60d472129 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -52,6 +52,7 @@ def liger_fused_linear_cross_entropy( weight, target, bias=None, + ce_weight=None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -63,6 +64,7 @@ def liger_fused_linear_cross_entropy( weight, target, bias, + ce_weight, ignore_index, lse_square_scale, label_smoothing, diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 7df79d309..c13148f91 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -10,6 +10,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): def __init__( self, + ce_weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -28,6 +29,7 @@ def __init__( assert ( softcap is None or softcap > 0 ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.ce_weight = ce_weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -40,6 +42,7 @@ def forward(self, lin_weight, _input, target, bias=None): lin_weight, target, bias, + self.ce_weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index a6bcd4d8b..8909d9337 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -41,6 +41,7 @@ def __init__( V: int, dtype: torch.dtype, bias: bool = False, + ce_weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -52,6 +53,7 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.ce_loss = CrossEntropyWithZLoss( + weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -72,6 +74,7 @@ def __init__( H: int, V: int, dtype: torch.dtype, + ce_weight: Optional[torch.FloatTensor] = None, bias: bool = False, ignore_index: int = -100, lse_square_scale: float = 0.0, @@ -84,6 +87,7 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.ce_loss = LigerFusedLinearCrossEntropyLoss( + ce_weight=ce_weight, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, @@ -118,15 +122,11 @@ def forward(self, x, y): ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize( - "label_smoothing, ignore_index, lse_square_scale, softcap", + "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap", [ - (0, -100, 0, None), - ( - 0.1, - 42, - 1e-4, - 30.0, - ), # Pass non-default values once to ensure all params work along + (False, 0, -100, 0, None), + # Pass non-default values once to ensure all params work along + (True, 0.1, 42, 1e-4, 30.0), ], ) def test_correctness( @@ -137,6 +137,7 @@ def test_correctness( scalar, dtype, bias, + has_ce_weight, lse_square_scale, label_smoothing, ignore_index, @@ -145,10 +146,15 @@ def test_correctness( atol, rtol, ): + if has_ce_weight: + ce_weight = torch.rand(V, device=device, dtype=torch.float32) + else: + ce_weight = None torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, bias=bias, + ce_weight=ce_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, ignore_index=ignore_index, @@ -160,6 +166,7 @@ def test_correctness( H=H, V=V, bias=bias, + ce_weight=ce_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, ignore_index=ignore_index, From 5535a6091b0842129961e21ca3f23fade7de4005 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 22 Dec 2024 14:49:51 +0800 Subject: [PATCH 12/14] Fix kernel arugments in flce --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 341ed3199..15481c34d 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -101,7 +101,7 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 - weight_ptr=ce_weight, # dummy ptr, not used + weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 @@ -115,7 +115,7 @@ def fused_linear_cross_entropy_forward( reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=0, # False - HAS_WEIGHT=False, + HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, From f2f7b48e71043622e7bdb455e31523d97611d591 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 28 Dec 2024 14:59:38 +0800 Subject: [PATCH 13/14] Clean up and checkstyle --- src/liger_kernel/ops/cross_entropy.py | 22 ++------ .../ops/fused_linear_cross_entropy.py | 8 +-- .../transformers/cross_entropy.py | 4 +- .../fused_linear_cross_entropy.py | 4 +- test/transformers/test_cross_entropy.py | 37 ++++---------- weight_ce.py | 50 ------------------- 6 files changed, 18 insertions(+), 107 deletions(-) delete mode 100644 weight_ce.py diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index c7f049b56..4139ce608 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -130,13 +130,9 @@ def liger_cross_entropy_kernel( # scale X beforehand to avoid overflow if HAS_WEIGHT: weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) - scaled_x_sum += tl.sum( - tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0) - ) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) else: - scaled_x_sum += tl.sum( - tl.where(X_offsets < n_cols, -eps * X_block, 0.0) - ) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -191,9 +187,7 @@ def liger_cross_entropy_kernel( # derivative of original_loss dloss_ori = (1 - label_smoothing) * softmax_X # specially handle dx_y - dloss_ori = tl.where( - X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing) - ) + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) dloss_ori = dloss_ori * weight_y # derivative of smooth_loss dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) @@ -307,17 +301,11 @@ def cross_entropy_forward( sum_non_ignore_weight = n_non_ignore weight_sum = 0.0 if weight is not None: - assert ( - weight.shape[0] == V - ), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - sum_non_ignore_weight = ( - torch.gather(weight, dim=0, index=target.masked_select(target_mask)) - .sum() - .item() - ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() weight_sum = weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 26de6591a..41b223865 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -54,16 +54,12 @@ def fused_linear_cross_entropy_forward( total_sum_non_ignore_ce_weight = total_n_non_ignore ce_weight_sum = 0.0 if ce_weight is not None: - assert ( - ce_weight.shape[0] == V - ), f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" assert torch.is_floating_point( ce_weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" total_sum_non_ignore_ce_weight = ( - torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)) - .sum() - .item() + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() ) ce_weight_sum = ce_weight.sum().item() if ce_weight.stride(-1) != 1: diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index f3e51808c..d72fc3b00 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -28,9 +28,7 @@ def __init__( "sum", "none", }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" - assert ( - softcap is None or softcap > 0 - ), f"softcap must greater than 0.0 or None. Got: {softcap}" + assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" self.weight = weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 0c6ce0328..1de352e6c 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -24,9 +24,7 @@ def __init__( "sum", "none", }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" - assert ( - softcap is None or softcap > 0 - ), f"softcap must greater than 0.0 or None. Got: {softcap}" + assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" self.ce_weight = ce_weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 7b8d1a9d0..b88033f2a 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -58,9 +58,7 @@ def forward(self, logits, targets): lse = torch.logsumexp(logits, dim=-1) # Z-loss term - z_loss = torch.where( - targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 - ) + z_loss = torch.where(targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0) if self.reduction == "mean": z_loss = z_loss.sum() / target_mask.sum() @@ -292,9 +290,7 @@ def _test_correctness_with_z_loss_with_other_params_once( assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_weight_once( - target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol -): +def _test_correctness_with_weight_once(target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol): torch.manual_seed(0) torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction) @@ -350,14 +346,10 @@ def _test_correctness_with_weight_with_other_params_once( num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices target[indices_to_assign] = ignore_index - output = torch_ce( - softcap * torch.tanh(_input.to(torch.float32) / softcap), target - ).to(dtype) + output = torch_ce(softcap * torch.tanh(_input.to(torch.float32) / softcap), target).to(dtype) output2 = target_ce(_input2, target) assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) @@ -366,10 +358,7 @@ def _test_correctness_with_weight_with_other_params_once( assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_not_last_layer_once( - target_ce, B, T, V, reduction, scalar, dtype, atol, rtol -): - +def _test_correctness_not_last_layer_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol): torch_ce = CrossEntropyLoss(reduction=reduction) _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar @@ -417,9 +406,7 @@ def _test_correctness_functional( softcap=30.0, return_z_loss=True, ) - y2, y2_z = LigerCrossEntropyFunction.apply( - x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True - ) + y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) @@ -751,9 +738,7 @@ def test_correctness_with_z_loss_with_other_params_once( def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, rtol): weight = torch.rand(V, device=device, dtype=dtype) test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) - _test_correctness_with_weight_once( - test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol - ) + _test_correctness_with_weight_once(test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol) @pytest.mark.parametrize( @@ -780,9 +765,7 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], @@ -845,9 +828,7 @@ def test_correctness_with_weight_with_other_params_once( torch.bfloat16, 1e-8, 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), (1.0, torch.float32, 1e-8, 1e-6), ], diff --git a/weight_ce.py b/weight_ce.py deleted file mode 100644 index 3431abe5a..000000000 --- a/weight_ce.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import torch.nn as nn - -# Example data: 3 classes -logits = torch.tensor( - [ - [2.0, 0.5, 0.1], # Prediction logits for sample 1 - [0.1, 1.5, 2.1], # Prediction logits for sample 2 - [1.0, 2.0, 0.1], - ] -) # Prediction logits for sample 3 - -targets = torch.tensor([0, 2, 1]) # Ground truth labels - -# Define CrossEntropyLoss without weights -criterion_no_weight = nn.CrossEntropyLoss(reduction="none", label_smoothing=0.1) - -# Define CrossEntropyLoss with weights -weights = torch.tensor([0.7, 1.0, 1.5]) # Assign different weights to each class -criterion_with_weight = nn.CrossEntropyLoss( - weight=weights, reduction="none", label_smoothing=0.1 -) - -# Compute loss without weights -loss_no_weight = criterion_no_weight(logits, targets) - -# Compute loss with weights -loss_with_weight = criterion_with_weight(logits, targets) - -selected_weight = torch.gather(weights, dim=0, index=targets) -print(f"{selected_weight=}") -print("Loss without weights:", loss_no_weight) -print("Loss with weights:", loss_with_weight) -print("====================================================") -# Define CrossEntropyLoss without weights -criterion_no_weight = nn.CrossEntropyLoss(reduction="none") - -# Define CrossEntropyLoss with weights -weights = torch.tensor([0.7, 1.0, 1.5]) # Assign different weights to each class -criterion_with_weight = nn.CrossEntropyLoss(weight=weights, reduction="none") - -# Compute loss without weights -loss_no_weight = criterion_no_weight(logits, targets) - -# Compute loss with weights -loss_with_weight = criterion_with_weight(logits, targets) -print(f"{selected_weight=}") - -print("Loss without weights:", loss_no_weight) -print("Loss with weights:", loss_with_weight) From 2b382f80c2fb80e0891e4fa8ab19058ef45384e0 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 28 Dec 2024 15:09:39 +0800 Subject: [PATCH 14/14] add TODO --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 41b223865..b74f8063a 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -48,7 +48,7 @@ def fused_linear_cross_entropy_forward( # we use fp32 for loss accumulator loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) - # NOTE: skip .item() here to avoid CUDA synchronization + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed target_mask = target != ignore_index total_n_non_ignore = target_mask.sum().item() total_sum_non_ignore_ce_weight = total_n_non_ignore