Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #19275: [NVIDIA] Add fixes for supporting determinism expander for…
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- 3b7b56a by Chenhao Jiang <[email protected]>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <[email protected]>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- 126c952 by Chenhao Jiang <[email protected]>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- 1ecb608 by Chenhao Jiang <[email protected]>: Fix the scatter determinism expander for various dimension numbers -- 985079f by Chenhao Jiang <[email protected]>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
- Loading branch information