Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple's cross entropy computation #391

Open
fzyzcjy opened this issue Nov 17, 2024 · 8 comments
Open

Apple's cross entropy computation #391

fzyzcjy opened this issue Nov 17, 2024 · 8 comments

Comments

@fzyzcjy
Copy link

fzyzcjy commented Nov 17, 2024

Hi thanks for the library! Today I see a paper https://openreview.net/forum?id=E4Fk3YuG56 (code: https://github.com/apple/ml-cross-entropy), which seems to discuss a way to compute cross entropy. Thus I share this here in case it is useful for this repository.

@leng-yue
Copy link

leng-yue commented Dec 3, 2024

Hi @fzyzcjy , did you try the original repo? Does it work as expected?

@fzyzcjy
Copy link
Author

fzyzcjy commented Dec 3, 2024

Hi, no I have not tried it yet

@andersonbcdefg
Copy link

I ran into problems using/installing Apple's kernel from that repo. I assume that the triton etc. that they did is sound and does what it says it does, but it's just research code and isn't well tested for many versions/platforms. Would be amazing to have it be part of liger-kernel, because everything here is well tested and "just works" out of the box.

@ByronHsu
Copy link
Collaborator

ByronHsu commented Dec 5, 2024

We are more than happy to host and maintain innovative kernels like https://github.com/apple/ml-cross-entropy. @erikwijmans are you interested in collaboration? we are committed to long-term maintenance at the company level

@andersonbcdefg
Copy link

FYI @ByronHsu this is a very simple reproduction of why the cce kernel isn't working for me. I know you use Modal for CI so you should pretty easily able to reproduce this. I wish I could debug it myself but I am not a Triton god like you. This is the most simple setup I can think of: fresh install of only the cut cross entropy package, attempt to import linear_cross_entropy, error happens. I did not make any modifications to the code.

This uses python version 3.10, triton==3.1.0, torch==2.5.1.

import modal

image = modal.Image.debian_slim(python_version="3.10").apt_install("git").pip_install(
    "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git"
)

app = modal.App("cut-cross-entropy")

@app.function(
    image=image,
    gpu=modal.gpu.A10G()
)
def test_cce():
    from cut_cross_entropy import linear_cross_entropy
    print("success!")

Backtrace:

Traceback (most recent call last):
  File "/root/test.py", line 14, in test_cce
    from cut_cross_entropy import linear_cross_entropy
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/__init__.py", line 2, in <module>
    from cut_cross_entropy.linear_cross_entropy import (
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/linear_cross_entropy.py", line 20, in <module>
    from cut_cross_entropy.cce import cce_linear_cross_entropy
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/cce.py", line 7, in <module>
    from cut_cross_entropy.cce_backward import cce_backward_kernel
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/cce_backward.py", line 81, in <module>
    def _cce_backward_kernel(
  File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 882, in jit
    return decorator(fn)
  File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 871, in decorator
    return JITFunction(
  File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 717, in __init__
    self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
AttributeError: 'NoneType' object has no attribute 'start'

@andersonbcdefg
Copy link

OK, turns out the problem is something related to triton's regexp search for the source code + applying multiple decorators. The fix is to comment out the decorators on cce_backward_kernel:

# @cce_backward_autotune()
# @triton.heuristics(
#     {
#         "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
#         "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
#         "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
#         "HAS_VALIDS": lambda args: args["Valids"] is not None,
#         "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
#         "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
#         "HAS_TARGETS": lambda args: args["Targets"] is not None,
#         "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
#         "ITEM_DO": lambda args: args["dOut"].numel() == 1,
#         "GROUP_B": lambda args: 8,
#     }
# )
# @triton.jit
def _cce_backward_kernel(

...and instead apply them "manually" like this:

_cce_backward_kernel = triton.jit(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(
    {
        "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
        "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
        "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
        "HAS_VALIDS": lambda args: args["Valids"] is not None,
        "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
        "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
        "HAS_TARGETS": lambda args: args["Targets"] is not None,
        "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
        "ITEM_DO": lambda args: args["dOut"].numel() == 1,
        "GROUP_B": lambda args: 8,
    }
)(_cce_backward_kernel)
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)

@ccdv-ai
Copy link

ccdv-ai commented Dec 10, 2024

This has been merged to https://github.com/apple/ml-cross-entropy
Any update on CCE integration?

@amazingvince
Copy link

amazingvince commented Dec 11, 2024

I got CCE working with transformers but it was a hacked mess. The unsloth guys just announced that its now supported in their new blog post here.

here is the patch of the forward:
https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L932

Here is where they bring in CCE
https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/loss_utils.py#L139

I am not an expert on the kernel side but can help with the integration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants