-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Need Matmul Attention layer instead of Einsum to support GPU running #502
Comments
Why do you say cudnn GEMM isn't used? Normally it should. Can you provide an example where cudnn gemm isn't used in such a case? |
@nouiz Yes, TransformerEngine did use cudnn GEMM. But JAX(Flax or Praxis) attention layers was constructed by Einsum kernels, which couldn't' be lowered to cudnn GEMM and the latest cudnn XLA FMHA kernel. When running attention layers in GPU, it could be only transformed to triton kernel according the XLA dump log... TE currently only supports a limited number of transformer models (such as MOE is difficult to support) and does not yet support LORA SFT. So it may be necessary to optimize the layer composition of the Jax ecosystem. Sorry, I'm not sure where to put the requirement because it doesn't look like the TE team should be responsible for it. As I understand it, the TE team is only responsible for the 'custom_call' part of jax. |
I think you will be interested by this PR: jax-ml/jax#18814 |
@nouiz Cool, thank you! |
Einsum kernel in Praxis couldn't' be lowered to cudnn GEMM. The computing performance is seriously affected. Jax version Attention layer much slower than Tensorflow version.
The text was updated successfully, but these errors were encountered: