diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index 32ed147415b96..d45d4da027d02 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -361,7 +362,8 @@ std::optional MatchFp8Param(HloInstruction *instr) { // dimension. There must be only one contracting and only one non-contracting // dimension. Keeps the layout the same. HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, - absl::Span batch_dims) { + absl::Span batch_dims, + bool col_maj = false) { // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. std::vector permutation(instr->shape().dimensions_size(), -1); @@ -376,13 +378,42 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, non_contracting_dim = i; } } - permutation[non_contracting_dim] = contracting_dim; - permutation[contracting_dim] = non_contracting_dim; + if (!col_maj) { + permutation[non_contracting_dim] = contracting_dim; + permutation[contracting_dim] = non_contracting_dim; - Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); - *new_shape.mutable_layout() = instr->shape().layout(); - return instr->AddInstruction( - HloInstruction::CreateTranspose(new_shape, instr, permutation)); + Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); + *new_shape.mutable_layout() = instr->shape().layout(); + + return instr->AddInstruction( + HloInstruction::CreateTranspose(new_shape, instr, permutation)); + } + + Shape normalized_input_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + instr->shape()); + auto a0 = MakeBitcastHlo(instr, normalized_input_shape); + + std::vector layout_permuation( + instr->shape().layout().minor_to_major().begin(), + instr->shape().layout().minor_to_major().end()); + absl::c_reverse(layout_permuation); + auto inv_perm = InversePermutation(layout_permuation); + + int new_contracting_dim = inv_perm[contracting_dim]; + int new_non_contracting_dim = inv_perm[non_contracting_dim]; + absl::c_iota(permutation, 0); + std::swap(permutation[new_contracting_dim], + permutation[new_non_contracting_dim]); + + Shape transpose_shape = + ShapeUtil::PermuteDimensions(permutation, a0->shape()); + *transpose_shape.mutable_layout() = a0->shape().layout(); + HloInstruction *normalized_transpose = instr->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); + Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); + *final_shape.mutable_layout() = instr->shape().layout(); + return MakeBitcastHlo(normalized_transpose, final_shape); } // If the bias is a sequence of ops that depend only on broadcasts of @@ -1223,8 +1254,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } else { dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } - a.fp8_input = - TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims); + a.fp8_input = TransposeMatrix(a.fp8_input, a_contracting_dims[0], + a_batch_dims, /*col_maj*/ true); } // Similarly, cuBLASLt requires the second operand to be column-major, so diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index 0a6d6360587ac..f59f97e2c1b85 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -5032,6 +5032,58 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { + const char* hlo_text = R"( +HloModule test + ENTRY test { + x = <>[2,16,32]{1,0,2} parameter(0) + y = <>[2,32,16]{2,1,0} parameter(1) + x_scale = f32[] parameter(2) + y_scale = f32[] parameter(3) + dq_scale = f32[] multiply(x_scale, y_scale) + dq_scale_bcast = f32[2,16,16] broadcast(dq_scale), dimensions={} + out.0 = f32[2,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,16,16] multiply(out.0, dq_scale_bcast) + } +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite( + hlo_text, + GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,16,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[2,16,32]{1,0,2} parameter(0) +; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[32,2,16]{2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[16,2,32]{2,1,0} transpose([[P0_BT]]), dimensions={2,1,0} +; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[2,32,16]{1,0,2} bitcast([[P0_TR]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[2,32,16]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["2"] +; CHECK-DAG: "lhs_batch_dimensions":["0"] +; CHECK-DAG: "rhs_batch_dimensions":["0"] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { const char* hlo_text = R"( HloModule test