Skip to content

Commit

Permalink
[XLA:GPU] Add fp8 layout support to assign contrasting dim to be mino…
Browse files Browse the repository at this point in the history
…r most.

This is important for performance both for Triton and cuBLASLT FP8 Gemms. Due to GPU kernel constraints, XLA inserts an additional expensive transpose operation before the quantized gemm.

PiperOrigin-RevId: 679141198
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent 3f1d29c commit f7d0b34
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
5 changes: 4 additions & 1 deletion xla/service/gpu/transforms/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,11 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints(
output_shape.dimensions_size() == 2 &&
lhs_shape.dimensions_size() == 2 &&
rhs_shape.dimensions_size() == 2);
bool is_fp8_to_fp8 =
(lhs_shape.element_type() == PrimitiveType::F8E4M3FN &&
rhs_shape.element_type() == PrimitiveType::F8E4M3FN);

if (is_s8_to_s32 ||
if (is_s8_to_s32 || is_fp8_to_fp8 ||
(is_bf16_to_bf16 &&
debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) {
TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
Expand Down
29 changes: 29 additions & 0 deletions xla/service/gpu/transforms/layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,35 @@ ENTRY main {
LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major());
}

TEST_F(LayoutAssignmentTest, AutoLayoutE4M3ContractingMinorFirst) {
const char* hlo = R"(
HloModule jit_dot_general_f8e4m3fn
ENTRY main {
p0 = f8e4m3fn[128,5120] parameter(0)
p1 = f8e4m3fn[5120,10240] parameter(1)
ROOT dot = f32[128,10240] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> m,
ParseAndReturnUnverifiedModule(
hlo, {}, HloParserOptions().set_fill_missing_layouts(false)));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape(),
/*ignore_layouts=*/false);
GpuLayoutAssignment layout_assignment(
&computation_layout, GetGpuComputeCapability(), GetDnnVersion());
EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(
m::Dot(m::Parameter(0).WithShape(F8E4M3FN, {128, 5120}, {1, 0}),
m::Parameter(1).WithShape(F8E4M3FN, {5120, 10240}, {0, 1}))
.WithShape(F32, {128, 10240}, {1, 0})));
}

TEST_F(LayoutAssignmentTest, VariadicReduceSameOperandLayout) {
const char* module_str = R"(
HloModule variadic_reduce
Expand Down

0 comments on commit f7d0b34

Please sign in to comment.