From 97ef4ef2f8385c2f2d63cfb05b9a06d6c87879c5 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 15 Nov 2024 12:15:02 -0800 Subject: [PATCH] PR #18825: [GPU] GEMM fusions: let fusing effective parameters and their broadcasts in the epilogues. Imported from GitHub PR https://github.com/openxla/xla/pull/18825 Copybara import of the project: -- 37dc0d2f706bf681b1eff2088eb8d1000abf79b8 by Ilia Sergachev : [GPU] GEMM fusions: let fusing effective parameters and their broadcasts in the epilogues. Merging this change closes #18825 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/18825 from openxla:gemm_fusion_effective_parameters 37dc0d2f706bf681b1eff2088eb8d1000abf79b8 PiperOrigin-RevId: 696963468 --- xla/hlo/utils/hlo_query.cc | 7 +++++++ xla/hlo/utils/hlo_query.h | 4 ++++ .../gpu/transforms/gemm_fusion_test.cc | 19 +++++++++++++++++++ xla/service/gpu/triton_tiling_propagation.cc | 13 +++++++++---- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index 90b6ddfd4d2b2..1fd417a8a2cdd 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -178,6 +178,13 @@ bool IsBroadcastOfParameter(const HloInstruction& instr) { instr.operand(0)->opcode() == HloOpcode::kParameter; } +bool IsEffectiveParameter(const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kParameter || + ((instr.opcode() == HloOpcode::kBitcast || + instr.opcode() == HloOpcode::kGetTupleElement) && + IsEffectiveParameter(*instr.operand(0))); +} + HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, const HloOpcode opcode) { auto instructions = computation.instructions(); diff --git a/xla/hlo/utils/hlo_query.h b/xla/hlo/utils/hlo_query.h index f219594024dc7..3612baf0803d5 100644 --- a/xla/hlo/utils/hlo_query.h +++ b/xla/hlo/utils/hlo_query.h @@ -79,6 +79,10 @@ bool IsBroadcastOfScalarConstant(const HloInstruction& instr); // Returns whether the `instr` is a broadcast and its input is a parameter. bool IsBroadcastOfParameter(const HloInstruction& instr); +// Returns true for a parameter or a parameter followed by a chain of no-op +// instructions (bitcast, get-tuple-element). +bool IsEffectiveParameter(const HloInstruction&); + // Returns first HLO of the computation with the opcode, otherwise nullptr. HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, HloOpcode opcode); diff --git a/xla/service/gpu/transforms/gemm_fusion_test.cc b/xla/service/gpu/transforms/gemm_fusion_test.cc index ebda001f2a7ba..509cc8d76b320 100644 --- a/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -1238,6 +1238,25 @@ ENTRY e { m::Parameter(), m::Parameter())))); } +TEST_F(GemmFusionTest, BroadcastsOfParametersAreFusedAsEpilogueInputs) { + auto module = ParseAndReturnVerifiedModule(R"( +e { + p0 = f16[4,55] parameter(0) + p1 = f16[123,55] parameter(1) + d = f16[4,123] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + p2 = (f16[123,1], f16[456]) parameter(2) + g = get-tuple-element(p2), index=0 + t = f16[123] bitcast(g) + b = f16[4,123] broadcast(t), dimensions={1} + m = f16[4,123] multiply(d, b) +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), + m::GetTupleElement())))); +} + // A test fixture class for testing the threshold for small matrices. class SmallDotGemmFusionTest : public GemmFusionTest { public: diff --git a/xla/service/gpu/triton_tiling_propagation.cc b/xla/service/gpu/triton_tiling_propagation.cc index d891d1cfe2134..87b1c9c58507c 100644 --- a/xla/service/gpu/triton_tiling_propagation.cc +++ b/xla/service/gpu/triton_tiling_propagation.cc @@ -1091,11 +1091,16 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( if (i == *src_operand_index) { continue; } - // Currently only broadcasts of scalars or parameters are accepted as - // other inputs of non-unary operations in the output fusion. + // Currently only + // - effective parameters + // - broadcasts of effective parameters + // - broadcasts of scalars + // are accepted as other inputs of non-unary operations in + // the output fusion. if ((operand->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operand->operand(0)->shape())) || - operand->opcode() == HloOpcode::kParameter) { + (ShapeUtil::IsScalar(operand->operand(0)->shape()) || + hlo_query::IsEffectiveParameter(*operand->operand(0)))) || + hlo_query::IsEffectiveParameter(*operand)) { continue; } return FusionDecision::Forbid(