Skip to content

Commit

Permalink
[XLA] Don't use while/conditional back pointers
Browse files Browse the repository at this point in the history
Specifically HloComputation::WhileCallInstruction() and ConditionalCallInstruction(). These are broken because in too many places we don't update the references between the instruction and computation after cloning.

CallGraph::GetComputationCallers() can be used for the same purposes. Refactor InfeedTokenPropagation to use it.
PiperOrigin-RevId: 679345390
  • Loading branch information
vsytch authored and Google-ML-Automation committed Sep 27, 2024
1 parent ce63b93 commit 9f7cb88
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 124 deletions.
12 changes: 12 additions & 0 deletions xla/hlo/ir/hlo_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,35 +787,47 @@ class HloComputation {
}

// Returns if this computation is a body computation of a while.
[[deprecated(
"This is broken. Use CallGraph::GetComputationCallers() instead")]]
bool IsWhileBodyComputation() const {
return instruction_type() == InstructionType::kWhile;
}

// Returns the owning while call instruction, or nullptr if this is not a
// while call body computation.
[[deprecated(
"This is broken. Use CallGraph::GetComputationCallers() instead")]]
HloInstruction* WhileCallInstruction() const {
return instruction_type() == InstructionType::kWhile ? instruction()
: nullptr;
}

[[deprecated(
"This is broken. Use CallGraph::GetComputationCallers() instead")]]
void SetWhileCallInstruction(HloInstruction* while_call_instruction) {
CHECK(while_call_instruction != nullptr);
CHECK(while_call_instruction->opcode() == HloOpcode::kWhile);
SetInstruction(while_call_instruction, InstructionType::kWhile);
}

// Returns if this computation is a branch computation of a conditional.
[[deprecated(
"This is broken. Use CallGraph::GetComputationCallers() instead")]]
bool IsConditionalBranchComputation() const {
return instruction_type() == InstructionType::kConditional;
}

// Returns the owning conditional call instruction, or nullptr if this is not
// a conditional branch computation.
[[deprecated(
"This is broken. Use CallGraph::GetComputationCallers() instead")]]
HloInstruction* ConditionalCallInstruction() const {
return instruction_type() == InstructionType::kConditional ? instruction()
: nullptr;
}

[[deprecated(
"This is broken. Use CallGraph::GetComputationCallers() instead")]]
void SetConditionalCallInstruction(
HloInstruction* conditional_call_instruction) {
CHECK(conditional_call_instruction != nullptr);
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8649,6 +8649,7 @@ cc_library(
srcs = ["infeed_token_propagation.cc"],
hdrs = ["infeed_token_propagation.h"],
deps = [
":call_graph",
":hlo_dce",
":tuple_simplifier",
"//xla:shape_util",
Expand Down
Loading

0 comments on commit 9f7cb88

Please sign in to comment.