From 26ac6f3645342e13ce6973c324afaa63e60c5d02 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Fri, 1 Nov 2024 15:38:07 +0100 Subject: [PATCH 1/3] updated expand to support dynamic relax.ShapeExpr updated slice to convert PrimExpr to PrimValue before sending values to relax.op.strided_slice --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 611f4348d55e..4d905da89225 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -909,7 +909,7 @@ def _impl_v14(cls, bb, inputs, attr, params): if len(inputs) > 1: k = get_constant(inputs[1], params) if isinstance(k, relax.Constant): - k = int(k.data.numpy()[0]) + k = int(k.data.numpy().item()) else: raise ValueError("Currently only support constant k for Trilu op.") else: @@ -1588,6 +1588,16 @@ def _impl_v13(cls, bb, inputs, attr, params): return bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) +def get_prim_value_list(values): + new_values = [] + for v in list(values): + if isinstance(v, relax.expr.PrimExpr): + new_values.append(relax.PrimValue(v)) + else: + new_values.append(v) + return new_values + + class Slice(OnnxOpConverter): """Converts an onnx Splice node into an equivalent Relax expression.""" @@ -1641,7 +1651,12 @@ def _impl_v13(cls, bb, inputs, attr, params): assume_inbound = not all( [isinstance(param, (tir.IntImm, int)) for param in [*starts, *ends, *steps]] ) - # return relax.op.strided_slice(data, axes, starts, ends, steps) + + # Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExpr + starts = get_prim_value_list(starts) + ends = get_prim_value_list(ends) + steps = get_prim_value_list(steps) + return relax.op.strided_slice( data, axes, starts, ends, steps, assume_inbound=assume_inbound ) @@ -1730,9 +1745,21 @@ class Expand(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] shape = inputs[1] - if isinstance(shape, relax.ShapeExpr): - return relax.op.broadcast_to(data, shape) + data_shape = [dim for dim in data.struct_info.shape] + target_shape = [dim for dim in shape.values] + data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape + assert len(data_shape) == len(target_shape) + # Fix small target shapes + for i, s in enumerate(target_shape): + if isinstance(s, tvm.tir.IntImm) and ( + (isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i]) + or s.value == -1 + ): + target_shape[i] = data_shape[i] + if target_shape == data_shape: + return data + return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape)) # If possible, directly expand to constant shape. if isinstance(shape, relax.Constant): @@ -2688,15 +2715,11 @@ def _impl_v11(cls, bb, inputs, attr, params): mode = attr.get("mode", b"DCR").decode("utf-8") b, c, h, w = inputs[0].struct_info.shape if mode == "DCR": - x = relax.op.reshape( - inputs[0], (b, block_size, block_size, c // (block_size**2), h, w) - ) + x = relax.op.reshape(inputs[0], (b, block_size, block_size, c // (block_size**2), h, w)) x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) elif mode == "CRD": - x = relax.op.reshape( - inputs[0], (b, c // (block_size**2), block_size, block_size, h, w) - ) + x = relax.op.reshape(inputs[0], (b, c // (block_size**2), block_size, block_size, h, w)) x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) else: From 2847ef3ac4c2e304f8aae478397485b74176dffa Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Fri, 1 Nov 2024 16:00:35 +0100 Subject: [PATCH 2/3] added test for dynamic shape expression in test_expand --- tests/python/relax/test_frontend_onnx.py | 52 +++++++++++++++++------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 9faa441138fc..c130bf43730b 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1507,10 +1507,6 @@ def test_topk(axis: int, largest: int): @pytest.mark.parametrize("dynamic", [False, True]) def test_expand(dynamic): - if dynamic: - # TODO: Support dynamic shape for Expand - pytest.skip("Dynamic expand is not supported yet") - def _test_expand(name, data, shape, ref_data): shape_array = np.array(shape) shape_node = onnx.helper.make_node( @@ -1541,17 +1537,43 @@ def _test_expand(name, data, shape, ref_data): model = helper.make_model(graph, producer_name=name) check_correctness(model, inputs={"in": data}) - in_shape = (3, 1) - shape = (3, 4) - data = np.random.uniform(size=in_shape).astype(np.float32) - ref_data = np.tile(data, 4) - _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) - - in_shape = (3, 1) - shape = (1, 3, 4) - data = np.random.uniform(size=in_shape).astype(np.float32) - ref_data = np.tile(data, (1, 1, 4)) - _test_expand("expand_with_diff_dim", data, shape, ref_data) + def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data): + shape_node = onnx.helper.make_node("Shape", inputs=["in_2"], outputs=["shape"]) + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + in_shape = list(data.shape) + out_shape = list(ref_data.shape) + graph = helper.make_graph( + [shape_node, expand_node], + "expand_test", + inputs=[ + helper.make_tensor_value_info("in", TensorProto.FLOAT, in_shape), + helper.make_tensor_value_info("in_2", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name=name) + check_correctness(model, inputs={"in": data, "in_2": shape_data}) + + if not dynamic: + in_shape = (3, 1) + shape = (3, 4) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, 4) + _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) + + in_shape = (3, 1) + shape = (1, 3, 4) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, (1, 1, 4)) + _test_expand("expand_with_diff_dim", data, shape, ref_data) + else: + in_shape = (1, 32, 32) + shape = ("batch", 32, 32) + data = np.random.uniform(size=in_shape).astype(np.float32) + shape_data = np.random.uniform(size=(64, 32, 32)).astype(np.float32) + ref_data = np.tile(data, (64, 1, 1)) + _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data) # TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. From 598d60d720776028916f0ffc4ec6825e2eee9878 Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Fri, 1 Nov 2024 16:28:16 +0100 Subject: [PATCH 3/3] updated formatting removed unnecessary list comprehension --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4d905da89225..cbd633324a75 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1746,11 +1746,11 @@ def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] shape = inputs[1] if isinstance(shape, relax.ShapeExpr): - data_shape = [dim for dim in data.struct_info.shape] - target_shape = [dim for dim in shape.values] + data_shape = list(data.struct_info.shape) + target_shape = list(shape.values) data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape assert len(data_shape) == len(target_shape) - # Fix small target shapes + # Fix small target shapes or target shapes assigned to -1 for i, s in enumerate(target_shape): if isinstance(s, tvm.tir.IntImm) and ( (isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i]) @@ -2715,11 +2715,15 @@ def _impl_v11(cls, bb, inputs, attr, params): mode = attr.get("mode", b"DCR").decode("utf-8") b, c, h, w = inputs[0].struct_info.shape if mode == "DCR": - x = relax.op.reshape(inputs[0], (b, block_size, block_size, c // (block_size**2), h, w)) + x = relax.op.reshape( + inputs[0], (b, block_size, block_size, c // (block_size**2), h, w) + ) x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) elif mode == "CRD": - x = relax.op.reshape(inputs[0], (b, c // (block_size**2), block_size, block_size, h, w)) + x = relax.op.reshape( + inputs[0], (b, c // (block_size**2), block_size, block_size, h, w) + ) x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) else: