Skip to content

Commit

Permalink
[Unity][Frontend] Add Sqrt Op (apache#17228)
Browse files Browse the repository at this point in the history
* Update op.py

* Update test_frontend_nn_op.py

* Update op.py with annotation

* Update core.py(typo in annotation)
  • Loading branch information
tlopex authored Aug 5, 2024
1 parent bd7f1f8 commit 5a67a00
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""The core infra for nn.Module, which includes the following pieces:
- Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo,
providing more convenient access shape and dtype information.
Tensor is always symbolc and not bound to any concrete values.
Tensor is always symbolic and not bound to any concrete values.
- Parameter, a special tensor which could be bound or not bound to concrete values.
- Module, a container of nn.Parameters and sub nn.Modules.
- Effect, a non-user-facing class that encloses potential side effects, for example, IO,
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,28 @@ def square(x: Tensor, name: str = "square") -> Tensor:
return wrap_nested(_op.square(x._expr), name)


def sqrt(x: Tensor, name: str = "sqrt") -> Tensor:
"""Computes the element-wise sqrt of the input tensor.
Parameters
----------
x : Tensor
The input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.sqrt(x._expr), name)


def get_timestep_embedding(
x: Tensor,
embedding_dim: int,
Expand Down
6 changes: 4 additions & 2 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def test_unary():
class Model(Module):
def test(self, x: Tensor):
z0 = op.square(x)
return (x,)
z1 = op.sqrt(x)
return (z0, z1)

# fmt: off
@R.function
def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object):
R.func_attr({"num_input": 2})
with R.dataflow():
square: R.Tensor((1, 10), dtype="float32") = R.square(x)
gv1 = (x,), (_io,)
sqrt: R.Tensor((1, 10), dtype="float32") = R.sqrt(x)
gv1 = (square, sqrt), (_io,)
R.output(gv1)
return gv1
# fmt: on
Expand Down

0 comments on commit 5a67a00

Please sign in to comment.