Skip to content

Commit

Permalink
[Relax][Frontend][Onnx] Cast Op special handling for ShapeExpr input (#…
Browse files Browse the repository at this point in the history
…17061)

Co-authored-by: tsu-bin <[email protected]>
  • Loading branch information
tsu-bin and tsu-bin authored Jun 4, 2024
1 parent 1c05902 commit f5d3fc2
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ class Cast(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
to_type = get_type(attr["to"])
if isinstance(inputs[0], relax.ShapeExpr):
shape = inputs[0]
if all([isinstance(x, tir.IntImm) for x in shape]):
shape = [int(x) for x in shape]
return relax.const(shape, to_type)
if isinstance(inputs[0], relax.Constant):
output = inputs[0].data.numpy().astype(to_type)
return relax.const(output, to_type)
Expand Down Expand Up @@ -2210,6 +2215,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
"Concat",
"Equal",
"Where",
"Cast",
]
for i, inp in enumerate(inputs):
if (
Expand Down

0 comments on commit f5d3fc2

Please sign in to comment.