-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is it any plan for overlap and fuse GEMM for Jax? #320
Comments
TE project doesn't plan to add more parallelism then what each frameworks already support. TE try to support as many parallelism it can. We are actively working on that for JAX. |
A few days ago I asked core developers of Alpa——Yonghao Zhuang and Hao Zhang, and they told me that compiler optimization for DL model is no longer suitable for this era, and asked me to ask Nvidia. In fact, for example, alpa's pipeline parallelism is difficult to integrate with Jax, and sharding constraint in Jax using to support sequence parallelism is difficult to integrate with Jax. Their final recommendation was not to use Alpa/Jax without TPU. |
But if we just use Megatron framework, It has a lot of limitations. So it's there any roadmap for more framework supporting, and more technique? Such as https://arxiv.org/abs/2105.05720. |
Quick update, we are adding SP in this PR: We changed at the end of last year how TE try to parallelize. It was using xmap(so hardcoding some cases), not it use custom_partitioning. So now, all TE operations should ack as native XLA operations and should respect uses of with_sharding_constraint(). This way, end users should be able to trigger all SPMD parallelism only by setting the input/output sharding or by adding with_sharding_constraint() at the right place. The PR above, make it even simpler for SP. Note, for the computation/communication overlap, this is works that is started in XLA. TE/JAX can't control that. There is some XLA_FLAGS that allow to enable more or play with some configuration options. Models in JAX-Toolbox use some of them for speed up. We are hoping to enable more of those cases by default over the year. |
@nouiz I have seen this update, thank you for your work. |
The last question was same as NVIDIA/JAX-Toolbox#502 |
By combining with Alpa?
The text was updated successfully, but these errors were encountered: