Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce "no_tf32" param for targets, to disable kernels using numer…
…ically less accurate tf32 (#874) Summary: Pull Request resolved: #874 Adressing github #872 ( #872 ): "Option for choosing fp32 gemm backend implementation" As reported by Github user zhekunz2, small numerical discrepancies between pytorch's and AITemplate GEMM could be observed on GPUs >= SM80 ( A100 and above ) where GEMM Kernels with TF32 could be selected. Most of the time these Kernels are a good choice due to their performance and relatively good accuracy, but sometimes perfect accuracy is required. So this diff introduces a "no_tf32" option that can be passed to detect_target, which prevents the usage of certain Cutlass GEMM Kernels using TF32. Example usage as in this new unit test, which is slightly modified code from the initial report: ``` def test_rrr_no_tf32(self): # Test accuracy with tf32 disabled # this test uses a smaller numerical tolerance level # than the others allow_tf32_bak = torch.backends.cuda.matmul.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = False try: test_dtype = torch.float32 test_dtype_str = "float32" A = torch.rand((64, 64), dtype=test_dtype).cuda() B = torch.rand((64, 64), dtype=test_dtype).cuda() result_cuda = torch.matmul(A, B) target = detect_target(no_tf32=True) # Disable tf32 for accuracy A_ait = Tensor( shape=[64, 64], dtype=test_dtype_str, name="input_0", is_input=True ) B_ait = Tensor( shape=[64, 64], dtype=test_dtype_str, name="input_1", is_input=True ) OP = ops.gemm_rrr() Y = OP(A_ait, B_ait) Y._attrs["name"] = "output_0" Y._attrs["is_output"] = True module = compile_model(Y, target, "./tmp", f"gemm_rrr_no_tf32") inputs = { "input_0": A.clone().detach().cuda(), "input_1": B.clone().detach().cuda(), } result_ait = torch.empty([64, 64], dtype=test_dtype, device="cuda") module.run_with_tensors(inputs, [result_ait]) torch.testing.assert_close(result_cuda, result_ait) finally: torch.backends.cuda.matmul.allow_tf32 = allow_tf32_bak ``` Differential Revision: D48034389 fbshipit-source-id: 98ab08e4be7ad9156496d4f809110c459e2b842d
- Loading branch information