-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: lower jnp.exp2 to tl.math.exp2 in pallas #204
Comments
This might be a pretty simple PR if you are up for contributing. Otherwise I will get around to it in a week or so. If interested, take a look at jax_triton/pallas/triton_lowering.py. |
Maybe look at an existing unary primitive lowering such as exp or tanh. |
The only problem is that I'm unsure whether I can support exp2 in the same way, given that in jax it's not a lax primitive, but rather a function implemented in terms of other lax primitives. |
Sure, let's do it! |
part of fixing jax-ml/jax-triton#204
jax-ml/jax#16883 fixes the JAX side. Once we merge that I can probably update jax-triton too (unless you want to take it @hr0nix ) |
part of fixing jax-ml/jax-triton#204
part of fixing jax-ml/jax-triton#204
part of fixing jax-ml/jax-triton#204
part of fixing jax-ml/jax-triton#204
Oh, I was too slow to react. Thanks a lot for a quick fix, guys! |
Hello guys, I tried using exp2 for this trick in Triton: For just 13 out of 49K elements! How odd! It may simply be the extreme values which result in large floating point errors. In my case, I implemented the trick for both the backward and fowards passes in https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py. Just implementing for forward or backward alone resulted in many more floating point errors. Just as an anecdote - do let me know if any of you guys succeed with this. |
I found the issue! Scaling before/after the float32 conversion makes all the difference! Now implemented in: jax-ml/jax#17328 |
The flash attention implementation in the triton repo uses exp2, but I can't use the same trick with pallas and I suspect it might be important.
The text was updated successfully, but these errors were encountered: