From 66ee8d6b90f32127f9e31daaee79009d1e28e50f Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Wed, 4 Sep 2024 09:43:22 -0700 Subject: [PATCH] Add reciprocal operator (#1023) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/1023 Add reciprocal Reviewed By: frank-wei, aakhundov Differential Revision: D62000543 --- fx2ait/fx2ait/converters/ait_converters.py | 11 ++++ fx2ait/fx2ait/converters/utils.py | 2 + python/aitemplate/backend/backend_spec.py | 7 +++ .../backend/cuda/elementwise/custom_math.cuh | 28 ++++++++++ .../compiler/ops/common/epilogue.py | 1 + python/aitemplate/compiler/ops/common/math.py | 4 ++ tests/unittest/ops/test_activation.py | 54 +++++++++++++++++++ 7 files changed, 107 insertions(+) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 1fc2af753..3c0463ad8 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -166,6 +166,17 @@ def acc_ops_floor( return elementwise(FuncEnum.FLOOR)(input_val) +@ait_converter(acc_ops.reciprocal) +def acc_ops_reciprocal( + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> ConverterOutput: + input_val = kwargs["input"] + return elementwise(FuncEnum.RECIPROCAL)(input_val) + + @ait_converter(acc_ops.add) def acc_ops_add( target: Target, diff --git a/fx2ait/fx2ait/converters/utils.py b/fx2ait/fx2ait/converters/utils.py index dc8d59a15..df25d9092 100644 --- a/fx2ait/fx2ait/converters/utils.py +++ b/fx2ait/fx2ait/converters/utils.py @@ -166,6 +166,8 @@ def get_python_op_from_ait_constant_elementwise_op( return operator.floordiv elif op_type == FuncEnum.FLOOR: return math.floor + elif op_type == FuncEnum.RECIPROCAL: + return math.reciprocal else: raise RuntimeError(f"{op_type} is not supported yet!") diff --git a/python/aitemplate/backend/backend_spec.py b/python/aitemplate/backend/backend_spec.py index 1aeeb613a..b53e2750b 100644 --- a/python/aitemplate/backend/backend_spec.py +++ b/python/aitemplate/backend/backend_spec.py @@ -326,6 +326,13 @@ class GPUBackendSpec(BackendSpec): "bfloat16": "__floor", "bfloat16_2": "__floor", }, + FuncEnum.RECIPROCAL: { + "float": "__reciprocal", + "half": "__hreciprocal", + "half2": "__h2reciprocal", + "bfloat16": "__breciprocal", + "bfloat16_2": "__b2reciprocal", + }, FuncEnum.CELU: { "float": "fcelu", "half": "hcelu", diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh index eebe344f0..43b5fec6c 100644 --- a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -1043,6 +1043,34 @@ __device__ bfloat16_2 __floor(const bfloat16_2 a) { #endif } +__device__ float __reciprocal(const float a) { + return 1.0f / a; +} + +__device__ half __hreciprocal(const half a) { + return __hdiv(1.0f, a); +} + +__device__ bfloat16 __breciprocal(const bfloat16 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hdiv(1.0f, a); +#else + NOT_IMPLEMENTED(); +#endif +} + +__device__ half2 __h2reciprocal(const half2 a) { + return half2(__hdiv(1.0f, a.x), __hdiv(1.0f, a.y)); +} + +__device__ bfloat16_2 __b2reciprocal(const bfloat16_2 a) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_2(__hdiv(1.0f, a.x), __hdiv(1.0f, a.y)); +#else + NOT_IMPLEMENTED(); +#endif +} + __device__ float fcelu(const float a, const float alpha) { return a > 0.f ? a : alpha * (expf(a / alpha) - 1.0f); } diff --git a/python/aitemplate/compiler/ops/common/epilogue.py b/python/aitemplate/compiler/ops/common/epilogue.py index f634b44b6..3395fe1be 100644 --- a/python/aitemplate/compiler/ops/common/epilogue.py +++ b/python/aitemplate/compiler/ops/common/epilogue.py @@ -67,3 +67,4 @@ class FuncEnum(Enum): CELU = 29 FLOOR = 30 LOG1P = 31 + RECIPROCAL = 32 diff --git a/python/aitemplate/compiler/ops/common/math.py b/python/aitemplate/compiler/ops/common/math.py index b4af047e9..e661954c8 100644 --- a/python/aitemplate/compiler/ops/common/math.py +++ b/python/aitemplate/compiler/ops/common/math.py @@ -125,3 +125,7 @@ def celu(tensor: Any) -> Tensor: def floor(tensor: Any) -> Tensor: return OP_REGISTRY.get("FLOOR")(tensor) + + +def reciprocal(tensor: Any) -> Tensor: + return OP_REGISTRY.get("RECIPROCAL")(tensor) diff --git a/tests/unittest/ops/test_activation.py b/tests/unittest/ops/test_activation.py index 7ec2c7ee7..14e1c0c3b 100644 --- a/tests/unittest/ops/test_activation.py +++ b/tests/unittest/ops/test_activation.py @@ -45,6 +45,7 @@ FuncEnum.RELU: torch.relu, FuncEnum.CELU: torch.celu, FuncEnum.FLOOR: torch.floor, + FuncEnum.RECIPROCAL: torch.reciprocal, } @@ -437,6 +438,38 @@ def _test_fast_gelu( module.run_with_tensors([x1_pt], [x2]) torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2) + def _test_reciprocal( + self, + input_size, + test_name="reciprocal", + copy_op=False, + dtype="float16", + ): + assert len(input_size) == 2 + X1 = Tensor( + shape=[IntImm(input_size[0]), IntImm(input_size[1])], + dtype=dtype, + name="input0", + is_input=True, + ) + X2_op = ops.elementwise(FuncEnum.RECIPROCAL) + + if copy_op: + X2_op = ops.elementwise(**X2_op._get_op_attributes()) + X2 = X2_op(X1) + X2._attrs["is_output"] = True + X2._attrs["name"] = "output0" + + target = detect_target() + module = compile_model(X2, target, "./tmp", f"{test_name}_{dtype}") + + x1_pt = get_random_torch_tensor(input_size, dtype) + x2_pt = torch.floor(x1_pt) + + x2 = torch.empty_like(x2_pt) + module.run_with_tensors([x1_pt], [x2]) + torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2) + @parameterized.expand( **filter_test_cases_by_params( { @@ -921,6 +954,27 @@ def test_fast_gelu(self, dtype): [256, 128], test_name="fast_gelu_4_copy_op", copy_op=True, dtype=dtype ) + @parameterized.expand( + **filter_test_cases_by_params( + { + TestEnv.CUDA_LESS_THAN_SM80: [("float16"), ("float32")], + TestEnv.CUDA_SM80: [("bfloat16")], + TestEnv.ROCM: [("float16")], + } + ) + ) + def test_reciprocal(self, dtype): + self._test_simple_function( + [32, 128], FuncEnum.RECIPROCAL, test_name="reciprocal", dtype=dtype + ) + self._test_simple_function( + [32, 128], + FuncEnum.RECIPROCAL, + test_name="reciprocal_copy_op", + copy_op=True, + dtype=dtype, + ) + if __name__ == "__main__": unittest.main()