Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight support for LigerCrossEntropy #420

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
116 changes: 97 additions & 19 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,22 @@ def liger_cross_entropy_kernel(
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
sum_non_ignore_weight,
weight_sum,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
Expand All @@ -50,18 +54,22 @@ def liger_cross_entropy_kernel(
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
n_non_ignore (flaot): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

Expand All @@ -86,6 +94,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

Expand Down Expand Up @@ -116,7 +127,15 @@ def liger_cross_entropy_kernel(
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
if HAS_WEIGHT:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
scaled_x_sum += tl.sum(
tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)
)
else:
scaled_x_sum += tl.sum(
tl.where(X_offsets < n_cols, -eps * X_block, 0.0)
)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
Expand Down Expand Up @@ -152,18 +171,42 @@ def liger_cross_entropy_kernel(
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
# chain rule

if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(
X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)
)
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss
Comment on lines +175 to +207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it will be better if we use if HAS_WEIGHT instead of if not HAS_WEIGHT to align all the other behaviors in this change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought putting the base case (none weight) first would be better, but I'll consider it, thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see your point. Never mind. I was a little too nitpicking.


# chain rule softcapping
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)
Expand All @@ -182,6 +225,8 @@ def liger_cross_entropy_kernel(
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = lse - ori_X_y
if HAS_WEIGHT:
loss = weight_y * loss

# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
Expand All @@ -192,17 +237,23 @@ def liger_cross_entropy_kernel(
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * lse
if HAS_WEIGHT:
smooth_loss = scaled_x_sum + eps * lse * weight_sum
else:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss = lse_square_scale * lse * lse
loss += z_loss
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
if HAS_WEIGHT:
loss = loss / sum_non_ignore_weight
else:
loss = loss / n_non_ignore
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore
loss += z_loss
Comment on lines 250 to +256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the change here? I was thinking that if there's any missing part for z_loss. I am not quite sure whether z_loss will be affected by weights or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zloss wasn't scaled by weight right now, so it is divided by number of non ignore token, unlike the rest part of loss divided by the sum of weight when weight exists.
That's why I have to do divisions first before summing them.


tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
Expand All @@ -224,6 +275,7 @@ def liger_cross_entropy_kernel(
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -253,7 +305,25 @@ def cross_entropy_forward(
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False

n_non_ignore = (target != ignore_index).sum().item()
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert (
weight.shape[0] == V
), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
sum_non_ignore_weight = (
torch.gather(weight, dim=0, index=target.masked_select(target_mask))
.sum()
.item()
)
weight_sum = weight.sum().item()
if weight.stride(-1) != 1:
weight = weight.contiguous()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
Expand All @@ -267,18 +337,22 @@ def cross_entropy_forward(
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
sum_non_ignore_weight=sum_non_ignore_weight,
ignore_index=ignore_index,
weight_sum=weight_sum,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -330,6 +404,7 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -344,6 +419,7 @@ def forward(
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand All @@ -357,6 +433,7 @@ def forward(
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -398,4 +475,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
Loading
Loading