From d7ae4c74fc0363f36fc5c0fdc2d40c2e64d5ae9c Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 14 Jun 2024 03:24:10 +0900 Subject: [PATCH] [Relax] [PyTorch] Add support for torch.nn.Hardsigmoid (#17085) add hardsigmoid support to fx_frontend --- .../tvm/relax/frontend/torch/fx_translator.py | 10 ++++++ tests/python/relax/test_frontend_from_fx.py | 35 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5efcce27859..5ed0f18deb9e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -243,6 +243,14 @@ def _gelu(self, node: fx.node.Node) -> relax.Expr: else: raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + def _hardsigmoid(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + def _hardswish(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -1367,6 +1375,7 @@ def create_convert_map(self): nn.Sigmoid: self._sigmoid, nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Hardsigmoid: self._hardsigmoid, nn.Hardswish: self._hardswish, nn.Flatten: self._flatten, nn.BatchNorm2d: self._batch_norm_2d, @@ -1447,6 +1456,7 @@ def create_convert_map(self): "leaky_relu": self._leakyrelu, "gelu": self._gelu, "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + "hardsigmoid": self._hardsigmoid, "hardswish": self._hardswish, "interpolate": self._interpolate, "size": self._size, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 49131b5ff891..dd2719f8ce91 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1416,6 +1416,41 @@ def main( verify_model(SiLU2(), input_info, {}, expected1) +def test_hardsigmoid(): + input_info = [([1, 3, 10, 10], "float32")] + + class Hardsigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardsigmoid() + + def forward(self, input): + return self.hs(input) + + class Hardsigmoid2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardsigmoid(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2 + R.output(gv) + return gv + + verify_model(Hardsigmoid(), input_info, {}, expected1) + verify_model(Hardsigmoid2(), input_info, {}, expected1) + + def test_hardswish(): input_info = [([1, 3, 10, 10], "float32")]