diff --git a/src/onnx/parse_shrink.cpp b/src/onnx/parse_shrink.cpp new file mode 100644 index 00000000000..669706425c7 --- /dev/null +++ b/src/onnx/parse_shrink.cpp @@ -0,0 +1,85 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_shrink : op_parser +{ + std::vector operators() const { return {{"Shrink"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + float bias = 0.0; + if(contains(info.attributes, "bias")) + { + bias = parser.parse_value(info.attributes.at("bias")).at(); + } + float lambd = 0.5; + if(contains(info.attributes, "lambd")) + { + lambd = parser.parse_value(info.attributes.at("lambd")).at(); + } + + auto x = args[0]; + auto x_shape = x->get_shape(); + auto x_type = x_shape.type(); + auto lit_bias = info.add_literal(bias); + auto lit_neg_lambd = info.add_literal(-lambd); + auto lit_lambd = info.add_literal(lambd); + + auto x_plus_bias = info.add_common_op("add", x, lit_bias); + auto x_min_bias = info.add_common_op("sub", x, lit_bias); + + auto cond1 = info.add_common_op("less", x, lit_neg_lambd); + auto cond2_a = info.add_common_op("not", cond1); + auto cond2_b = info.add_common_op("greater", x, lit_lambd); + auto cond2 = info.add_common_op("logical_and", cond2_a, cond2_b); + + auto mul1 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond1); + auto mul2 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond2); + + auto first = info.add_common_op("mul", mul1, x_plus_bias); + auto second = info.add_common_op("mul", mul2, x_min_bias); + auto ret = info.add_common_op("add", first, second); + if(ret->get_shape().type() != x_type) + { + ret = info.add_instruction(make_op("convert", {{"target_type", x_type}}), ret); + } + return ret; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 73d42789de3..85633b33d49 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -6718,6 +6718,101 @@ def shape_gather_test(): return ([node_const, node_shape, node_gather], [x], [z]) +@onnx_test() +def shrink_hard_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_soft_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + bias=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=-5.0, + bias=1.0, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_verify2_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=-6.0, + bias=5.0, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_int8_test(): + x = helper.make_tensor_value_info('x', TensorProto.INT8, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.INT8, [3, 3]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + bias=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_uint8_test(): + x = helper.make_tensor_value_info('x', TensorProto.UINT8, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.UINT8, [3, 3]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=5.0, + bias=-4.5, + ) + + return ([node], [x], [y]) + + @onnx_test() def sign_test(): x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index a1fb16caa7f..4299495b1e6 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -6610,6 +6610,73 @@ TEST_CASE(shape_gather_test) EXPECT(p == prog); } +TEST_CASE(shrink_hard_test) +{ + migraphx::program p; + float bias = 0.0; + float lambd = 1.5; + std::vector lens{5}; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond1); + auto mul2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + add_common_op(*mm, migraphx::make_op("add"), {first, second}); + auto prog = optimize_onnx("shrink_hard_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(shrink_int8_test) +{ + migraphx::program p; + float bias = 1.5; + float lambd = 1.5; + std::vector lens{3, 3}; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int8_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond1); + auto mul2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second}); + mm->add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), + ret); + auto prog = optimize_onnx("shrink_int8_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(sign_test) { migraphx::program p; diff --git a/test/onnx/shrink_hard_test.onnx b/test/onnx/shrink_hard_test.onnx new file mode 100644 index 00000000000..51c88c32cb6 Binary files /dev/null and b/test/onnx/shrink_hard_test.onnx differ diff --git a/test/onnx/shrink_int8_test.onnx b/test/onnx/shrink_int8_test.onnx new file mode 100644 index 00000000000..29b85880a52 Binary files /dev/null and b/test/onnx/shrink_int8_test.onnx differ diff --git a/test/onnx/shrink_soft_test.onnx b/test/onnx/shrink_soft_test.onnx new file mode 100644 index 00000000000..50d2bd872ea Binary files /dev/null and b/test/onnx/shrink_soft_test.onnx differ diff --git a/test/onnx/shrink_uint8_test.onnx b/test/onnx/shrink_uint8_test.onnx new file mode 100644 index 00000000000..2bf09bdff4f Binary files /dev/null and b/test/onnx/shrink_uint8_test.onnx differ diff --git a/test/onnx/shrink_verify2_test.onnx b/test/onnx/shrink_verify2_test.onnx new file mode 100644 index 00000000000..93e89460ea6 Binary files /dev/null and b/test/onnx/shrink_verify2_test.onnx differ diff --git a/test/onnx/shrink_verify_test.onnx b/test/onnx/shrink_verify_test.onnx new file mode 100644 index 00000000000..8478fb84a4a Binary files /dev/null and b/test/onnx/shrink_verify_test.onnx differ diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index f74efe08392..0b930061484 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -1708,6 +1708,112 @@ TEST_CASE(selu_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } +TEST_CASE(shrink_hard_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_hard_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data{-2, -1, 0, 1, 2}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2, 0, 0, 0, 2}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_soft_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_soft_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data{-2, -1, 0, 1, 2}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-0.5, 0, 0, 0, 0.5}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_verify_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::half_type, {5}}; + std::vector tmp = {-10.0, -5.0, 0.0, 5.0, 10.0}; + std::vector data{tmp.cbegin(), tmp.cend()}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + tmp = {-9.0, -4.0, 1.0, 4.0, 9.0}; + std::vector gold{tmp.cbegin(), tmp.cend()}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_verify2_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_verify2_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::half_type, {5}}; + std::vector tmp = {-10.0, -5.0, 0.0, 5.0, 10.0}; + std::vector data{tmp.cbegin(), tmp.cend()}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + tmp = {-5.0, 0.0, 5.0, 10.0, 5.0}; + std::vector gold{tmp.cbegin(), tmp.cend()}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_int8_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_int8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::int8_type, {3, 3}}; + std::vector data{-4, -3, -2, -1, 0, 1, 2, 3, 4}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2, -1, 0, 0, 0, 0, 0, 1, 2}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_uint8_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_uint8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::uint8_type, {3, 3}}; + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 0, 0, 0, 0, 10, 11, 12, 13}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + TEST_CASE(size_verify_test) { migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx"); diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 3bafd6f2cae..db2dec28e13 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -249,8 +249,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_reversesequence_time_cpu') backend_test.exclude(r'test_scan9_sum_cpu') backend_test.exclude(r'test_scan_sum_cpu') - backend_test.exclude(r'test_shrink_hard_cpu') - backend_test.exclude(r'test_shrink_soft_cpu') backend_test.exclude(r'test_slice_cpu') backend_test.exclude(r'test_slice_default_axes_cpu') backend_test.exclude(r'test_slice_default_steps_cpu') @@ -463,7 +461,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_sequence_model6_cpu') backend_test.exclude(r'test_sequence_model7_cpu') backend_test.exclude(r'test_sequence_model8_cpu') - backend_test.exclude(r'test_shrink_cpu') backend_test.exclude(r'test_strnorm_model_monday_casesensintive_lower_cpu') backend_test.exclude( r'test_strnorm_model_monday_casesensintive_nochangecase_cpu') diff --git a/test/verify/test_shrink.cpp b/test/verify/test_shrink.cpp new file mode 100644 index 00000000000..824375b2e63 --- /dev/null +++ b/test/verify/test_shrink.cpp @@ -0,0 +1,82 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include +#include +#include + +template +struct test_shrink : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + float bias = 1.5; + float lambd = 1.5; + auto* mm = p.get_main_module(); + migraphx::shape is{T, {2, 3}}; + std::vector data{-3, -2, -1, 0, 1, 2}; + auto x = mm->add_literal(migraphx::literal{is, data}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = + mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond1); + auto mul2 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second}); + if(ret->get_shape().type() != T) + { + mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), ret); + } + return p; + } +}; + +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink;