From 74cfb8e83135573be65ce7cec0897c23804e864d Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Mon, 21 Oct 2024 01:22:23 -0700 Subject: [PATCH] PR #18062: [ROCm] Fix gemm_rewriter_test for AMD GCN Arch Imported from GitHub PR https://github.com/openxla/xla/pull/18062 https://github.com/openxla/xla/pull/16841 removes scaling factor constants in gemm_rewriter for FP8 data types. This patch address the same in the gemm_rewriter_test Copybara import of the project: -- be4da5b8de0785d43e18dbdb0773307870084e32 by Harsha HS : [ROCm] Fix gemm_rewriter_test for AMD GCN Arch https://github.com/openxla/xla/pull/16841 removes scaling factor constants in gemm_rewriter for FP8 data types. This patch address the same in the gemm_rewriter_test Merging this change closes #18062 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/18062 from ROCm:ci_fix_gemm_rewriter_fp8_tests_20241008 be4da5b8de0785d43e18dbdb0773307870084e32 PiperOrigin-RevId: 688034088 --- .../gpu/transforms/gemm_rewriter_test.cc | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index e34c22eb3cd1b..b4c6be7177a67 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -6107,9 +6107,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6325,8 +6325,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6435,8 +6434,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6516,7 +6514,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { ; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) ; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK-NOT: output_to_operand_aliasing -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6589,8 +6587,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { ; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) ; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]], /*index=5*/[[CV2]]), -; CHECK-GCN: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7405,8 +7402,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7487,8 +7483,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[P4_INV_CONVERT]]), -; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7566,8 +7561,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), -; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7837,7 +7831,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX: custom_call_target="<>", ; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={