From 237c03240da3dce736d92c8273dc1f9d3be53af5 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 18 Sep 2024 12:16:01 -0700 Subject: [PATCH 1/8] Improve TransposeMatrix --- xla/service/gpu/transforms/gemm_rewriter.cc | 66 +++++++++++++++++---- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index 32ed147415b96..d29a084eb0ec3 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -361,14 +362,11 @@ std::optional MatchFp8Param(HloInstruction *instr) { // dimension. There must be only one contracting and only one non-contracting // dimension. Keeps the layout the same. HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, - absl::Span batch_dims) { + absl::Span batch_dims, + bool col_maj = false) { // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. std::vector permutation(instr->shape().dimensions_size(), -1); - // Discard the batch dimensions. - for (int64_t batch_dim : batch_dims) { - permutation[batch_dim] = batch_dim; - } // Identify the non-contracting dimension. int non_contracting_dim; for (int i = 0; i < instr->shape().dimensions_size(); ++i) { @@ -376,13 +374,55 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, non_contracting_dim = i; } } - permutation[non_contracting_dim] = contracting_dim; - permutation[contracting_dim] = non_contracting_dim; + if (!col_maj) { + // Discard the batch dimensions. + for (int64_t batch_dim : batch_dims) { + permutation[batch_dim] = batch_dim; + } + + permutation[non_contracting_dim] = contracting_dim; + permutation[contracting_dim] = non_contracting_dim; + + Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); + *new_shape.mutable_layout() = instr->shape().layout(); + + return instr->AddInstruction( + HloInstruction::CreateTranspose(new_shape, instr, permutation)); + } + + Shape normalized_input_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(instr->shape()); + auto a0 = MakeBitcastHlo(instr, normalized_input_shape); + + int new_contracting_dim = -1; + int new_non_contracting_dim = -1; + for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + auto dim = LayoutUtil::Major(instr->shape().layout(), i); + if (dim == contracting_dim) { + new_contracting_dim = i; + } else if (dim == non_contracting_dim) { + new_non_contracting_dim = i; + } else { + // Discard the batch dimensions. + permutation[i] = i; + } + } - Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); - *new_shape.mutable_layout() = instr->shape().layout(); - return instr->AddInstruction( - HloInstruction::CreateTranspose(new_shape, instr, permutation)); + permutation[new_non_contracting_dim] = new_contracting_dim; + permutation[new_contracting_dim] = new_non_contracting_dim; + + Shape transpose_shape = + ShapeUtil::PermuteDimensions(permutation, a0->shape()); + *transpose_shape.mutable_layout() = a0->shape().layout(); + HloInstruction *normalized_transpose = instr->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); + std::vector layout_permuation(instr->shape().layout().minor_to_major().begin(), + instr->shape().layout().minor_to_major().end()); + absl::c_reverse(layout_permuation); + auto inv_perm = InversePermutation(layout_permuation); + Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); + *final_shape.mutable_layout() = instr->shape().layout(); + return MakeBitcastHlo(normalized_transpose, final_shape); } // If the bias is a sequence of ops that depend only on broadcasts of @@ -1223,8 +1263,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } else { dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } + a.fp8_input = TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims); + + a.fp8_input = RegulateColMajorTransposeMatrixF8( + a.fp8_input, a_contracting_dims[0], a_batch_dims); } // Similarly, cuBLASLt requires the second operand to be column-major, so From 508cd6928bbc20c1d87818eed4ee6190c6c9f691 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Wed, 25 Sep 2024 12:50:07 -0500 Subject: [PATCH 2/8] Fix bug of permutation. --- xla/service/gpu/transforms/gemm_rewriter.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index d29a084eb0ec3..5979be94af1b9 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -367,6 +367,10 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. std::vector permutation(instr->shape().dimensions_size(), -1); + // Discard the batch dimensions. + for (int64_t batch_dim : batch_dims) { + permutation[batch_dim] = batch_dim; + } // Identify the non-contracting dimension. int non_contracting_dim; for (int i = 0; i < instr->shape().dimensions_size(); ++i) { @@ -375,11 +379,6 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, } } if (!col_maj) { - // Discard the batch dimensions. - for (int64_t batch_dim : batch_dims) { - permutation[batch_dim] = batch_dim; - } - permutation[non_contracting_dim] = contracting_dim; permutation[contracting_dim] = non_contracting_dim; @@ -417,7 +416,7 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, HloInstruction *normalized_transpose = instr->AddInstruction( HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); std::vector layout_permuation(instr->shape().layout().minor_to_major().begin(), - instr->shape().layout().minor_to_major().end()); + instr->shape().layout().minor_to_major().end()); absl::c_reverse(layout_permuation); auto inv_perm = InversePermutation(layout_permuation); Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); From c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 25 Sep 2024 21:50:17 +0000 Subject: [PATCH 3/8] clang format --- xla/service/gpu/transforms/gemm_rewriter.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index 5979be94af1b9..6d31f90fdaa9d 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -390,7 +390,8 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, } Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(instr->shape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + instr->shape()); auto a0 = MakeBitcastHlo(instr, normalized_input_shape); int new_contracting_dim = -1; @@ -415,8 +416,9 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, *transpose_shape.mutable_layout() = a0->shape().layout(); HloInstruction *normalized_transpose = instr->AddInstruction( HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); - std::vector layout_permuation(instr->shape().layout().minor_to_major().begin(), - instr->shape().layout().minor_to_major().end()); + std::vector layout_permuation( + instr->shape().layout().minor_to_major().begin(), + instr->shape().layout().minor_to_major().end()); absl::c_reverse(layout_permuation); auto inv_perm = InversePermutation(layout_permuation); Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); From ad0a4ba8054092dd79608865a823c1d432f81b21 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 26 Sep 2024 16:29:05 -0500 Subject: [PATCH 4/8] Add unittest. --- xla/service/gpu/transforms/gemm_rewriter.cc | 36 +++++-------- .../gpu/transforms/gemm_rewriter_test.cc | 54 +++++++++++++++++++ 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index 6d31f90fdaa9d..6f564ed977968 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -394,33 +394,23 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, instr->shape()); auto a0 = MakeBitcastHlo(instr, normalized_input_shape); - int new_contracting_dim = -1; - int new_non_contracting_dim = -1; - for (int i = 0; i < instr->shape().dimensions_size(); ++i) { - auto dim = LayoutUtil::Major(instr->shape().layout(), i); - if (dim == contracting_dim) { - new_contracting_dim = i; - } else if (dim == non_contracting_dim) { - new_non_contracting_dim = i; - } else { - // Discard the batch dimensions. - permutation[i] = i; - } - } + std::vector layout_permuation( + instr->shape().layout().minor_to_major().begin(), + instr->shape().layout().minor_to_major().end()); + absl::c_reverse(layout_permuation); + auto inv_perm = InversePermutation(layout_permuation); - permutation[new_non_contracting_dim] = new_contracting_dim; - permutation[new_contracting_dim] = new_non_contracting_dim; + int new_contracting_dim = inv_perm[contracting_dim]; + int new_non_contracting_dim = inv_perm[non_contracting_dim]; + absl::c_iota(permutation, 0); + std::swap(permutation[new_contracting_dim], + permutation[new_non_contracting_dim]); Shape transpose_shape = ShapeUtil::PermuteDimensions(permutation, a0->shape()); *transpose_shape.mutable_layout() = a0->shape().layout(); HloInstruction *normalized_transpose = instr->AddInstruction( HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); - std::vector layout_permuation( - instr->shape().layout().minor_to_major().begin(), - instr->shape().layout().minor_to_major().end()); - absl::c_reverse(layout_permuation); - auto inv_perm = InversePermutation(layout_permuation); Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); *final_shape.mutable_layout() = instr->shape().layout(); return MakeBitcastHlo(normalized_transpose, final_shape); @@ -1266,10 +1256,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } a.fp8_input = - TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims); - - a.fp8_input = RegulateColMajorTransposeMatrixF8( - a.fp8_input, a_contracting_dims[0], a_batch_dims); + TransposeMatrix(a.fp8_input, a_contracting_dims[0], + a_batch_dims, /*col_maj*/true); } // Similarly, cuBLASLt requires the second operand to be column-major, so diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index 0a6d6360587ac..c1cc3478be0e1 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -5032,6 +5032,60 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { + const char* hlo_text = R"( + HloModule test + + ENTRY test { + x = <>[16,32]{0,1} parameter(0) + y = <>[32,16]{1,0} parameter(1) + x_scale = f32[] parameter(2) + y_scale = f32[] parameter(3) + dq_scale = f32[] multiply(x_scale, y_scale) + dq_scale_bcast = f32[16,16] broadcast(dq_scale), dimensions={} + out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out.1 = f32[16,16] multiply(out, dq_scale_bcast) + } + +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite( + hlo_text, + GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{0,1} parameter(0) +; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[32,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[16,32]{1,0} transpose([[P0_BT]]), dimensions={1,0} +; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[32,16]{0,1} bitcast([[P0_TR]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["0"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { const char* hlo_text = R"( HloModule test From 1d45b4d64347c64a9483fd26caf7d8598818b855 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 26 Sep 2024 16:31:28 -0500 Subject: [PATCH 5/8] Remove uncessary space. --- xla/service/gpu/transforms/gemm_rewriter.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index 6f564ed977968..b2b01b41c75b0 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1254,7 +1254,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } else { dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } - a.fp8_input = TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims, /*col_maj*/true); From 78378455e70e439e71da078c3099732a14292d7d Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Fri, 27 Sep 2024 12:22:12 -0500 Subject: [PATCH 6/8] Update unittest. --- xla/service/gpu/transforms/gemm_rewriter.cc | 5 +-- .../gpu/transforms/gemm_rewriter_test.cc | 38 +++++++++---------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index b2b01b41c75b0..d45d4da027d02 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -1254,9 +1254,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } else { dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } - a.fp8_input = - TransposeMatrix(a.fp8_input, a_contracting_dims[0], - a_batch_dims, /*col_maj*/true); + a.fp8_input = TransposeMatrix(a.fp8_input, a_contracting_dims[0], + a_batch_dims, /*col_maj*/ true); } // Similarly, cuBLASLt requires the second operand to be column-major, so diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index c1cc3478be0e1..f59f97e2c1b85 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -5034,19 +5034,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { const char* hlo_text = R"( - HloModule test - +HloModule test ENTRY test { - x = <>[16,32]{0,1} parameter(0) - y = <>[32,16]{1,0} parameter(1) + x = <>[2,16,32]{1,0,2} parameter(0) + y = <>[2,32,16]{2,1,0} parameter(1) x_scale = f32[] parameter(2) y_scale = f32[] parameter(3) dq_scale = f32[] multiply(x_scale, y_scale) - dq_scale_bcast = f32[16,16] broadcast(dq_scale), dimensions={} - out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out.1 = f32[16,16] multiply(out, dq_scale_bcast) + dq_scale_bcast = f32[2,16,16] broadcast(dq_scale), dimensions={} + out.0 = f32[2,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,16,16] multiply(out.0, dq_scale_bcast) } - )"; CheckFp8IfSupported(hlo_text); @@ -5055,28 +5053,28 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{0,1} parameter(0) -; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[32,16]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[16,32]{1,0} transpose([[P0_BT]]), dimensions={1,0} -; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[32,16]{0,1} bitcast([[P0_TR]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,16,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[2,16,32]{1,0,2} parameter(0) +; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[32,2,16]{2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[16,2,32]{2,1,0} transpose([[P0_BT]]), dimensions={2,1,0} +; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[2,32,16]{1,0,2} bitcast([[P0_TR]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[2,32,16]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 ; CHECK-DAG: "alpha_imag":0 ; CHECK-DAG: "beta":0 ; CHECK-DAG: "dot_dimension_numbers":{ -; CHECK-DAG: "lhs_contracting_dimensions":["0"] -; CHECK-DAG: "rhs_contracting_dimensions":["1"] -; CHECK-DAG: "lhs_batch_dimensions":[] -; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["2"] +; CHECK-DAG: "lhs_batch_dimensions":["0"] +; CHECK-DAG: "rhs_batch_dimensions":["0"] ; CHECK-DAG: } ; CHECK-DAG: "precision_config":{ ; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] From b479c2177672a0010ffba1630efdaec5ca4cee26 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 30 Sep 2024 15:16:26 +0000 Subject: [PATCH 7/8] Improve TransposeMatrix --- xla/service/gpu/transforms/gemm_rewriter.cc | 30 ++++++++++++--------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index d45d4da027d02..dc113aa971cbd 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -362,28 +362,30 @@ std::optional MatchFp8Param(HloInstruction *instr) { // dimension. There must be only one contracting and only one non-contracting // dimension. Keeps the layout the same. HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, - absl::Span batch_dims, - bool col_maj = false) { + absl::Span batch_dims) { + auto input_shape = instr->shape(); // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. - std::vector permutation(instr->shape().dimensions_size(), -1); + std::vector permutation(input_shape.dimensions_size(), -1); // Discard the batch dimensions. for (int64_t batch_dim : batch_dims) { permutation[batch_dim] = batch_dim; } // Identify the non-contracting dimension. int non_contracting_dim; - for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + for (int i = 0; i < input_shape.dimensions_size(); ++i) { if (permutation[i] == -1 && contracting_dim != i) { non_contracting_dim = i; } } - if (!col_maj) { + + if (Layout::Equal()(input_shape.layout(), + LayoutUtil::GetDefaultLayoutForShape(input_shape))) { permutation[non_contracting_dim] = contracting_dim; permutation[contracting_dim] = non_contracting_dim; - Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); - *new_shape.mutable_layout() = instr->shape().layout(); + Shape new_shape = ShapeUtil::PermuteDimensions(permutation, input_shape); + *new_shape.mutable_layout() = input_shape.layout(); return instr->AddInstruction( HloInstruction::CreateTranspose(new_shape, instr, permutation)); @@ -391,12 +393,12 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, Shape normalized_input_shape = ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - instr->shape()); + input_shape); auto a0 = MakeBitcastHlo(instr, normalized_input_shape); std::vector layout_permuation( - instr->shape().layout().minor_to_major().begin(), - instr->shape().layout().minor_to_major().end()); + input_shape.layout().minor_to_major().begin(), + input_shape.layout().minor_to_major().end()); absl::c_reverse(layout_permuation); auto inv_perm = InversePermutation(layout_permuation); @@ -409,10 +411,12 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, Shape transpose_shape = ShapeUtil::PermuteDimensions(permutation, a0->shape()); *transpose_shape.mutable_layout() = a0->shape().layout(); + HloInstruction *normalized_transpose = instr->AddInstruction( HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); + Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); - *final_shape.mutable_layout() = instr->shape().layout(); + *final_shape.mutable_layout() = input_shape.layout(); return MakeBitcastHlo(normalized_transpose, final_shape); } @@ -1254,8 +1258,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } else { dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } - a.fp8_input = TransposeMatrix(a.fp8_input, a_contracting_dims[0], - a_batch_dims, /*col_maj*/ true); + a.fp8_input = + TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims); } // Similarly, cuBLASLt requires the second operand to be column-major, so From b63318487153a8668b9f95574b054b0129194c0c Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 1 Oct 2024 12:16:46 -0500 Subject: [PATCH 8/8] Update unittest shape and BUILD file. --- xla/service/gpu/transforms/BUILD | 1 + xla/service/gpu/transforms/gemm_rewriter.cc | 1 + .../gpu/transforms/gemm_rewriter_test.cc | 20 +++++++++---------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 22a97478a4f70..1757e27a904ff 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1671,6 +1671,7 @@ cc_library( deps = [ "//xla:literal", "//xla:literal_util", + "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", "//xla:types", diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index fa6f2c096a03f..14b2e61ad55e8 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/permutation_util.h" diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index 5d819ce0be80f..721f262822fb4 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -5036,14 +5036,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { const char* hlo_text = R"( HloModule test ENTRY test { - x = <>[2,16,32]{1,0,2} parameter(0) + x = <>[2,64,32]{1,2,0} parameter(0) y = <>[2,32,16]{2,1,0} parameter(1) x_scale = f32[] parameter(2) y_scale = f32[] parameter(3) dq_scale = f32[] multiply(x_scale, y_scale) - dq_scale_bcast = f32[2,16,16] broadcast(dq_scale), dimensions={} - out.0 = f32[2,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} - ROOT out = f32[2,16,16] multiply(out.0, dq_scale_bcast) + dq_scale_bcast = f32[2,64,16] broadcast(dq_scale), dimensions={} + out.0 = f32[2,64,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,64,16] multiply(out.0, dq_scale_bcast) } )"; @@ -5053,18 +5053,18 @@ HloModule test GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( -; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,16,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = <>[2,16,32]{1,0,2} parameter(0) -; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[32,2,16]{2,1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[16,2,32]{2,1,0} transpose([[P0_BT]]), dimensions={2,1,0} -; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[2,32,16]{1,0,2} bitcast([[P0_TR]]) +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,64,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,64,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[2,64,32]{1,2,0} parameter(0) +; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[2,32,64]{2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[2,64,32]{2,1,0} transpose([[P0_BT]]), dimensions={0,2,1} +; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[2,32,64]{1,2,0} bitcast([[P0_TR]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[2,32,16]{2,1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,64,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1