Skip to content

Commit

Permalink
[XLA:GPU][NFC] Move addition of double buffering passes to a separate…
Browse files Browse the repository at this point in the history
… function.

PiperOrigin-RevId: 678626186
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Sep 25, 2024
1 parent e57ba4d commit e56fee4
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,43 @@ absl::Status RunFusionPasses(HloModule* hlo_module,
return absl::OkStatus();
}

// Adds unrolling while loop optimization. Mostly to get rid of extra D2D
// copies, but also there are some performance benefits (better comm-compute
// overlap) when collectives are present within a while loop.
void AddDoubleBufferingPasses(const DebugOptions& opts,
HloPassPipeline& pipeline) {
std::optional<DoubleBufferLoopUnrolling::UnrollStrategy> unroll_strategy =
std::nullopt;
// Support old flag.
if (opts.xla_gpu_enable_while_loop_double_buffering()) {
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer;
}
// Support new flag setting style, override the old one.
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) {
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer;
}
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) {
LOG_IF(WARNING, unroll_strategy != std::nullopt)
<< "Overriding double buffering set via "
"`xla_gpu_enable_while_loop_double_buffering` flag.";
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll;
}
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL &&
opts.xla_gpu_enable_heuristic_pass_configuration() &&
!opts.xla_gpu_enable_while_loop_double_buffering()) {
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto;
}
if (unroll_strategy != std::nullopt) {
pipeline.AddPass<WhileLoopSimplifier>();
pipeline.AddPass<DoubleBufferLoopUnrolling>(*unroll_strategy);
pipeline.AddPass<TupleSimplifier>();
pipeline.AddPass<HloDCE>();
}
}

absl::Status RunPostFusionPasses(
HloModule* hlo_module,
std::function<absl::Status(HloPassPipeline*, const DebugOptions&)>
Expand Down Expand Up @@ -1077,36 +1114,7 @@ absl::Status RunPostFusionPasses(
pipeline.AddPass<AllReduceBlueConnect>(blueconnect_num_devices_per_host);
}

std::optional<DoubleBufferLoopUnrolling::UnrollStrategy> unroll_strategy =
std::nullopt;
// Support old flag.
if (opts.xla_gpu_enable_while_loop_double_buffering()) {
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer;
}
// Support new flag setting style, override the old one.
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) {
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer;
}
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) {
LOG_IF(WARNING, unroll_strategy != std::nullopt)
<< "Overriding double buffering set via "
"`xla_gpu_enable_while_loop_double_buffering` flag.";
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll;
}
if (opts.xla_gpu_enable_while_loop_unrolling() ==
DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL &&
opts.xla_gpu_enable_heuristic_pass_configuration() &&
!opts.xla_gpu_enable_while_loop_double_buffering()) {
unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto;
}
if (unroll_strategy != std::nullopt) {
pipeline.AddPass<WhileLoopSimplifier>();
pipeline.AddPass<DoubleBufferLoopUnrolling>(*unroll_strategy);
pipeline.AddPass<TupleSimplifier>();
pipeline.AddPass<HloDCE>();
}
AddDoubleBufferingPasses(opts, pipeline);

return pipeline.Run(hlo_module).status();
}
Expand Down

0 comments on commit e56fee4

Please sign in to comment.