From 5a67a00bcbb53731bbf53db7801fa16c8c9eb9f2 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 5 Aug 2024 21:17:48 +0800 Subject: [PATCH] [Unity][Frontend] Add Sqrt Op (#17228) * Update op.py * Update test_frontend_nn_op.py * Update op.py with annotation * Update core.py(typo in annotation) --- python/tvm/relax/frontend/nn/core.py | 2 +- python/tvm/relax/frontend/nn/op.py | 22 ++++++++++++++++++++++ tests/python/relax/test_frontend_nn_op.py | 6 ++++-- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 3511c38a2b7c..21118b1cb8af 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -17,7 +17,7 @@ """The core infra for nn.Module, which includes the following pieces: - Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more convenient access shape and dtype information. - Tensor is always symbolc and not bound to any concrete values. + Tensor is always symbolic and not bound to any concrete values. - Parameter, a special tensor which could be bound or not bound to concrete values. - Module, a container of nn.Parameters and sub nn.Modules. - Effect, a non-user-facing class that encloses potential side effects, for example, IO, diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index e1ba4483c741..17a40a8cce57 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1486,6 +1486,28 @@ def square(x: Tensor, name: str = "square") -> Tensor: return wrap_nested(_op.square(x._expr), name) +def sqrt(x: Tensor, name: str = "sqrt") -> Tensor: + """Computes the element-wise sqrt of the input tensor. + + Parameters + ---------- + x : Tensor + The input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.sqrt(x._expr), name) + + def get_timestep_embedding( x: Tensor, embedding_dim: int, diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index a632a867432b..6c3269195498 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -31,7 +31,8 @@ def test_unary(): class Model(Module): def test(self, x: Tensor): z0 = op.square(x) - return (x,) + z1 = op.sqrt(x) + return (z0, z1) # fmt: off @R.function @@ -39,7 +40,8 @@ def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object): R.func_attr({"num_input": 2}) with R.dataflow(): square: R.Tensor((1, 10), dtype="float32") = R.square(x) - gv1 = (x,), (_io,) + sqrt: R.Tensor((1, 10), dtype="float32") = R.sqrt(x) + gv1 = (square, sqrt), (_io,) R.output(gv1) return gv1 # fmt: on