We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
loss = LigerFusedLinearCrossEntropyLoss(reduction='none')(model.lm_head.weight, flattened_hidden_states, flattened_target)
returns loss with shape []: tensor(209594.4062, device='cuda:7', grad_fn=)
reduction is actually perfromed
from liger_kernel.transformers import LigerCrossEntropyLoss, LigerFusedLinearCrossEntropyLoss device3 = 'cuda' weight = torch.randn((180000, 4096), device=device3, dtype=torch.float32) fhidden_states = torch.randn((20, 4096), device=device3, dtype=torch.float32) ftarget = torch.ones((20,), device=device3, dtype=torch.long) loss = LigerFusedLinearCrossEntropyLoss(reduction='none')(weight, fhidden_states, ftarget) print(loss)
Operating System: Linux-5.4.0-135-generic-x86_64-with-glibc2.31 Python version: 3.10.14 Liger Kernel version: 0.5.2 PyTorch version: 2.5.1+cu124 CUDA version: 12.4 HIP(ROCm) version: Not available Triton version: 3.1.0 Transformers version: 4.44.0 XPU version: XPU Not Available
The text was updated successfully, but these errors were encountered:
Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py
Line 136 in 15a2f58
It should be easily fixed by removing torch.sum() if reduction is "none". Similar to what LigerCrossEntropy does.
Sorry, something went wrong.
Successfully merging a pull request may close this issue.
🐛 Describe the bug
loss = LigerFusedLinearCrossEntropyLoss(reduction='none')(model.lm_head.weight, flattened_hidden_states, flattened_target)
returns loss with shape []:
tensor(209594.4062, device='cuda:7',
grad_fn=)
reduction is actually perfromed
Reproduce
Versions
Environment Report:
Operating System: Linux-5.4.0-135-generic-x86_64-with-glibc2.31
Python version: 3.10.14
Liger Kernel version: 0.5.2
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
HIP(ROCm) version: Not available
Triton version: 3.1.0
Transformers version: 4.44.0
XPU version: XPU Not Available
The text was updated successfully, but these errors were encountered: