diff --git a/xla/hlo/analysis/hlo_dataflow_analysis.cc b/xla/hlo/analysis/hlo_dataflow_analysis.cc index 2743067dd8802..9d024ffb7e33a 100644 --- a/xla/hlo/analysis/hlo_dataflow_analysis.cc +++ b/xla/hlo/analysis/hlo_dataflow_analysis.cc @@ -58,6 +58,7 @@ int64_t CalculatePostOrderScheduleHelper( int64_t ordinal = start_ordinal; for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kAsyncStart || instruction->opcode() == HloOpcode::kConditional) { for (const HloComputation* called_computation : instruction->called_computations()) { @@ -75,6 +76,8 @@ int64_t CalculatePostOrderScheduleHelper( // flatten (meaning we could have multiple callers for one computation). In // that case the oridinal_map will see the instruction multiple times. We // consider that case to be ok as it only shows up in unit tests. + VLOG(4) << "Add instruction " << instruction->name() + << " to ordinal map with ordinal " << ordinal; ordinal_map->insert({instruction, ordinal++}); } return ordinal; @@ -1293,6 +1296,8 @@ void HloDataflowAnalysis::Propagate() { auto add_to_worklist = [&priority_map, &worklist, &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { + VLOG(4) << "Add " << instruction->name() << " to worklist with priority " + << priority_map[instruction]; worklist.emplace(priority_map[instruction], instruction); } }; @@ -1316,7 +1321,7 @@ void HloDataflowAnalysis::Propagate() { workset.erase(workset.find(instruction)); - VLOG(3) << "Worklist top: " << instruction->name(); + VLOG(4) << "Worklist top: " << instruction->name(); XLA_VLOG_LINES(3, ToString()); if (!UpdateInstructionValueSet(instruction)) { diff --git a/xla/hlo/analysis/hlo_dataflow_analysis_test.cc b/xla/hlo/analysis/hlo_dataflow_analysis_test.cc index ea28bc9d7aa48..a1837b1406ecc 100644 --- a/xla/hlo/analysis/hlo_dataflow_analysis_test.cc +++ b/xla/hlo/analysis/hlo_dataflow_analysis_test.cc @@ -1189,6 +1189,68 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { } } +TEST_P(HloDataflowAnalysisTest, AsyncCallWithConditional) { + std::string hlo_str = R"( +HloModule AsyncCall + +%cond_computation.1 (param_0: f32[4096]) -> f32[4096] { + ROOT %param_0 = f32[4096]{0} parameter(0) +} + +%cond_computation.2 (param_1: f32[4096]) -> f32[4096] { + %param_1 = f32[4096]{0} parameter(0) + ROOT %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1) +} + +%called_computation (param_0: pred[], param_1: f32[4096]) -> f32[4096] { + %param_0 = pred[] parameter(0) + %param_1 = f32[4096]{0} parameter(1) + ROOT %conditional = f32[4096]{0} conditional(pred[] %param_0, f32[4096]{0} %param_1, f32[4096]{0} %param_1), true_computation=%cond_computation.1, false_computation=%cond_computation.2 +} + +ENTRY %main (a: f32[4096], pred: pred[]) -> f32[4096] { + %a = f32[4096]{0} parameter(1) + %p = pred[] parameter(0) + %async-start = ((pred[], f32[4096]{0}), f32[4096]{0}, u32[]) call-start(pred[] %p, f32[4096]{0} %a), to_apply=%called_computation + ROOT %async-done = f32[4096]{0} call-done(%async-start) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + module_, ParseAndReturnVerifiedModule(hlo_str, GetModuleConfigForTest())); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + const HloInstruction* a = FindInstruction(module_.get(), "a"); + const HloInstruction* p = FindInstruction(module_.get(), "p"); + // const HloInstruction* async_done = + // FindInstruction(module_.get(), "async-done"); + const HloInstruction* conditional = + FindInstruction(module_.get(), "conditional"); + + for (std::string async_name : {"async-start", "async-done"}) { + const HloInstruction* async_op = FindInstruction(module_.get(), async_name); + const HloComputation* called_computation = + async_op->async_wrapped_instruction()->called_computations()[0]; + const HloInstruction* parameter0 = + called_computation->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter0)); + EXPECT_THAT(HloValuesAt(parameter0), + UnorderedElementsAre(&analysis.GetValueDefinedAt(p))); + const HloInstruction* parameter1 = + called_computation->parameter_instruction(1); + EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter1)); + EXPECT_THAT(HloValuesAt(parameter1), + UnorderedElementsAre(&analysis.GetValueDefinedAt(a))); + if (ssa_form) { + EXPECT_EQ(HloValuesAt(conditional).size(), 1); + EXPECT_TRUE(HloValuesAt(conditional)[0]->is_phi()); + } else { + EXPECT_EQ(HloValuesAt(conditional).size(), 2); + } + } +} + TEST_P(HloDataflowAnalysisTest, TupleShapedAsyncOp) { std::string hlo_str = R"( HloModule module