From f756875267143f1b4b738f49f21f38caf459b02d Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 9 Sep 2024 10:44:21 -0700 Subject: [PATCH] Turn on TMA by default for row-wise GEMM Summary: Enabling the TMA row-wise GEMM by default it TMA appears to give quite some speedup across-the-board, up to 40% for some shapes. Differential Revision: D62212842 --- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index cdcda6bf8..07765fa21 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -975,7 +975,7 @@ def matmul_fp8_row( allow_tf32: bool = True, fp8_fast_accum: bool = True, imprecise_acc: bool = False, - tma_persistent: bool = False, + tma_persistent: bool = True, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].