Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

Commit

Permalink
last additions
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Mar 29, 2023
1 parent 651dac4 commit 2dbb2bd
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
17 changes: 13 additions & 4 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,15 +1336,24 @@ def _defaut_impl(cls, bb, inputs, attr):

@classmethod
def _impl_v1(cls, bb, inputs, attr):
assert (
inputs[0].struct_info.dtype in ["float16", "float", "float32", "float64", "double"]
), "input type is unsupported"
assert inputs[0].struct_info.dtype in [
"float16",
"float",
"float32",
"float64",
"double",
], "input type is unsupported"
return cls._default_impl(bb, inputs, attr)

@classmethod
def _impl_v13(cls, bb, inputs, attr):
assert inputs[0].struct_info.dtype in [
"float16", "bfloat16", "float", "float32", "float64", "double"
"float16",
"bfloat16",
"float",
"float32",
"float64",
"double",
], "input type is unsupported"
return cls._default_impl(bb, inputs, attr)

Expand Down
12 changes: 6 additions & 6 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,12 @@ def batch_norm(


def lrn(
data : Expr,
size : int,
axis : int = 1,
bias : float = 2,
alpha : float = 1e-4,
beta : float = 0.75,
data: Expr,
size: int,
axis: int = 1,
bias: float = 2,
alpha: float = 1e-4,
beta: float = 0.75,
) -> Expr:
"""This operator takes data as input and does local response normalization.
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ Expr log_softmax(Expr data, int axis);
Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, //
int axis, double epsilon, bool center, bool scale);

/*! \brief Compute local response normalization (LRN) */
Expr lrn(Expr data, int size, int axis, float alpha, float beta, float bias);

/*! \brief Compute layer normalization. */
Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center,
bool scale);
Expand Down
2 changes: 2 additions & 0 deletions src/topi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm)
.set_default(WrapSchedule(topi::generic::default_schedule));

TVM_REGISTER_GENERIC_FUNC(schedule_lrn).set_default(WrapSchedule(topi::generic::default_schedule));

TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
Expand Down

0 comments on commit 2dbb2bd

Please sign in to comment.