Skip to content

Commit

Permalink
Fix expand parsing (#3027)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored May 6, 2024
1 parent 2bdd02d commit 5306a70
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 23 deletions.
8 changes: 7 additions & 1 deletion src/onnx/parse_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ struct parse_expand : op_parser<parse_expand>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
if(arg_s.empty())
{
Expand All @@ -50,6 +49,13 @@ struct parse_expand : op_parser<parse_expand>
}
else
{
const shape& shape_0 = args[0]->get_shape();
if(shape_0.dynamic())
{
MIGRAPHX_THROW(
"PARSE_EXPAND: dynamic input tensor with fixed dims input not supported");
}
const auto& in_lens = shape_0.lens();
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
Expand Down
Binary file added test/onnx/expand_dyn_input_dyn_output_test.onnx
Binary file not shown.
Binary file added test/onnx/expand_dyn_input_static_dims_throw.onnx
Binary file not shown.
19 changes: 0 additions & 19 deletions test/onnx/expand_dyn_test.onnx

This file was deleted.

19 changes: 19 additions & 0 deletions test/onnx/expand_static_input_dyn_output_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 #expand_static_input_dyn_output_test:�

x
dimsy"Expand#expand_static_input_dyn_output_testZ
x



Z
dims


b
y




B
37 changes: 36 additions & 1 deletion test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,7 @@ def expand_test():


@onnx_test()
def expand_dyn_test():
def expand_static_input_dyn_output_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 1])
dims_in = helper.make_tensor_value_info('dims', TensorProto.INT64, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5])
Expand All @@ -2214,6 +2214,41 @@ def expand_dyn_test():
return ([node], [x, dims_in], [y])


@onnx_test()
def expand_dyn_input_dyn_output_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 1, 1])
dims_in = helper.make_tensor_value_info('dims', TensorProto.INT64, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5])

node = onnx.helper.make_node('Expand', inputs=['x', 'dims'], outputs=['y'])

return ([node], [x, dims_in], [y])


@onnx_test()
def expand_dyn_input_static_dims_throw():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 1, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 4])

shape_val = np.array([3, 4, 4]).astype(np.int64)
shape_ts = helper.make_tensor(name='shape_tensor',
data_type=TensorProto.INT32,
dims=shape_val.shape,
vals=shape_val.flatten().astype(int))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape'],
value=shape_ts,
)

node = onnx.helper.make_node('Expand',
inputs=['x', 'shape'],
outputs=['y'])

return ([shape_const, node], [x], [y])


@onnx_test(True)
def external_constant_test():
x = np.array([0, 1, 2])
Expand Down
29 changes: 27 additions & 2 deletions test/onnx/parse/expand_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TEST_CASE(expand_test)
EXPECT(p == prog);
}

TEST_CASE(expand_dyn_test)
TEST_CASE(expand_static_input_dyn_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -48,6 +48,31 @@ TEST_CASE(expand_dyn_test)
auto dims = mm->add_parameter("dims", ss);
mm->add_instruction(migraphx::make_op("broadcast_with_dims"), param, dims);

auto prog = optimize_onnx("expand_dyn_test.onnx");
auto prog = optimize_onnx("expand_static_input_dyn_output_test.onnx");
EXPECT(p == prog);
}

TEST_CASE(expand_dyn_input_dyn_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s(migraphx::shape::float_type, {{3, 8}, {1, 1}, {1, 1}});
auto param = mm->add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int64_type, {4});
auto dims = mm->add_parameter("dims", ss);
auto ret = mm->add_instruction(migraphx::make_op("broadcast_with_dims"), param, dims);
mm->add_return({ret});

migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
auto prog = parse_onnx("expand_dyn_input_dyn_output_test.onnx", options);
EXPECT(p == prog);
}

TEST_CASE(expand_dyn_input_static_dims_throw)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("expand_dyn_input_static_dims_throw.onnx", options); }));
}

0 comments on commit 5306a70

Please sign in to comment.