Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight support for LigerCrossEntropy #420

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
54 changes: 50 additions & 4 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ def liger_cross_entropy_kernel(
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
sum_of_non_ignore_weight,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
Expand All @@ -50,18 +53,22 @@ def liger_cross_entropy_kernel(
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
weight_stride (int): The stride of the weight tesnor.
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
sum_of_non_ignore_weight (float): The denominator when `reduction="mean"` if `weight` is given.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

Expand All @@ -86,6 +93,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
weight = tl.load(weight_ptr + y).cast(tl.float32)

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

Expand Down Expand Up @@ -162,7 +172,12 @@ def liger_cross_entropy_kernel(
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
if HAS_WEIGHT:
X_block = X_block / (sum_of_non_ignore_weight)
else:
X_block = X_block / (n_non_ignore)
if HAS_WEIGHT:
X_block = X_block * weight
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
# chain rule
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
Expand Down Expand Up @@ -201,8 +216,16 @@ def liger_cross_entropy_kernel(
loss += z_loss
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore
if HAS_WEIGHT:
z_loss = z_loss / sum_of_non_ignore_weight
loss = loss / sum_of_non_ignore_weight
else:
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore

if HAS_WEIGHT:
z_loss = z_loss * weight
loss = loss * weight

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
Expand All @@ -224,6 +247,7 @@ def liger_cross_entropy_kernel(
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -253,7 +277,22 @@ def cross_entropy_forward(
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False

n_non_ignore = (target != ignore_index).sum().item()
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
sum_of_non_ignore_weight = n_non_ignore
if weight is not None:
assert (
weight.shape[0] == V
), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
selected_weight = torch.where(
target_mask, torch.gather(weight, dim=0, index=target * target_mask), 0.0
)
sum_of_non_ignore_weight = selected_weight.sum().item()
Copy link
Collaborator Author

@Tcc0403 Tcc0403 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can rewrite it with torch.masked_select

sum_of_non_ignore_weight = (torch.gather(weight, dim=0, index=target.masked_select(target_mask))
            .sum()
            .item()
        )

Refer to torch's impl mentioned above

if weight.stride(-1) != 1:
weight = weight.contiguous()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
Expand All @@ -267,18 +306,21 @@ def cross_entropy_forward(
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
sum_of_non_ignore_weight=sum_of_non_ignore_weight,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -329,6 +371,7 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -343,6 +386,7 @@ def forward(
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand All @@ -356,6 +400,7 @@ def forward(
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -397,4 +442,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
4 changes: 4 additions & 0 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,21 @@ def fused_linear_cross_entropy_forward(
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=_input, # dummy ptr, not used
weight_stride=0,
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
sum_of_non_ignore_weight=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -30,6 +31,7 @@ def __init__(
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.weight = weight
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
Expand All @@ -41,6 +43,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss = LigerCrossEntropyFunction.apply(
_input,
target,
self.weight,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def liger_cross_entropy(
loss, z_loss = LigerCrossEntropyFunction.apply(
input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down
Loading
Loading