Skip to content

Commit

Permalink
kAsyncStart is missing from CalculatePostOrderScheduleHelper() which …
Browse files Browse the repository at this point in the history
…will

not initialize the ordinal/priority correctly for some instructions. Later
priority-queue-based worklist may not work in the correct order as a result.

PiperOrigin-RevId: 696968628
  • Loading branch information
Google-ML-Automation committed Nov 15, 2024
1 parent 97ef4ef commit 4c7e8aa
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
7 changes: 6 additions & 1 deletion xla/hlo/analysis/hlo_dataflow_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ int64_t CalculatePostOrderScheduleHelper(
int64_t ordinal = start_ordinal;
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kAsyncStart ||
instruction->opcode() == HloOpcode::kConditional) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
Expand All @@ -75,6 +76,8 @@ int64_t CalculatePostOrderScheduleHelper(
// flatten (meaning we could have multiple callers for one computation). In
// that case the oridinal_map will see the instruction multiple times. We
// consider that case to be ok as it only shows up in unit tests.
VLOG(4) << "Add instruction " << instruction->name()
<< " to ordinal map with ordinal " << ordinal;
ordinal_map->insert({instruction, ordinal++});
}
return ordinal;
Expand Down Expand Up @@ -1293,6 +1296,8 @@ void HloDataflowAnalysis::Propagate() {
auto add_to_worklist = [&priority_map, &worklist,
&workset](HloInstruction* instruction) {
if (workset.insert(instruction).second) {
VLOG(4) << "Add " << instruction->name() << " to worklist with priority "
<< priority_map[instruction];
worklist.emplace(priority_map[instruction], instruction);
}
};
Expand All @@ -1316,7 +1321,7 @@ void HloDataflowAnalysis::Propagate() {

workset.erase(workset.find(instruction));

VLOG(3) << "Worklist top: " << instruction->name();
VLOG(4) << "Worklist top: " << instruction->name();
XLA_VLOG_LINES(3, ToString());

if (!UpdateInstructionValueSet(instruction)) {
Expand Down
62 changes: 62 additions & 0 deletions xla/hlo/analysis/hlo_dataflow_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,68 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
}
}

TEST_P(HloDataflowAnalysisTest, AsyncCallWithConditional) {
std::string hlo_str = R"(
HloModule AsyncCall
%cond_computation.1 (param_0: f32[4096]) -> f32[4096] {
ROOT %param_0 = f32[4096]{0} parameter(0)
}
%cond_computation.2 (param_1: f32[4096]) -> f32[4096] {
%param_1 = f32[4096]{0} parameter(0)
ROOT %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1)
}
%called_computation (param_0: pred[], param_1: f32[4096]) -> f32[4096] {
%param_0 = pred[] parameter(0)
%param_1 = f32[4096]{0} parameter(1)
ROOT %conditional = f32[4096]{0} conditional(pred[] %param_0, f32[4096]{0} %param_1, f32[4096]{0} %param_1), true_computation=%cond_computation.1, false_computation=%cond_computation.2
}
ENTRY %main (a: f32[4096], pred: pred[]) -> f32[4096] {
%a = f32[4096]{0} parameter(1)
%p = pred[] parameter(0)
%async-start = ((pred[], f32[4096]{0}), f32[4096]{0}, u32[]) call-start(pred[] %p, f32[4096]{0} %a), to_apply=%called_computation
ROOT %async-done = f32[4096]{0} call-done(%async-start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
module_, ParseAndReturnVerifiedModule(hlo_str, GetModuleConfigForTest()));

bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);

const HloInstruction* a = FindInstruction(module_.get(), "a");
const HloInstruction* p = FindInstruction(module_.get(), "p");
// const HloInstruction* async_done =
// FindInstruction(module_.get(), "async-done");
const HloInstruction* conditional =
FindInstruction(module_.get(), "conditional");

for (std::string async_name : {"async-start", "async-done"}) {
const HloInstruction* async_op = FindInstruction(module_.get(), async_name);
const HloComputation* called_computation =
async_op->async_wrapped_instruction()->called_computations()[0];
const HloInstruction* parameter0 =
called_computation->parameter_instruction(0);
EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter0));
EXPECT_THAT(HloValuesAt(parameter0),
UnorderedElementsAre(&analysis.GetValueDefinedAt(p)));
const HloInstruction* parameter1 =
called_computation->parameter_instruction(1);
EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter1));
EXPECT_THAT(HloValuesAt(parameter1),
UnorderedElementsAre(&analysis.GetValueDefinedAt(a)));
if (ssa_form) {
EXPECT_EQ(HloValuesAt(conditional).size(), 1);
EXPECT_TRUE(HloValuesAt(conditional)[0]->is_phi());
} else {
EXPECT_EQ(HloValuesAt(conditional).size(), 2);
}
}
}

TEST_P(HloDataflowAnalysisTest, TupleShapedAsyncOp) {
std::string hlo_str = R"(
HloModule module
Expand Down

0 comments on commit 4c7e8aa

Please sign in to comment.