Skip to content

Commit

Permalink
PR #19275: [NVIDIA] Add fixes for supporting determinism expander for…
Browse files Browse the repository at this point in the history
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
  • Loading branch information
serach24 authored and Google-ML-Automation committed Nov 13, 2024
1 parent 619adc0 commit 896734c
Show file tree
Hide file tree
Showing 6 changed files with 1,201 additions and 142 deletions.
11 changes: 11 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_dot_merger_threshold_mb(32);
opts.set_xla_enable_fast_math(false);
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
return opts;
}

Expand Down Expand Up @@ -2046,6 +2047,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(),
"This controls how many in-flight collectives "
"latency hiding scheduler can schedule."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_scatter_determinism_expander",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_scatter_determinism_expander),
debug_options->xla_gpu_enable_scatter_determinism_expander(),
"Enable the scatter determinism expander, an optimized pass that "
"rewrites scatter operations to ensure deterministic behavior with high "
"performance."
"Note that even when this flag is disabled, scatter operations may still "
"be deterministic, although with additional overhead."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,7 @@ cc_library(
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
if (debug_options.xla_gpu_enable_scatter_determinism_expander()) {
pipeline.AddPass<ScatterDeterminismExpander>();
}
pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
Loading

0 comments on commit 896734c

Please sign in to comment.