Skip to content

Commit

Permalink
PR #18062: [ROCm] Fix gemm_rewriter_test for AMD GCN Arch
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18062

#16841 removes scaling factor constants in gemm_rewriter for FP8 data types. This patch address the same in the gemm_rewriter_test
Copybara import of the project:

--
be4da5b by Harsha HS <[email protected]>:

[ROCm] Fix gemm_rewriter_test for AMD GCN Arch

#16841 removes scaling factor
constants in gemm_rewriter for FP8 data types. This patch address
the same in the gemm_rewriter_test

Merging this change closes #18062

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18062 from ROCm:ci_fix_gemm_rewriter_fp8_tests_20241008 be4da5b
PiperOrigin-RevId: 688022433
  • Loading branch information
hsharsha authored and Google-ML-Automation committed Oct 21, 2024
1 parent 96abe88 commit e3934d6
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions xla/service/gpu/transforms/gemm_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6107,9 +6107,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]])
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1)
; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -6325,8 +6325,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]),
; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -6435,8 +6434,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]),
; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -6516,7 +6514,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5)
; CHECK-PTX: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]),
; CHECK-NOT: output_to_operand_aliasing
; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]),
; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -6589,8 +6587,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]])
; CHECK-PTX: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]], /*index=5*/[[CV2]]),
; CHECK-GCN: [[C:%[^ ]+]] = f32[] constant(1)
; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]),
; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -7405,8 +7402,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]),
; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -7487,8 +7483,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]])
; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]])
; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[P4_INV_CONVERT]]),
; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -7566,8 +7561,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4)
; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]])
; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<<F8E4M3>>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]),
; CHECK-CGN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]]),
; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
Expand Down Expand Up @@ -7837,7 +7831,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) {
; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]])
; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-PTX: custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
Expand Down

0 comments on commit e3934d6

Please sign in to comment.