Skip to content

Commit

Permalink
Add support for Shrink ONNX operator
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec committed Oct 16, 2023
1 parent 650ba45 commit 3f3ac05
Show file tree
Hide file tree
Showing 12 changed files with 435 additions and 3 deletions.
85 changes: 85 additions & 0 deletions src/onnx/parse_shrink.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_shrink : op_parser<parse_shrink>
{
std::vector<op_desc> operators() const { return {{"Shrink"}}; }

instruction_ref parse(const op_desc&,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float bias = 0.0;
if(contains(info.attributes, "bias"))
{
bias = parser.parse_value(info.attributes.at("bias")).at<float>();
}
float lambd = 0.5;
if(contains(info.attributes, "lambd"))
{
lambd = parser.parse_value(info.attributes.at("lambd")).at<float>();
}

auto x = args[0];
auto x_shape = x->get_shape();
auto x_type = x_shape.type();
auto lit_bias = info.add_literal(bias);
auto lit_neg_lambd = info.add_literal(-lambd);
auto lit_lambd = info.add_literal(lambd);

auto x_plus_bias = info.add_common_op("add", x, lit_bias);
auto x_min_bias = info.add_common_op("sub", x, lit_bias);

auto cond1 = info.add_common_op("less", x, lit_neg_lambd);
auto cond2_a = info.add_common_op("not", cond1);
auto cond2_b = info.add_common_op("greater", x, lit_lambd);
auto cond2 = info.add_common_op("logical_and", cond2_a, cond2_b);

auto mul1 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond1);
auto mul2 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond2);

auto first = info.add_common_op("mul", mul1, x_plus_bias);
auto second = info.add_common_op("mul", mul2, x_min_bias);
auto ret = info.add_common_op("add", first, second);
if(ret->get_shape().type() != x_type)
{
ret = info.add_instruction(make_op("convert", {{"target_type", x_type}}), ret);
}
return ret;
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
95 changes: 95 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6718,6 +6718,101 @@ def shape_gather_test():
return ([node_const, node_shape, node_gather], [x], [z])


@onnx_test()
def shrink_hard_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
)

return ([node], [x], [y])


@onnx_test()
def shrink_soft_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
bias=1.5,
)

return ([node], [x], [y])


@onnx_test()
def shrink_verify_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=-5.0,
bias=1.0,
)

return ([node], [x], [y])


@onnx_test()
def shrink_verify2_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=-6.0,
bias=5.0,
)

return ([node], [x], [y])


@onnx_test()
def shrink_int8_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [3, 3])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [3, 3])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
bias=1.5,
)

return ([node], [x], [y])


@onnx_test()
def shrink_uint8_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [3, 3])
y = helper.make_tensor_value_info('y', TensorProto.UINT8, [3, 3])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=5.0,
bias=-4.5,
)

return ([node], [x], [y])


@onnx_test()
def sign_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5])
Expand Down
67 changes: 67 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6610,6 +6610,73 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog);
}

TEST_CASE(shrink_hard_test)
{
migraphx::program p;
float bias = 0.0;
float lambd = 1.5;
std::vector<size_t> lens{5};
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});

auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});

auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});

auto mul1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond1);
auto mul2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond2);

auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
add_common_op(*mm, migraphx::make_op("add"), {first, second});
auto prog = optimize_onnx("shrink_hard_test.onnx");
EXPECT(p == prog);
}

TEST_CASE(shrink_int8_test)
{
migraphx::program p;
float bias = 1.5;
float lambd = 1.5;
std::vector<size_t> lens{3, 3};
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int8_type, lens});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});

auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});

auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});

auto mul1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond1);
auto mul2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond2);

auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second});
mm->add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}),
ret);
auto prog = optimize_onnx("shrink_int8_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(sign_test)
{
migraphx::program p;
Expand Down
Binary file added test/onnx/shrink_hard_test.onnx
Binary file not shown.
Binary file added test/onnx/shrink_int8_test.onnx
Binary file not shown.
Binary file added test/onnx/shrink_soft_test.onnx
Binary file not shown.
Binary file added test/onnx/shrink_uint8_test.onnx
Binary file not shown.
Binary file added test/onnx/shrink_verify2_test.onnx
Binary file not shown.
Binary file added test/onnx/shrink_verify_test.onnx
Binary file not shown.
106 changes: 106 additions & 0 deletions test/onnx/verify_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,112 @@ TEST_CASE(selu_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(shrink_hard_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_hard_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data{-2, -1, 0, 1, 2};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, 0, 0, 0, 2};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(shrink_soft_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_soft_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data{-2, -1, 0, 1, 2};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.5, 0, 0, 0, 0.5};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(shrink_verify_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_verify_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::half_type, {5}};
std::vector<float> tmp = {-10.0, -5.0, 0.0, 5.0, 10.0};
std::vector<migraphx::half> data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {-9.0, -4.0, 1.0, 4.0, 9.0};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(shrink_verify2_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_verify2_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::half_type, {5}};
std::vector<float> tmp = {-10.0, -5.0, 0.0, 5.0, 10.0};
std::vector<migraphx::half> data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {-5.0, 0.0, 5.0, 10.0, 5.0};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(shrink_int8_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_int8_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::int8_type, {3, 3}};
std::vector<int8_t> data{-4, -3, -2, -1, 0, 1, 2, 3, 4};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {-2, -1, 0, 0, 0, 0, 0, 1, 2};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(shrink_uint8_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_uint8_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s{migraphx::shape::uint8_type, {3, 3}};
std::vector<uint8_t> data{1, 2, 3, 4, 5, 6, 7, 8, 9};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());

auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {0, 0, 0, 0, 0, 10, 11, 12, 13};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(size_verify_test)
{
migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx");
Expand Down
3 changes: 0 additions & 3 deletions test/py/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_reversesequence_time_cpu')
backend_test.exclude(r'test_scan9_sum_cpu')
backend_test.exclude(r'test_scan_sum_cpu')
backend_test.exclude(r'test_shrink_hard_cpu')
backend_test.exclude(r'test_shrink_soft_cpu')
backend_test.exclude(r'test_slice_cpu')
backend_test.exclude(r'test_slice_default_axes_cpu')
backend_test.exclude(r'test_slice_default_steps_cpu')
Expand Down Expand Up @@ -463,7 +461,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_sequence_model6_cpu')
backend_test.exclude(r'test_sequence_model7_cpu')
backend_test.exclude(r'test_sequence_model8_cpu')
backend_test.exclude(r'test_shrink_cpu')
backend_test.exclude(r'test_strnorm_model_monday_casesensintive_lower_cpu')
backend_test.exclude(
r'test_strnorm_model_monday_casesensintive_nochangecase_cpu')
Expand Down
Loading

0 comments on commit 3f3ac05

Please sign in to comment.