Skip to content

Commit

Permalink
[Relax] [ONNX] Add support for HardSigmoid (#17089)
Browse files Browse the repository at this point in the history
add hardsigmoid support to onnx frontend
  • Loading branch information
mshr-h authored Jun 23, 2024
1 parent e6bfaf8 commit 4ef9011
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,20 @@ def _impl_v1(cls, bb, inputs, attr, params):
) + relax.op.nn.relu(inputs[0])


class HardSigmoid(OnnxOpConverter):
"""Converts an onnx HardSigmoid node into an equivalent Relax expression."""

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
x = inputs[0]
dtype = x.struct_info.dtype
alpha = float(attr.get("alpha", 0.2))
alpha = relax.const(alpha, dtype=dtype)
beta = float(attr.get("beta", 0.5))
beta = relax.const(beta, dtype=dtype)
return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1)


class HardSwish(OnnxOpConverter):
"""Converts an onnx HardSwish node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2014,6 +2028,7 @@ def _get_convert_map():
"Reciprocal": Reciprocal,
"OneHot": OneHot,
"Elu": Elu,
"HardSigmoid": HardSigmoid,
"HardSwish": HardSwish,
}

Expand Down
6 changes: 6 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ def test_elu():
verify_unary("Elu", [32, 32])


def test_hardsigmoid():
verify_unary("HardSigmoid", [32, 32])
verify_unary("HardSigmoid", [32, 32], attrs={"alpha": 0.3, "beta": 0.4})
verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6})


def test_hardswish():
verify_unary("HardSwish", [32, 32])

Expand Down

0 comments on commit 4ef9011

Please sign in to comment.