diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 86c77538e8fd..ba121b7ec4fa 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -442,6 +442,11 @@ class Cast(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): to_type = get_type(attr["to"]) + if isinstance(inputs[0], relax.ShapeExpr): + shape = inputs[0] + if all([isinstance(x, tir.IntImm) for x in shape]): + shape = [int(x) for x in shape] + return relax.const(shape, to_type) if isinstance(inputs[0], relax.Constant): output = inputs[0].data.numpy().astype(to_type) return relax.const(output, to_type) @@ -2210,6 +2215,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Concat", "Equal", "Where", + "Cast", ] for i, inp in enumerate(inputs): if (