Skip to content

Commit

Permalink
[XLA:GPU] Do not fuse custom fusions in horizontal_input_fusion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679128773
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Sep 26, 2024
1 parent 4de285a commit 400e72c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
3 changes: 2 additions & 1 deletion xla/service/gpu/transforms/horizontal_input_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ std::vector<HloInstruction*> FindAndSortFusionCandidates(
// Find out the input fusion instructions whose only consumer is `consumer`.
// This guarantees that fusing these candidates will never create cycles, as
// there is no back edge.
if (IsInputFusibleReduction(*predecessor) &&
if (!predecessor->IsCustomFusion() &&
IsInputFusibleReduction(*predecessor) &&
IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
if (fusion_instr_set.insert(predecessor).second) {
fusion_instrs.push_back(predecessor);
Expand Down
32 changes: 32 additions & 0 deletions xla/service/gpu/transforms/horizontal_input_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,38 @@ TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
GmockMatch(m::Tuple(m::Reduce(), m::Reduce())));
}

TEST_F(HorizontalInputFusionTest, DoesNotFuseCustomFusions) {
auto module = ParseAndReturnVerifiedModule(R"(
max {
p0 = f16[] parameter(0)
p1 = f16[] parameter(1)
ROOT max = f16[] maximum(p0, p1)
}
triton_a {
p = f16[128,256] parameter(0)
c = f16[] constant(0)
ROOT n = f16[128] reduce(p, c), dimensions={1}, to_apply=max
}
triton_b {
p = f16[128,256] parameter(0)
c = f16[] constant(0)
ROOT n = f16[128] reduce(p, c), dimensions={1}, to_apply=max
}
ENTRY entry_computation {
p = f16[128,256] parameter(0)
fa = f16[128] fusion(p), kind=kCustom, calls=triton_a
fb = f16[128] fusion(p), kind=kCustom, calls=triton_b
ROOT tuple = (f16[128], f16[128]) tuple(fa, fb)
}
)")
.value();

EXPECT_FALSE(horizontal_input_fusion_.Run(module.get()).value());
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 400e72c

Please sign in to comment.