From bcc98dcd1cd334a1aa833a1055a840bcd2ac87f5 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 24 Sep 2024 11:37:20 -0700 Subject: [PATCH] [xla] Avoid repeatedly traversing computations in a module by processing the computations in post-order. PiperOrigin-RevId: 678332958 --- .../while_loop_all_reduce_code_motion.cc | 163 +++++++++--------- 1 file changed, 81 insertions(+), 82 deletions(-) diff --git a/xla/service/while_loop_all_reduce_code_motion.cc b/xla/service/while_loop_all_reduce_code_motion.cc index c9b34c702efc3..c67a34628cc40 100644 --- a/xla/service/while_loop_all_reduce_code_motion.cc +++ b/xla/service/while_loop_all_reduce_code_motion.cc @@ -936,7 +936,7 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool is_changed = false; - bool run_next_pass = true; + // In case of MPMD, all-reduces might be cross-module and should preserve // their channel ID. Do not move all-reduces in this case since the channel // ID might be changed. @@ -965,96 +965,95 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( // loop. We recursively sink the all-reduce through nested while loops if // applicable by repeating this process. uint32_t count_all_reduce = 0, count_reduce_scatter = 0; - while (run_next_pass) { - run_next_pass = false; - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module); + // We process all callees of a computation before processing the computation, + // so that when we process a computation, the all-reduce instructions that + // need to be hoisted to the computation from its callees have been hoisted. + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { // A computation could be the while body of multiple while instructions, // so we start from the computation and find all of its callers that is a // kWhile if there is any. - for (HloComputation* computation : - module->computations(execution_threads)) { - std::vector computation_callers = - call_graph->GetComputationCallers(computation); - std::vector while_caller_instructions; - for (HloInstruction* caller_instruction : computation_callers) { - // For simplicity, we only support while instructions whose shape is - // tuple. - if (caller_instruction->opcode() == HloOpcode::kWhile && - caller_instruction->shape().IsTuple() && - caller_instruction->while_body() == computation) { - while_caller_instructions.push_back(caller_instruction); - } - } - // Skip to next computation if this computation is not the while body of - // any while instruction. - if (while_caller_instructions.empty()) { - continue; + std::vector computation_callers = + call_graph->GetComputationCallers(computation); + std::vector while_caller_instructions; + for (HloInstruction* caller_instruction : computation_callers) { + // For simplicity, we only support while instructions whose shape is + // tuple. + if (caller_instruction->opcode() == HloOpcode::kWhile && + caller_instruction->shape().IsTuple() && + caller_instruction->while_body() == computation) { + while_caller_instructions.push_back(caller_instruction); } - std::vector while_body_all_reduces; - for (HloInstruction* while_body_instruction : - computation->MakeInstructionPostOrder()) { - HloOpcode op = while_body_instruction->opcode(); - const bool is_candidate = - (op == HloOpcode::kAllReduce) || - (enable_reduce_scatter_ && op == HloOpcode::kReduceScatter); - if (!is_candidate) { - continue; - } - auto* all_reduce_instruction = - Cast(while_body_instruction); - if (all_reduce_instruction->constrain_layout()) { - return false; - } else { - while_body_all_reduces.push_back(all_reduce_instruction); - } - } - HloInstructionMap> - all_reduce_to_accumulations; - for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) { - auto movable_all_reduce_context = IsAllReduceMovable( - all_reduce, computation, cross_replica_replication_analysis, - cross_partition_replication_analysis); - if (movable_all_reduce_context.is_movable) { - all_reduce_to_accumulations[all_reduce] = - std::move(movable_all_reduce_context.accumulation_contexts); - } - VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: " - << all_reduce->ToString() - << " is_movable: " << movable_all_reduce_context.is_movable - << " while loop: " << while_caller_instructions.front()->name() - << " num_accumulations: " - << (movable_all_reduce_context.is_movable - ? all_reduce_to_accumulations[all_reduce].size() - : 0); - } - if (all_reduce_to_accumulations.empty()) { + } + // Skip to next computation if this computation is not the while body of + // any while instruction. + if (while_caller_instructions.empty()) { + continue; + } + std::vector while_body_all_reduces; + for (HloInstruction* while_body_instruction : + computation->MakeInstructionPostOrder()) { + HloOpcode op = while_body_instruction->opcode(); + const bool is_candidate = + (op == HloOpcode::kAllReduce) || + (enable_reduce_scatter_ && op == HloOpcode::kReduceScatter); + if (!is_candidate) { continue; } - // For each while instruction calling this computation, create the - // corresponding all-reduces after the while loop. - for (HloInstruction* while_instruction : while_caller_instructions) { - TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( - while_instruction, all_reduce_to_accumulations)); - is_changed = true; - run_next_pass = true; + auto* all_reduce_instruction = + Cast(while_body_instruction); + if (all_reduce_instruction->constrain_layout()) { + return false; + } else { + while_body_all_reduces.push_back(all_reduce_instruction); } - // At last, remove the old all-reduce instructions in the while body. - for (const auto& all_reduce_accumulations_pair : - all_reduce_to_accumulations) { - HloInstruction* all_reduce = all_reduce_accumulations_pair.first; - if (all_reduce->opcode() == HloOpcode::kAllReduce) { - count_all_reduce++; - } else { - count_reduce_scatter++; - } - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( - all_reduce, all_reduce->mutable_operand(0))); + } + HloInstructionMap> + all_reduce_to_accumulations; + for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) { + auto movable_all_reduce_context = IsAllReduceMovable( + all_reduce, computation, cross_replica_replication_analysis, + cross_partition_replication_analysis); + if (movable_all_reduce_context.is_movable) { + all_reduce_to_accumulations[all_reduce] = + std::move(movable_all_reduce_context.accumulation_contexts); } - // Needs to rebuild the call graph or we could access removed - // instructions. - if (run_next_pass) { - break; + VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: " + << all_reduce->ToString() + << " is_movable: " << movable_all_reduce_context.is_movable + << " while loop: " << while_caller_instructions.front()->name() + << " num_accumulations: " + << (movable_all_reduce_context.is_movable + ? all_reduce_to_accumulations[all_reduce].size() + : 0); + } + if (all_reduce_to_accumulations.empty()) { + continue; + } + // For each while instruction calling this computation, create the + // corresponding all-reduces after the while loop. + for (HloInstruction* while_instruction : while_caller_instructions) { + TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( + while_instruction, all_reduce_to_accumulations)); + is_changed = true; + } + // At last, remove the old all-reduce instructions in the while body. + for (const auto& all_reduce_accumulations_pair : + all_reduce_to_accumulations) { + HloInstruction* all_reduce = all_reduce_accumulations_pair.first; + if (all_reduce->opcode() == HloOpcode::kAllReduce) { + count_all_reduce++; + } else { + count_reduce_scatter++; } + TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + all_reduce, all_reduce->mutable_operand(0))); + } + // Needs to rebuild the call graph after we remove instructions to avoid + // accessing removed instructions. + if (!all_reduce_to_accumulations.empty()) { + call_graph = CallGraph::Build(module); } } VLOG(2) << "Hoisted " << count_all_reduce << " all-reduce and "