Skip to content

Commit

Permalink
[xla:gpu][NFC] Refactor and rename in address_computation_fusion_rewr…
Browse files Browse the repository at this point in the history
…iter for clarity

PiperOrigin-RevId: 619954961
  • Loading branch information
tyb0807 authored and copybara-github committed Mar 28, 2024
1 parent ac143d0 commit cf3ccda
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions xla/service/gpu/address_computation_fusion_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ Status CreateRootTuple(HloInstruction* hero, HloComputation::Builder& builder,
}

absl::StatusOr<HloComputation*> CreateFusionBody(
HloModule* module, absl::Span<HloInstruction* const> operand_matches,
HloModule* module, absl::Span<HloInstruction* const> sliced_operand_paths,
DefUseDataflowPaths sliced_user_paths,
absl::Span<HloInstruction* const> captures) {
HloComputation::Builder builder("address-computation");
Expand Down Expand Up @@ -342,7 +342,7 @@ absl::StatusOr<HloComputation*> CreateFusionBody(
// Instructions in the pattern are already topologically sorted, as we visited
// them following use-def path, then reverse the list.
HloInstruction* hero;
for (HloInstruction* instr : operand_matches) {
for (HloInstruction* instr : sliced_operand_paths) {
instr_mapping[instr] = builder.AddInstruction(
instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr)));
hero = instr;
Expand Down Expand Up @@ -441,36 +441,36 @@ absl::StatusOr<bool> AddressComputationFusionRewriter::Run(
if (matches.empty()) return false;

HloSchedule& schedule = module->schedule();
for (auto& kv : matches) {
auto& [operand_matches, sliced_user_paths] = kv.second;
std::vector<HloInstruction*> matches;
absl::c_copy(operand_matches, std::back_inserter(matches));
for (auto& [hero, paths] : matches) {
auto& [sliced_operand_paths, sliced_user_paths] = paths;
std::vector<HloInstruction*> matched_instrs;
absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs));

for (auto& sliced_user_path : sliced_user_paths)
absl::c_copy(sliced_user_path, std::back_inserter(matches));
absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs));

auto captures = GetPatternCaptures(matches);
auto captures = GetPatternCaptures(matched_instrs);

TF_ASSIGN_OR_RETURN(HloComputation * fusion_body,
CreateFusionBody(module, operand_matches,
CreateFusionBody(module, sliced_operand_paths,
sliced_user_paths, captures));

TF_ASSIGN_OR_RETURN(HloInstruction * fusion,
CreateFusionInstruction(module, kv.first, captures,
CreateFusionInstruction(module, hero, captures,
fusion_body, dynamic));

// As we are running after scheduling we have to keep it valid.
HloComputation* parent = kv.first->parent();
HloComputation* parent = hero->parent();
// Update schedule to replace the custom call instruction with the fusion
// instruction.
// Removal of the rest of the instructions in the sequence is handled by
// schedule update below.
HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent);
sequence.replace_instruction(kv.first, fusion);
sequence.replace_instruction(hero, fusion);

if (fusion->shape().IsTuple()) {
TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape(
const_cast<HloInstruction*>(kv.first), fusion));
const_cast<HloInstruction*>(hero), fusion));
for (auto& sliced_user_path : sliced_user_paths) {
auto old_gte =
Cast<HloGetTupleElementInstruction>(sliced_user_path.front());
Expand All @@ -481,28 +481,29 @@ absl::StatusOr<bool> AddressComputationFusionRewriter::Run(
parent->ReplaceInstruction(sliced_user_path.back(), gte));
}
} else {
auto* old_instr = const_cast<HloInstruction*>(kv.first);
auto* instr_to_be_replaced = const_cast<HloInstruction*>(hero);
if (sliced_user_paths.empty()) {
// The only case where a tuple-shaped original hero op is fused into a
// non-tuple-shaped fusion is there's only one element of the original
// tuple being used. In that case, we need to replace that single
// get-tuple-element (instead of the hero op) with the fusion
// instruction.
if (kv.first->shape().IsTuple()) {
if (kv.first->user_count() != 1 ||
if (hero->shape().IsTuple()) {
if (hero->user_count() != 1 ||
!DynCast<HloGetTupleElementInstruction>(
kv.first->users().front())) {
hero->users().front())) {
return absl::InternalError(
"Expect a single get-tuple-element user of the original "
"tuple-shaped hero op when address computation fusion does "
"not return a tuple");
}
old_instr = kv.first->users().front();
instr_to_be_replaced = hero->users().front();
}
} else {
old_instr = sliced_user_paths.front().back();
instr_to_be_replaced = sliced_user_paths.front().back();
}
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(old_instr, fusion));
TF_RETURN_IF_ERROR(
parent->ReplaceInstruction(instr_to_be_replaced, fusion));
}
}

Expand Down

0 comments on commit cf3ccda

Please sign in to comment.