Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

Commit

Permalink
update dropout for inference on topi side
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Apr 5, 2023
1 parent a6e94aa commit 0968c00
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
13 changes: 12 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
else:
inputs.append(None)
i_name = self._parse_value_proto(node)
outputs = node.output
outputs = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(outputs)
Expand Down Expand Up @@ -2162,6 +2162,17 @@ def _convert_operator(
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym

def _fix_outputs(self, op_name, outputs):
"""A hack to handle dropout or similar operator that have more than one out
in ONNX.
"""
if op_name == "Dropout":
if len(outputs) == 1:
return outputs
# TODO(vvchernov): support dropout mask?
outputs = outputs[:-1]
return outputs


def from_onnx(
model: onnx.onnx_ml_pb2.GraphProto,
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:

@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
return call
return bb.call_te(topi.nn.dropout, call.args[0])


def _te_attention(q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor) -> te.Tensor:
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/topi/nn/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,28 @@ def _compute_channelwise(*indices):
return tvm.tir.Select(xval > 0, xval, xval * slope(indices[axis]))

return te.compute(x.shape, _compute_channelwise)


@tvm.te.tag_scope(tag=tag.ELEMWISE)
def dropout(x):
"""Take dropout of input x for inference.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
def _compute_channelwise(*indices):
xval = x(*indices)
return tvm.tir.Cast("bool", xval)

def _compute_identity(*indices):
xval = x(*indices)
return xval

return te.compute(x.shape, _compute_identity), te.compute(x.shape, _compute_channelwise)

0 comments on commit 0968c00

Please sign in to comment.