diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py index 72f7e86e1..bc1efd4ed 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py @@ -211,6 +211,18 @@ def default_fproc( ): return ret acc_type = cutlass_lib.library.DataType.f32 + + if ( + "no_tf32" in Target.current()._kwargs + and data_type == "float" + and Target.current()._kwargs["no_tf32"] + ): + if ( + op.tile_description.math_instruction.element_a + == cutlass_lib.library.DataType.tf32 + ): + return ret + # check target use fp16 acc if "use_fp16_acc" in Target.current()._kwargs and data_type == "cutlass::half_t": if Target.current()._kwargs["use_fp16_acc"]: diff --git a/python/aitemplate/backend/cuda/gemm_universal/common.py b/python/aitemplate/backend/cuda/gemm_universal/common.py index cda2b94ee..2781489b7 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common.py @@ -1319,6 +1319,18 @@ def default_fproc( ): return ret acc_type = cutlass_lib.library.DataType.f32 + + if ( + "no_tf32" in Target.current()._kwargs + and data_type == "float" + and Target.current()._kwargs["no_tf32"] + ): + if ( + op.tile_description.math_instruction.element_a + == cutlass_lib.library.DataType.tf32 + ): + return ret + # check target use fp16 acc if "use_fp16_acc" in Target.current()._kwargs and data_type == "cutlass::half_t": if Target.current()._kwargs["use_fp16_acc"]: diff --git a/tests/unittest/ops/test_gemm_no_tf32.py b/tests/unittest/ops/test_gemm_no_tf32.py new file mode 100644 index 000000000..4b3fd5b55 --- /dev/null +++ b/tests/unittest/ops/test_gemm_no_tf32.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import filter_test_cases_by_test_env + + +@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +@unittest.skipIf( + detect_target().name() == "cuda" and int(detect_target()._arch) < 80, + "Not supported by CUDA < SM80.", +) +class GEMMNoTF32TestCase(unittest.TestCase): + def test_rrr_no_tf32(self): + # Test accuracy with tf32 disabled + # this test uses a smaller numerical tolerance level + # than the others + allow_tf32_bak = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + test_dtype = torch.float32 + test_dtype_str = "float32" + A = torch.rand((64, 64), dtype=test_dtype).cuda() + B = torch.rand((64, 64), dtype=test_dtype).cuda() + result_cuda = torch.matmul(A, B) + + target = detect_target(no_tf32=True) # Disable tf32 for accuracy + A_ait = Tensor( + shape=[64, 64], dtype=test_dtype_str, name="input_0", is_input=True + ) + B_ait = Tensor( + shape=[64, 64], dtype=test_dtype_str, name="input_1", is_input=True + ) + OP = ops.gemm_rrr() + Y = OP(A_ait, B_ait) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + module = compile_model(Y, target, "./tmp", "gemm_rrr_no_tf32") + inputs = { + "input_0": A.clone().detach().cuda(), + "input_1": B.clone().detach().cuda(), + } + result_ait = torch.empty([64, 64], dtype=test_dtype, device="cuda") + module.run_with_tensors(inputs, [result_ait]) + torch.testing.assert_close(result_cuda, result_ait) + finally: + torch.backends.cuda.matmul.allow_tf32 = allow_tf32_bak + + +filter_test_cases_by_test_env(GEMMNoTF32TestCase) + + +if __name__ == "__main__": + unittest.main()