Skip to content

Commit

Permalink
Add scales attribute parse in upsample for older opset versions
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-dusnoki-htec committed Oct 16, 2023
1 parent 650ba45 commit c3b28d3
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
41 changes: 28 additions & 13 deletions src/onnx/parse_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,23 +265,38 @@ struct parse_resize : op_parser<parse_resize>
": dynamic input scale is not supported!");

arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}

std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
}

// scales still missing, must be an attribute
if(vec_scale.empty())
{
if(contains(info.attributes, "scales"))
{
copy(info.attributes["scales"].floats(), std::back_inserter(vec_scale));
}
else
{
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": scale not provided!");
}
}
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": ranks of input and scale are different!");
}

// if the output was not calculated yet, we update it based on the scales
if(all_of(out_lens.cbegin(), out_lens.cend(), [](auto o) { return o == 0; }))
{
std::transform(
in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
}

shape out_s{in_s.type(), out_lens};
std::size_t out_elements = out_s.elements();
auto idx_op = get_original_idx_op(coord_trans_mode);
Expand Down
16 changes: 16 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8158,6 +8158,22 @@ def upsample_test():
return ([node], [X], [Y], [scale_tensor])


@onnx_test()
def upsample_ver7_test():
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6])

node = onnx.helper.make_node(
'Upsample',
inputs=['X'],
outputs=['Y'],
mode='nearest',
scales=[1.0, 1.0, 2.0, 3.0]
)

return ([node], [X], [Y])


@onnx_test()
def variable_batch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
Expand Down
21 changes: 21 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7640,6 +7640,27 @@ TEST_CASE(upsample_test)
EXPECT(p == prog);
}

TEST_CASE(upsample_ver7_test)
{
migraphx::program p;
auto* mm = p.get_main_module();

migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = mm->add_parameter("X", sx);

migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};

auto li = mm->add_literal(migraphx::literal(si, ind));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), ix);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li);
mm->add_return({r});

auto prog = migraphx::parse_onnx("upsample_ver7_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(unknown_test_throw_print_error)
{
migraphx::onnx_options options;
Expand Down
Binary file added test/onnx/upsample_ver7_test.onnx
Binary file not shown.

0 comments on commit c3b28d3

Please sign in to comment.