From c5b573f91b13fabe2671f71708f4b82a5bb8db63 Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Mon, 25 Sep 2023 11:19:52 +0000 Subject: [PATCH] Add support for Shrink ONNX operator --- src/onnx/parse_shrink.cpp | 102 ++++++++++++++++++++++++++++ test/onnx/gen_onnx.py | 31 +++++++++ test/onnx/onnx_test.cpp | 68 +++++++++++++++++++ test/onnx/shrink_hard_test.onnx | Bin 0 -> 111 bytes test/onnx/shrink_soft_test.onnx | Bin 0 -> 127 bytes test/onnx/verify_onnx.cpp | 34 ++++++++++ test/ref/shrink.cpp | 113 ++++++++++++++++++++++++++++++++ test/verify/test_shrink.cpp | 64 ++++++++++++++++++ 8 files changed, 412 insertions(+) create mode 100644 src/onnx/parse_shrink.cpp create mode 100644 test/onnx/shrink_hard_test.onnx create mode 100644 test/onnx/shrink_soft_test.onnx create mode 100644 test/ref/shrink.cpp create mode 100644 test/verify/test_shrink.cpp diff --git a/src/onnx/parse_shrink.cpp b/src/onnx/parse_shrink.cpp new file mode 100644 index 00000000000..4af3a9bd0be --- /dev/null +++ b/src/onnx/parse_shrink.cpp @@ -0,0 +1,102 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_shrink : op_parser +{ + std::vector operators() const { return {{"Shrink"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + float bias = 0.0; + if(contains(info.attributes, "bias")) + { + bias = parser.parse_value(info.attributes.at("bias")).at(); + } + float lambd = 0.5; + if(contains(info.attributes, "lambd")) + { + lambd = parser.parse_value(info.attributes.at("lambd")).at(); + } + + auto x_shape = args[0]->get_shape(); + auto lit_bias = + info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {bias}}); + auto lit_neg_lambd = + info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {-lambd}}); + auto lit_lambd = + info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {lambd}}); + auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}}); + + instruction_ref mb_bias; + instruction_ref mb_neg_lambd; + instruction_ref mb_lambd; + instruction_ref mb_zero; + if(x_shape.dynamic()) + { + mb_bias = info.add_instruction(migraphx::make_op("multibroadcast"), lit_bias, args[0]); + mb_neg_lambd = + info.add_instruction(migraphx::make_op("multibroadcast"), lit_neg_lambd, args[0]); + mb_lambd = + info.add_instruction(migraphx::make_op("multibroadcast"), lit_lambd, args[0]); + mb_zero = info.add_instruction(migraphx::make_op("multibroadcast"), lit_zero, args[0]); + } + else + { + mb_bias = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_bias); + mb_neg_lambd = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_neg_lambd); + mb_lambd = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_lambd); + mb_zero = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero); + } + + auto condition_1 = info.add_instruction(migraphx::make_op("less"), args[0], mb_neg_lambd); + auto condition_2 = info.add_instruction(migraphx::make_op("greater"), args[0], mb_lambd); + + auto x_plus_bias = info.add_instruction(migraphx::make_op("add"), args[0], mb_bias); + auto x_min_bias = info.add_instruction(migraphx::make_op("sub"), args[0], mb_bias); + + auto filtered = + info.add_instruction(migraphx::make_op("where"), condition_1, x_plus_bias, mb_zero); + return info.add_instruction(migraphx::make_op("where"), condition_2, x_min_bias, filtered); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index d29ca03e9b9..896c63872ad 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -6423,6 +6423,37 @@ def shape_gather_test(): return ([node_const, node_shape, node_gather], [x], [z]) +@onnx_test() +def shrink_hard_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_soft_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + bias=1.5, + ) + + return ([node], [x], [y]) + + @onnx_test() def sign_test(): x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 8c545bf0b3b..cf4f1b0e314 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -6334,6 +6334,74 @@ TEST_CASE(shape_gather_test) EXPECT(p == prog); } +TEST_CASE(shrink_hard_test) +{ + migraphx::program p; + float bias = 0.0; + float lambd = 1.5; + std::vector lens{5}; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + auto lit_zero = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0}}); + + auto mb_bias = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_bias); + auto mb_neg_lambd = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_neg_lambd); + auto mb_lambd = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_lambd); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_zero); + + auto condition_1 = mm->add_instruction(migraphx::make_op("less"), input, mb_neg_lambd); + auto condition_2 = mm->add_instruction(migraphx::make_op("greater"), input, mb_lambd); + auto x_plus_bias = mm->add_instruction(migraphx::make_op("add"), input, mb_bias); + auto x_min_bias = mm->add_instruction(migraphx::make_op("sub"), input, mb_bias); + auto branch_1 = + mm->add_instruction(migraphx::make_op("where"), condition_1, x_plus_bias, mb_zero); + mm->add_instruction(migraphx::make_op("where"), condition_2, x_min_bias, branch_1); + + auto prog = optimize_onnx("shrink_hard_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(shrink_soft_test) +{ + migraphx::program p; + float bias = 1.5; + float lambd = 1.5; + std::vector lens{5}; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + auto lit_zero = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0}}); + + auto mb_bias = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_bias); + auto mb_neg_lambd = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_neg_lambd); + auto mb_lambd = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_lambd); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_zero); + + auto condition_1 = mm->add_instruction(migraphx::make_op("less"), input, mb_neg_lambd); + auto condition_2 = mm->add_instruction(migraphx::make_op("greater"), input, mb_lambd); + auto x_plus_bias = mm->add_instruction(migraphx::make_op("add"), input, mb_bias); + auto x_min_bias = mm->add_instruction(migraphx::make_op("sub"), input, mb_bias); + auto branch_1 = + mm->add_instruction(migraphx::make_op("where"), condition_1, x_plus_bias, mb_zero); + mm->add_instruction(migraphx::make_op("where"), condition_2, x_min_bias, branch_1); + auto prog = optimize_onnx("shrink_soft_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(sign_test) { migraphx::program p; diff --git a/test/onnx/shrink_hard_test.onnx b/test/onnx/shrink_hard_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..51c88c32cb6120249efeaa4bf9ae883c48988e57 GIT binary patch literal 111 zcmdAPfMJWf>j- literal 0 HcmV?d00001 diff --git a/test/onnx/shrink_soft_test.onnx b/test/onnx/shrink_soft_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..50d2bd872ea346e357636b88ece783c407661daf GIT binary patch literal 127 zcmd data{-2, -1, 0, 1, 2}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2, 0, 0, 0, 2}; + EXPECT(migraphx::verify::verify_range(result_vector, gold)); +} + +TEST_CASE(shrink_soft_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_soft_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data{-2, -1, 0, 1, 2}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-0.5, 0, 0, 0, 0.5}; + EXPECT(migraphx::verify::verify_range(result_vector, gold)); +} + TEST_CASE(size_verify_test) { migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx"); diff --git a/test/ref/shrink.cpp b/test/ref/shrink.cpp new file mode 100644 index 00000000000..05b7c290c06 --- /dev/null +++ b/test/ref/shrink.cpp @@ -0,0 +1,113 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(shrink_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector lens{2, 3}; + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto input = + mm->add_literal(migraphx::literal{s, {-5.342, -2.134, -1.028, 0.145, 1.498, 2.887}}); + float lambd = 1.5; + float bias = 0.0; + + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + auto lit_zero = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0}}); + + auto mb_bias = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_bias); + auto mb_neg_lambd = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_neg_lambd); + auto mb_lambd = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_lambd); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_zero); + + auto condition_1 = mm->add_instruction(migraphx::make_op("less"), input, mb_neg_lambd); + auto condition_2 = mm->add_instruction(migraphx::make_op("greater"), input, mb_lambd); + auto x_plus_bias = mm->add_instruction(migraphx::make_op("add"), input, mb_bias); + auto x_min_bias = mm->add_instruction(migraphx::make_op("sub"), input, mb_bias); + auto branch_1 = + mm->add_instruction(migraphx::make_op("where"), condition_1, x_plus_bias, mb_zero); + mm->add_instruction(migraphx::make_op("where"), condition_2, x_min_bias, branch_1); + + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-5.342, -2.134, 0, 0, 0, 2.887}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(shrink_dyn_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape::dynamic_dimension dd{3, 8}; + migraphx::shape s{migraphx::shape::float_type, {dd}}; + auto input = mm->add_parameter("X", s); + float lambd = 1.5; + float bias = 1.5; + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + auto lit_zero = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0}}); + + auto mb_bias = mm->add_instruction(migraphx::make_op("multibroadcast"), lit_bias, input); + auto mb_neg_lambd = + mm->add_instruction(migraphx::make_op("multibroadcast"), lit_neg_lambd, input); + auto mb_lambd = mm->add_instruction(migraphx::make_op("multibroadcast"), lit_lambd, input); + auto mb_zero = mm->add_instruction(migraphx::make_op("multibroadcast"), lit_zero, input); + + auto condition_1 = mm->add_instruction(migraphx::make_op("less"), input, mb_neg_lambd); + auto condition_2 = mm->add_instruction(migraphx::make_op("greater"), input, mb_lambd); + auto x_plus_bias = mm->add_instruction(migraphx::make_op("add"), input, mb_bias); + auto x_min_bias = mm->add_instruction(migraphx::make_op("sub"), input, mb_bias); + auto branch_1 = + mm->add_instruction(migraphx::make_op("where"), condition_1, x_plus_bias, mb_zero); + mm->add_instruction(migraphx::make_op("where"), condition_2, x_min_bias, branch_1); + + p.compile(migraphx::make_target("ref")); + + std::vector input_data{-5.342, -2.134, -1.028, 0.145, 1.498, 2.887, 3.934}; + migraphx::parameter_map params0; + migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {7}}; + params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); + auto result = p.eval(params0).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-3.842, -0.634, 0, 0, 0, 1.387, 2.434}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} diff --git a/test/verify/test_shrink.cpp b/test/verify/test_shrink.cpp new file mode 100644 index 00000000000..8d2f2ca1f75 --- /dev/null +++ b/test/verify/test_shrink.cpp @@ -0,0 +1,64 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_shrink : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + float bias = 0.0; + float lambd = 1.5; + std::vector lens{2, 3, 4, 5}; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = + mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + auto lit_zero = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0}}); + + auto mb_bias = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_bias); + auto mb_neg_lambd = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_neg_lambd); + auto mb_lambd = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_lambd); + auto mb_zero = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lit_zero); + + auto condition_1 = mm->add_instruction(migraphx::make_op("less"), input, mb_neg_lambd); + auto condition_2 = mm->add_instruction(migraphx::make_op("greater"), input, mb_lambd); + auto x_plus_bias = mm->add_instruction(migraphx::make_op("add"), input, mb_bias); + auto x_min_bias = mm->add_instruction(migraphx::make_op("sub"), input, mb_bias); + auto branch_1 = + mm->add_instruction(migraphx::make_op("where"), condition_1, x_plus_bias, mb_zero); + mm->add_instruction(migraphx::make_op("where"), condition_2, x_min_bias, branch_1); + return p; + } +};