Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight support for LigerCrossEntropy #420

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Add weight support for LigerCrossEntropy #420

wants to merge 14 commits into from

Conversation

Tcc0403
Copy link
Collaborator

@Tcc0403 Tcc0403 commented Dec 2, 2024

Summary

Resolve #404.
Note: current implementation doesn't weight z loss.

Reference: PyTorch's CrossEntropyLoss

Testing Done

It hasn't fully tested with other params.

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403 Tcc0403 requested review from pramodith and ByronHsu December 2, 2024 12:52
Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking care of this! Had a few minor suggestions.

Another TODO is based on the original paper linked in the original issue for this feature. We also need to support a sample level weight. i.e. a weight that can be applied to each element of the batch if we have logits in the shape (B, S, V). We'd have sample level weights of shape (B, ). This is what's proposed in the C-RLFT paper. https://arxiv.org/abs/2309.11235

src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
test/transformers/test_cross_entropy.py Outdated Show resolved Hide resolved
@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Dec 2, 2024

Feel free to push to this branch or even take over it and open a new PR, I won't be able to update that often in the next few months. Just trying to make the first step when I got time.

(1.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness_with_weight_with_other_params_once(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test couldn't pass somehow. I might miss something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the issue seems to be with combining label_smoothing with weighted loss. I've been staring at the code and equations for a while now but I can't pinpoint anything that's wrong. Simply multiplying the final loss with the weight of the label token seems like the right thing to do to me.

If not there can only be an issue with the:

scaled_x_sum term since all the other terms in smoothed loss are also a part of the plain ce loss which we know works correctly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figuring out where it doesn't work is a big! I'll take a look on Saturday.

@pramodith
Copy link
Collaborator

Feel free to push to this branch or even take over it and open a new PR, I won't be able to update that often in the next few months. Just trying to make the first step when I got time.

Gotcha! I'll try wrapping it up, you've done most of the heavy lifting already.

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Dec 8, 2024

I took a look at torch's impl, and here's how they compute smooth_loss
https://github.com/pytorch/pytorch/blob/2682e5e0d48a8200c1672b6a42250d3c8de44190/aten/src/ATen/native/LossNLL.cpp#L558

    if (weight.defined()) {
      // Expand weight to the correct number of dims for broadcasting with input / target
      auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
      std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
      weight_broadcast_shape[class_dim] = weight.size(0);
      Tensor weight_ = weight.view(weight_broadcast_shape);

      smooth_loss = -(input * weight_).sum(class_dim);

related code blocks in liger:

scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))

if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

selected_weight = torch.where(
target_mask, torch.gather(weight, dim=0, index=target * target_mask), 0.0
)
sum_of_non_ignore_weight = selected_weight.sum().item()
Copy link
Collaborator Author

@Tcc0403 Tcc0403 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can rewrite it with torch.masked_select

sum_of_non_ignore_weight = (torch.gather(weight, dim=0, index=target.masked_select(target_mask))
            .sum()
            .item()
        )

Refer to torch's impl mentioned above

@winglian
Copy link
Contributor

@pramodith anything I can do to help with this PR?

@pramodith
Copy link
Collaborator

@pramodith anything I can do to help with this PR?

Hey @winglian I won't be able to look into this any further, feel free to take over and see if you can figure out the source of discrepancy. The tests fail when combining smoothing loss with weighted ce.

@Tcc0403 Tcc0403 requested a review from pramodith December 22, 2024 07:31
@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Dec 22, 2024

I'll make an another PR for sample level weight.

@Tcc0403 Tcc0403 requested a review from austin362667 December 23, 2024 10:19
Copy link
Contributor

@bboyleonp666 bboyleonp666 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Tcc0403, thanks for your wonderful work. I left some of my thoughts for this PR, PTAL.

Comment on lines +175 to +207
if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(
X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)
)
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it will be better if we use if HAS_WEIGHT instead of if not HAS_WEIGHT to align all the other behaviors in this change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought putting the base case (none weight) first would be better, but I'll consider it, thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see your point. Never mind. I was a little too nitpicking.

Comment on lines 250 to +256
if reduction == "mean":
if HAS_WEIGHT:
loss = loss / sum_non_ignore_weight
else:
loss = loss / n_non_ignore
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore
loss += z_loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the change here? I was thinking that if there's any missing part for z_loss. I am not quite sure whether z_loss will be affected by weights or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zloss wasn't scaled by weight right now, so it is divided by number of non ignore token, unlike the rest part of loss divided by the sum of weight when weight exists.
That's why I have to do divisions first before summing them.

Comment on lines 57 to +59
# NOTE: skip .item() here to avoid CUDA synchronization
total_n_non_ignore = (target != ignore_index).sum()
target_mask = target != ignore_index
total_n_non_ignore = target_mask.sum().item()
Copy link
Contributor

@bboyleonp666 bboyleonp666 Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have noticed the comment above to avoid using .item() for synchronization issue. Will this change align this behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove the comment, it doesn't affect the result.

Comment on lines -223 to +245
_input,
weight,
target,
bias,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
_input=_input,
weight=weight,
target=target,
bias=bias,
ce_weight=ce_weight,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,

This comment was marked as off-topic.

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Dec 24, 2024

I'll update it on Saturday. Thanks for your review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Weighted Cross Entropy Loss
4 participants