Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: lower jnp.exp2 to tl.math.exp2 in pallas #204

Open
hr0nix opened this issue Jul 18, 2023 · 10 comments
Open

Feature request: lower jnp.exp2 to tl.math.exp2 in pallas #204

hr0nix opened this issue Jul 18, 2023 · 10 comments
Assignees

Comments

@hr0nix
Copy link

hr0nix commented Jul 18, 2023

The flash attention implementation in the triton repo uses exp2, but I can't use the same trick with pallas and I suspect it might be important.

@sharadmv
Copy link
Collaborator

This might be a pretty simple PR if you are up for contributing. Otherwise I will get around to it in a week or so.

If interested, take a look at jax_triton/pallas/triton_lowering.py.

@sharadmv
Copy link
Collaborator

Maybe look at an existing unary primitive lowering such as exp or tanh.

@hr0nix
Copy link
Author

hr0nix commented Jul 19, 2023

The only problem is that I'm unsure whether I can support exp2 in the same way, given that in jax it's not a lax primitive, but rather a function implemented in terms of other lax primitives.

@sharadmv
Copy link
Collaborator

Interesting...we might want to change that in JAX then! cc: @froystig @mattjj

@hr0nix
Copy link
Author

hr0nix commented Jul 27, 2023

@froystig @mattjj Hey guys! Can you comment on whether this can be fixed in JAX in the foreseeable future?

@mattjj mattjj self-assigned this Jul 28, 2023
@mattjj
Copy link
Collaborator

mattjj commented Jul 28, 2023

Sure, let's do it!

mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
@mattjj
Copy link
Collaborator

mattjj commented Jul 28, 2023

jax-ml/jax#16883 fixes the JAX side. Once we merge that I can probably update jax-triton too (unless you want to take it @hr0nix )

mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
mattjj added a commit to mattjj/jax that referenced this issue Jul 28, 2023
mattjj added a commit that referenced this issue Jul 28, 2023
@hr0nix
Copy link
Author

hr0nix commented Jul 29, 2023

Oh, I was too slow to react. Thanks a lot for a quick fix, guys!

@jon-chuang
Copy link

jon-chuang commented Aug 28, 2023

Hello guys, I tried using exp2 for this trick in Triton:

https://github.com/openai/triton/blob/cc45c3826fecabc2d6882e4bd94464ab39c8f730/python/tutorials/06-fused-attention.py#L186

However, the bwd tests fail -
image

For just 13 out of 49K elements! How odd! It may simply be the extreme values which result in large floating point errors.

In my case, I implemented the trick for both the backward and fowards passes in https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py. Just implementing for forward or backward alone resulted in many more floating point errors.

Just as an anecdote - do let me know if any of you guys succeed with this.

@jon-chuang
Copy link

jon-chuang commented Aug 30, 2023

I found the issue! Scaling before/after the float32 conversion makes all the difference!

Now implemented in: jax-ml/jax#17328

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants