Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reciprocal operator #1023

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def acc_ops_floor(
return elementwise(FuncEnum.FLOOR)(input_val)


@ait_converter(acc_ops.reciprocal)
def acc_ops_reciprocal(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
input_val = kwargs["input"]
return elementwise(FuncEnum.RECIPROCAL)(input_val)


@ait_converter(acc_ops.add)
def acc_ops_add(
target: Target,
Expand Down
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def get_python_op_from_ait_constant_elementwise_op(
return operator.floordiv
elif op_type == FuncEnum.FLOOR:
return math.floor
elif op_type == FuncEnum.RECIPROCAL:
return math.reciprocal
else:
raise RuntimeError(f"{op_type} is not supported yet!")

Expand Down
7 changes: 7 additions & 0 deletions python/aitemplate/backend/backend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ class GPUBackendSpec(BackendSpec):
"bfloat16": "__floor",
"bfloat16_2": "__floor",
},
FuncEnum.RECIPROCAL: {
"float": "__reciprocal",
"half": "__hreciprocal",
"half2": "__h2reciprocal",
"bfloat16": "__breciprocal",
"bfloat16_2": "__b2reciprocal",
},
FuncEnum.CELU: {
"float": "fcelu",
"half": "hcelu",
Expand Down
28 changes: 28 additions & 0 deletions python/aitemplate/backend/cuda/elementwise/custom_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,34 @@ __device__ bfloat16_2 __floor(const bfloat16_2 a) {
#endif
}

__device__ float __reciprocal(const float a) {
return 1.0f / a;
}

__device__ half __hreciprocal(const half a) {
return __hdiv(1.0f, a);
}

__device__ bfloat16 __breciprocal(const bfloat16 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return __hdiv(1.0f, a);
#else
NOT_IMPLEMENTED();
#endif
}

__device__ half2 __h2reciprocal(const half2 a) {
return half2(__hdiv(1.0f, a.x), __hdiv(1.0f, a.y));
}

__device__ bfloat16_2 __b2reciprocal(const bfloat16_2 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_2(__hdiv(1.0f, a.x), __hdiv(1.0f, a.y));
#else
NOT_IMPLEMENTED();
#endif
}

__device__ float fcelu(const float a, const float alpha) {
return a > 0.f ? a : alpha * (expf(a / alpha) - 1.0f);
}
Expand Down
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/common/epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ class FuncEnum(Enum):
CELU = 29
FLOOR = 30
LOG1P = 31
RECIPROCAL = 32
4 changes: 4 additions & 0 deletions python/aitemplate/compiler/ops/common/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def celu(tensor: Any) -> Tensor:

def floor(tensor: Any) -> Tensor:
return OP_REGISTRY.get("FLOOR")(tensor)


def reciprocal(tensor: Any) -> Tensor:
return OP_REGISTRY.get("RECIPROCAL")(tensor)
54 changes: 54 additions & 0 deletions tests/unittest/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
FuncEnum.RELU: torch.relu,
FuncEnum.CELU: torch.celu,
FuncEnum.FLOOR: torch.floor,
FuncEnum.RECIPROCAL: torch.reciprocal,
}


Expand Down Expand Up @@ -437,6 +438,38 @@ def _test_fast_gelu(
module.run_with_tensors([x1_pt], [x2])
torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2)

def _test_reciprocal(
self,
input_size,
test_name="reciprocal",
copy_op=False,
dtype="float16",
):
assert len(input_size) == 2
X1 = Tensor(
shape=[IntImm(input_size[0]), IntImm(input_size[1])],
dtype=dtype,
name="input0",
is_input=True,
)
X2_op = ops.elementwise(FuncEnum.RECIPROCAL)

if copy_op:
X2_op = ops.elementwise(**X2_op._get_op_attributes())
X2 = X2_op(X1)
X2._attrs["is_output"] = True
X2._attrs["name"] = "output0"

target = detect_target()
module = compile_model(X2, target, "./tmp", f"{test_name}_{dtype}")

x1_pt = get_random_torch_tensor(input_size, dtype)
x2_pt = torch.floor(x1_pt)

x2 = torch.empty_like(x2_pt)
module.run_with_tensors([x1_pt], [x2])
torch.testing.assert_close(x2, x2_pt, atol=1e-2, rtol=1e-2)

@parameterized.expand(
**filter_test_cases_by_params(
{
Expand Down Expand Up @@ -921,6 +954,27 @@ def test_fast_gelu(self, dtype):
[256, 128], test_name="fast_gelu_4_copy_op", copy_op=True, dtype=dtype
)

@parameterized.expand(
**filter_test_cases_by_params(
{
TestEnv.CUDA_LESS_THAN_SM80: [("float16"), ("float32")],
TestEnv.CUDA_SM80: [("bfloat16")],
TestEnv.ROCM: [("float16")],
}
)
)
def test_reciprocal(self, dtype):
self._test_simple_function(
[32, 128], FuncEnum.RECIPROCAL, test_name="reciprocal", dtype=dtype
)
self._test_simple_function(
[32, 128],
FuncEnum.RECIPROCAL,
test_name="reciprocal_copy_op",
copy_op=True,
dtype=dtype,
)


if __name__ == "__main__":
unittest.main()
Loading