-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add weight support for LigerCrossEntropy #420
base: main
Are you sure you want to change the base?
Changes from all commits
2d66515
dbe4237
e770182
f38e1e2
45f6c1f
a1a4f0a
0473e22
2e6ded2
d54ce80
cbaf88f
ec134fc
7ed4dd9
5535a60
b6253b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,18 +27,22 @@ def liger_cross_entropy_kernel( | |
X_stride, | ||
Y_ptr, | ||
Y_stride, | ||
weight_ptr, | ||
loss_ptr, | ||
z_loss_ptr, | ||
loss_stride, | ||
n_cols, | ||
n_non_ignore, | ||
sum_non_ignore_weight, | ||
weight_sum, | ||
ignore_index, | ||
lse_square_scale: tl.constexpr, | ||
label_smoothing: tl.constexpr, | ||
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time | ||
softcap, | ||
RETURN_Z_LOSS: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
HAS_WEIGHT: tl.constexpr, | ||
HAS_SOFTCAPPING: tl.constexpr, | ||
): | ||
""" | ||
|
@@ -50,18 +54,22 @@ def liger_cross_entropy_kernel( | |
X_stride (int): The stride of the input tensor. | ||
Y_ptr: Pointer to target tensor. | ||
Y_stride (int): The stride of the target tensor. | ||
weight_ptr: Pointer to weight tensor. | ||
loss_ptr: Pointer to tensor to store the loss. | ||
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. | ||
loss_stride (int): The stride of the loss tensor. | ||
n_cols (int): The number of columns in the input tensor. | ||
n_non_ignore (int): The number of non-ignored elements in the batch. | ||
n_non_ignore (flaot): The number of non-ignored elements in the batch. | ||
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. | ||
weight_sum (float): The sum of weight tensor. | ||
ignore_index (int): The index to ignore in the target. | ||
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. | ||
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. | ||
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. | ||
reduction (str): The string for the reduction to apply | ||
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). | ||
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. | ||
BLOCK_SIZE (int): The block size for Triton operations. | ||
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. | ||
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. | ||
""" | ||
|
||
|
@@ -86,6 +94,9 @@ def liger_cross_entropy_kernel( | |
loss_ptr += program_id * loss_stride | ||
z_loss_ptr += program_id * loss_stride | ||
|
||
if HAS_WEIGHT: | ||
weight_y = tl.load(weight_ptr + y).cast(tl.float32) | ||
|
||
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) | ||
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 | ||
|
||
|
@@ -116,7 +127,15 @@ def liger_cross_entropy_kernel( | |
block_max = tl.max(X_block) | ||
if label_smoothing > 0: | ||
# scale X beforehand to avoid overflow | ||
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) | ||
if HAS_WEIGHT: | ||
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) | ||
scaled_x_sum += tl.sum( | ||
tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0) | ||
) | ||
else: | ||
scaled_x_sum += tl.sum( | ||
tl.where(X_offsets < n_cols, -eps * X_block, 0.0) | ||
) | ||
m_new = tl.maximum(m, block_max) | ||
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) | ||
m = m_new | ||
|
@@ -152,18 +171,42 @@ def liger_cross_entropy_kernel( | |
if HAS_SOFTCAPPING: | ||
intermediate = tanh(X_block / softcap) | ||
X_block = softcap * intermediate | ||
# softmax(x_i) | ||
X_block = tl.exp(X_block - m) / d | ||
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) | ||
X_block += 2 * lse_square_scale * lse * X_block | ||
# smoothing term | ||
X_block += -eps | ||
# special handle dx_y | ||
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) | ||
# reduction scale | ||
if reduction == "mean": | ||
X_block = X_block / (n_non_ignore) | ||
# chain rule | ||
|
||
if not HAS_WEIGHT: | ||
# softmax(x_i) | ||
X_block = tl.exp(X_block - m) / d | ||
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) | ||
X_block += 2 * lse_square_scale * lse * X_block | ||
# smoothing term | ||
X_block += -eps | ||
# special handle dx_y | ||
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) | ||
# reduction scale | ||
if reduction == "mean": | ||
X_block = X_block / n_non_ignore | ||
else: | ||
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) | ||
softmax_X = tl.exp(X_block - m) / d | ||
# derivative of original_loss | ||
dloss_ori = (1 - label_smoothing) * softmax_X | ||
# specially handle dx_y | ||
dloss_ori = tl.where( | ||
X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing) | ||
) | ||
dloss_ori = dloss_ori * weight_y | ||
# derivative of smooth_loss | ||
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) | ||
# derivative of z-loss | ||
dz_loss = 2 * lse_square_scale * lse * softmax_X | ||
# reduction scale | ||
if reduction == "mean": | ||
dloss_ori = dloss_ori / sum_non_ignore_weight | ||
dloss_smooth = dloss_smooth / sum_non_ignore_weight | ||
dz_loss = dz_loss / n_non_ignore | ||
# derivative of total_loss | ||
X_block = dloss_ori + dloss_smooth + dz_loss | ||
|
||
# chain rule softcapping | ||
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) | ||
if HAS_SOFTCAPPING: | ||
X_block = X_block * (1 - intermediate * intermediate) | ||
|
@@ -182,6 +225,8 @@ def liger_cross_entropy_kernel( | |
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 | ||
# So we can safely calculate log (softmax(X_y)) without overflow | ||
loss = lse - ori_X_y | ||
if HAS_WEIGHT: | ||
loss = weight_y * loss | ||
|
||
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps | ||
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) | ||
|
@@ -192,17 +237,23 @@ def liger_cross_entropy_kernel( | |
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 | ||
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 | ||
if label_smoothing > 0: | ||
smooth_loss = scaled_x_sum + label_smoothing * lse | ||
if HAS_WEIGHT: | ||
smooth_loss = scaled_x_sum + eps * lse * weight_sum | ||
else: | ||
smooth_loss = scaled_x_sum + label_smoothing * lse | ||
loss = loss * (1 - label_smoothing) + smooth_loss | ||
|
||
# An auxiliary loss, z_loss | ||
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html | ||
z_loss = lse_square_scale * lse * lse | ||
loss += z_loss | ||
# Normalize the loss by the number of non-ignored elements if reduction is "mean" | ||
if reduction == "mean": | ||
if HAS_WEIGHT: | ||
loss = loss / sum_non_ignore_weight | ||
else: | ||
loss = loss / n_non_ignore | ||
z_loss = z_loss / n_non_ignore | ||
loss = loss / n_non_ignore | ||
loss += z_loss | ||
Comment on lines
250
to
+256
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you clarify the change here? I was thinking that if there's any missing part for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. zloss wasn't scaled by weight right now, so it is divided by number of non ignore token, unlike the rest part of loss divided by the sum of weight when weight exists. |
||
|
||
tl.store(loss_ptr, loss) | ||
if RETURN_Z_LOSS == _TRUE: | ||
|
@@ -224,6 +275,7 @@ def liger_cross_entropy_kernel( | |
def cross_entropy_forward( | ||
_input, | ||
target, | ||
weight, | ||
ignore_index, | ||
lse_square_scale, | ||
label_smoothing, | ||
|
@@ -253,7 +305,25 @@ def cross_entropy_forward( | |
else: | ||
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False | ||
|
||
n_non_ignore = (target != ignore_index).sum().item() | ||
target_mask = target != ignore_index | ||
n_non_ignore = target_mask.sum().item() | ||
sum_non_ignore_weight = n_non_ignore | ||
weight_sum = 0.0 | ||
if weight is not None: | ||
assert ( | ||
weight.shape[0] == V | ||
), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" | ||
assert torch.is_floating_point( | ||
weight | ||
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" | ||
sum_non_ignore_weight = ( | ||
torch.gather(weight, dim=0, index=target.masked_select(target_mask)) | ||
.sum() | ||
.item() | ||
) | ||
weight_sum = weight.sum().item() | ||
if weight.stride(-1) != 1: | ||
weight = weight.contiguous() | ||
|
||
# ensure _input and target are contiguous in the last dimension | ||
if _input.stride(-1) != 1: | ||
|
@@ -267,18 +337,22 @@ def cross_entropy_forward( | |
X_stride=_input.stride(-2), | ||
Y_ptr=target, | ||
Y_stride=target.stride(-1), # always 1 | ||
weight_ptr=weight if weight is not None else _input, # dummy if None | ||
loss_ptr=loss_1d, | ||
z_loss_ptr=z_loss_1d, | ||
loss_stride=loss_1d.stride(-1), # always 1 | ||
n_cols=V, | ||
n_non_ignore=n_non_ignore, | ||
sum_non_ignore_weight=sum_non_ignore_weight, | ||
ignore_index=ignore_index, | ||
weight_sum=weight_sum, | ||
lse_square_scale=lse_square_scale, | ||
label_smoothing=label_smoothing, | ||
reduction=reduction, | ||
softcap=softcap if softcap is not None else 0.0, | ||
RETURN_Z_LOSS=return_z_loss, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
HAS_WEIGHT=True if weight is not None else False, | ||
HAS_SOFTCAPPING=True if softcap is not None else False, | ||
# TODO: 32 seems to give the best performance | ||
# Performance is quite sensitive to num_warps | ||
|
@@ -330,6 +404,7 @@ def forward( | |
ctx, | ||
_input: torch.Tensor, | ||
target: torch.Tensor, | ||
weight: Optional[torch.FloatTensor], | ||
ignore_index: int = -100, | ||
lse_square_scale: float = 0.0, | ||
label_smoothing: float = 0.0, | ||
|
@@ -344,6 +419,7 @@ def forward( | |
ctx : The context object. | ||
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. | ||
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. | ||
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype | ||
ignore_index (int): The index to ignore in the target. | ||
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. | ||
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. | ||
|
@@ -357,6 +433,7 @@ def forward( | |
loss, z_loss, _input = cross_entropy_forward( | ||
_input, | ||
target, | ||
weight, | ||
ignore_index, | ||
lse_square_scale, | ||
label_smoothing, | ||
|
@@ -398,4 +475,5 @@ def backward(ctx, grad_output, grad_ouput2): | |
None, | ||
None, | ||
None, | ||
None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it will be better if we use
if HAS_WEIGHT
instead ofif not HAS_WEIGHT
to align all the other behaviors in this changeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought putting the base case (none weight) first would be better, but I'll consider it, thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see your point. Never mind. I was a little too nitpicking.