Skip to content

Commit

Permalink
support tanh
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Ziereis <[email protected]>
  • Loading branch information
ziereis committed Jan 1, 2025
1 parent 20507b7 commit c95d127
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
3 changes: 3 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def reciprocal(src: "Register") -> "Register":
def abs(src: "Register") -> "Register":
...

def tanh(src: "Register") -> "Register":
...

def maximum(lhs: "Register", rhs: "Register") -> "Register":
...
Expand Down Expand Up @@ -661,6 +663,7 @@ def infer_type(self):
@define_interface_op("exp2")
@define_interface_op("reciprocal")
@define_interface_op("abs")
@define_interface_op("tanh")
@define_py_op(operator.neg)
@dataclass
class UnaryPyOp(CustomOp, ABC):
Expand Down
10 changes: 10 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
log2,
reciprocal,
abs,
tanh,
maximum,
get_custom,
get_result,
Expand Down Expand Up @@ -1160,6 +1161,15 @@ def handle_abs(source: Value) -> OpResult:
raise ValidationError(f"Found unhandled operand type for abs: {element_type}")
return abs

@handle_unary_op(tanh)
def handle_abs(source: Value) -> OpResult:
element_type = get_type_or_element_type(source.type)
if _is_float_type(element_type):
result = math_d.tanh(source)
else:
raise ValidationError(f"Found unhandled operand type for tanh: {element_type}")
return result


###############################################################################
# Control Flow ops
Expand Down
12 changes: 8 additions & 4 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def test(
res = tkw.reciprocal(res)
res = tkw.abs(res)
res_b = tkw.abs(b_reg)
res = tkw.tanh(res)
tkw.write(res, a, elements_per_thread=4)
tkw.write(res_b, b, elements_per_thread=4)

Expand All @@ -748,12 +749,15 @@ def test(
# CHECK: %[[EXP2:.+]] = math.exp2 %[[NEG]]

# Testing reciprocal
# %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
# %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>
# CHECK: %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
# CHECK: %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>

# Testing abs
# %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
# %[[ABSI:.+]] = math.absi
# CHECK: %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
# CHECK: %[[ABSI:.+]] = math.absi

# Tests tanh
# CHECK: %[[TANH:.+]] = math.tanh %[[ABSF]]


@run_test
Expand Down

0 comments on commit c95d127

Please sign in to comment.