Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Mar 30, 2023
1 parent 52864b1 commit 9bf08bd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def te_silu(x: te.Tensor):


@register_legalize("relax.nn.leaky_relu")
def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr:
def _nn_leaky_relu(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.nn.leaky_relu, call.args[0], call.attrs.alpha)


Expand Down
8 changes: 3 additions & 5 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ TVM_REGISTER_GLOBAL("relax.op.nn.leaky_relu").set_body_typed(leaky_relu);

StructInfo InferStructInfoLeakyRelu(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
if (data_sinfo->IsUnknownNdim()) {
return data_sinfo;
}

if (!data_sinfo->IsUnknownDtype() &&
!(data_sinfo->dtype.is_float() || data_sinfo->dtype.is_bfloat16())) {
ctx->ReportFatal(Diagnostic::Error(call) <<
"LeakyRelu requires the input tensor to have float dtype. However, the given input dtype is "
ctx->ReportFatal(Diagnostic::Error(call) << "LeakyRelu requires the input tensor to have float "
"dtype. However, the given input dtype is "
<< data_sinfo->dtype);
}

Expand Down

0 comments on commit 9bf08bd

Please sign in to comment.