From 6e496c1f39a7d5df35a97924922973ee67d57206 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Mon, 15 Apr 2024 21:45:21 -0400 Subject: [PATCH] Parse ONNX Expand to `broadcast_with_dims` (#2799) --- src/CMakeLists.txt | 1 + .../migraphx/op/broadcast_with_dims.hpp | 84 +++++++++ src/onnx/parse_expand.cpp | 20 ++- src/simplify_dyn_ops.cpp | 30 ++++ test/onnx/expand_dyn_test.onnx | 19 ++ test/onnx/gen_onnx.py | 11 ++ test/onnx/parse/expand_test.cpp | 16 +- test/op_shape_test.cpp | 37 ++++ test/ref/broadcast_with_dims.cpp | 163 ++++++++++++++++++ test/simplify_dyn_ops_test.cpp | 50 ++++++ 10 files changed, 424 insertions(+), 7 deletions(-) create mode 100644 src/include/migraphx/op/broadcast_with_dims.hpp create mode 100644 test/onnx/expand_dyn_test.onnx create mode 100644 test/ref/broadcast_with_dims.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f269aa8cb38..f38380225d7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -142,6 +142,7 @@ register_migraphx_ops( atan broadcast broadcast_for_dot + broadcast_with_dims capture ceil clip diff --git a/src/include/migraphx/op/broadcast_with_dims.hpp b/src/include/migraphx/op/broadcast_with_dims.hpp new file mode 100644 index 00000000000..66267008c26 --- /dev/null +++ b/src/include/migraphx/op/broadcast_with_dims.hpp @@ -0,0 +1,84 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_BROADCAST_WITH_DIMS_HPP +#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_WITH_DIMS_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +/** + * Broadcast the input tensor to the shape defined by the values of the second input. + * Used as `broadcast_with_dims(input_tensor, dims)`, where dims is a vector of integer dimensions. + * `input_tensor` must be broadcastable with `dims`, otherwise this operator with throw at compute. + * This operator can be replaced with `multibroadcast(input_tensor)` if the `dims` vector is + * constant. + * + * Example: + * input_tensor shape: lens = {2, 3}, strides = {3, 1} + * dims = [4, 1, 3] + * output shape: lens = {4, 2, 3}, strides = {0, 3, 1} + */ +struct broadcast_with_dims +{ + std::string name() const { return "broadcast_with_dims"; } + + shape compute_shape(const std::vector& inputs) const + { + migraphx::check_shapes{inputs, *this, true}.has(2); + // check that second input has a static shape + (void)migraphx::check_shapes{inputs.begin() + 1, inputs.end(), *this, false}; + // output tensor rank is greater of input_tensor rank or length of dims vector + const auto& input_tensor_shape = inputs.at(0); + const auto& dims_shape = inputs.at(1); + size_t out_ndim = std::max(input_tensor_shape.ndim(), dims_shape.lens().at(0)); + std::size_t max_int = std::numeric_limits::max(); + std::vector dyn_dims(out_ndim, + shape::dynamic_dimension{0, max_int}); + return {input_tensor_shape.type(), dyn_dims}; + } + + argument compute(const shape& output_shape, const std::vector& args) const + { + auto s0 = args.at(0).get_shape(); + const auto& in_lens = s0.lens(); + std::vector dims_input(output_shape.ndim()); + args.at(1).visit([&](auto a) { dims_input.assign(a.begin(), a.end()); }); + auto out_lens = compute_broadcasted_lens(in_lens, dims_input); + auto out_shape = make_bcast_shape(s0, out_lens); + return args[0].reshape(out_shape); + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/parse_expand.cpp b/src/onnx/parse_expand.cpp index ee11a7f4070..0762fbb6243 100644 --- a/src/onnx/parse_expand.cpp +++ b/src/onnx/parse_expand.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -43,11 +43,19 @@ struct parse_expand : op_parser { auto in_lens = args[0]->get_shape().lens(); migraphx::argument arg_s = args[1]->eval(); - check_arg_empty(arg_s, "Expand: dynamic shape is not supported"); - std::vector dims; - arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); - auto out_lens = compute_broadcasted_lens(in_lens, dims); - return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]); + if(arg_s.empty()) + { + // variable dims input + return info.add_instruction(make_op("broadcast_with_dims"), args[0], args[1]); + } + else + { + std::vector dims; + arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); + auto out_lens = compute_broadcasted_lens(in_lens, dims); + return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), + args[0]); + } } }; diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index 5a3d5bce186..3785c19609f 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -33,6 +33,35 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +/** + * Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input + * into multibroadcast op with a static output shape attribute. + * + */ +struct find_broadcast_with_dims_static +{ + auto matcher() const + { + return match::name("broadcast_with_dims")(match::nargs(2), + match::arg(0)(match::static_shape()), + match::arg(1)(match::is_constant())); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto inputs = ins->inputs(); + + // read the values of arg(1) to create input to multibroadcast + std::vector sizes_vec; + inputs.at(1)->eval().visit( + [&](auto output) { sizes_vec.assign(output.begin(), output.end()); }); + + m.replace_instruction( + ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0)); + } +}; + /** * Convert a Resize op. with Nearest mode to an implementation using Gather op. * From: resize[scales={...}/sizes={...},](static, constant) @@ -586,6 +615,7 @@ struct simplify_select_module_output_shape void simplify_dyn_ops::apply(module& m) const { match::find_matches(m, + find_broadcast_with_dims_static{}, find_resize_static{}, find_static_dimensions_of{}, find_const_alloc_reshapes{}, diff --git a/test/onnx/expand_dyn_test.onnx b/test/onnx/expand_dyn_test.onnx new file mode 100644 index 00000000000..80701d8a491 --- /dev/null +++ b/test/onnx/expand_dyn_test.onnx @@ -0,0 +1,19 @@ + expand_dyn_test:q + +x +dimsy"Expandexpand_dyn_testZ +x + + + +Z +dims + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 9db60e68541..c1f6b1fe6e6 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -2203,6 +2203,17 @@ def expand_test(): return ([shape_const, node], [x], [y]) +@onnx_test() +def expand_dyn_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 1]) + dims_in = helper.make_tensor_value_info('dims', TensorProto.INT64, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5]) + + node = onnx.helper.make_node('Expand', inputs=['x', 'dims'], outputs=['y']) + + return ([node], [x, dims_in], [y]) + + @onnx_test(True) def external_constant_test(): x = np.array([0, 1, 2]) diff --git a/test/onnx/parse/expand_test.cpp b/test/onnx/parse/expand_test.cpp index 504a31095b9..892ad222022 100644 --- a/test/onnx/parse/expand_test.cpp +++ b/test/onnx/parse/expand_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,3 +37,17 @@ TEST_CASE(expand_test) auto prog = optimize_onnx("expand_test.onnx"); EXPECT(p == prog); } + +TEST_CASE(expand_dyn_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s(migraphx::shape::float_type, {3, 1, 1}); + auto param = mm->add_parameter("x", s); + migraphx::shape ss(migraphx::shape::int64_type, {4}); + auto dims = mm->add_parameter("dims", ss); + mm->add_instruction(migraphx::make_op("broadcast_with_dims"), param, dims); + + auto prog = optimize_onnx("expand_dyn_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index f5ae4259fc2..1f961115d0c 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -942,6 +942,43 @@ TEST_CASE(broadcast_for_dot_dyn2) s0); } +TEST_CASE(broadcast_with_dims0) +{ + using migraphx::shape; + shape s0{migraphx::shape::float_type, {2, 4}}; + shape s1{migraphx::shape::int64_type, {4}}; + std::size_t max_int = std::numeric_limits::max(); + std::vector dyn_dims(4, shape::dynamic_dimension{0, max_int}); + expect_shape( + shape{shape::float_type, dyn_dims}, migraphx::make_op("broadcast_with_dims"), s0, s1); +} + +TEST_CASE(broadcast_with_dims1) +{ + using migraphx::shape; + shape s0{migraphx::shape::int32_type, {1, 2, 4}}; + shape s1{migraphx::shape::int64_type, {1}}; + std::size_t max_int = std::numeric_limits::max(); + std::vector dyn_dims(3, shape::dynamic_dimension{0, max_int}); + expect_shape(shape{migraphx::shape::int32_type, dyn_dims}, + migraphx::make_op("broadcast_with_dims"), + s0, + s1); +} + +TEST_CASE(broadcast_with_dims2) +{ + using migraphx::shape; + shape s0{migraphx::shape::float_type, {{1, 4}, {2, 2}, {4, 4}}}; + shape s1{migraphx::shape::int64_type, {4}}; + std::size_t max_int = std::numeric_limits::max(); + std::vector dyn_dims(4, shape::dynamic_dimension{0, max_int}); + expect_shape(shape{migraphx::shape::float_type, dyn_dims}, + migraphx::make_op("broadcast_with_dims"), + s0, + s1); +} + TEST_CASE(flatten_shape) { migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; diff --git a/test/ref/broadcast_with_dims.cpp b/test/ref/broadcast_with_dims.cpp new file mode 100644 index 00000000000..80ef52a4cac --- /dev/null +++ b/test/ref/broadcast_with_dims.cpp @@ -0,0 +1,163 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "test.hpp" + +TEST_CASE(broadcast_with_dims_static0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape input_shape{migraphx::shape::int32_type, {2}}; + migraphx::shape dims_shape{migraphx::shape::int64_type, {2}}; + auto input_param = mm->add_parameter("x", input_shape); + auto dims_param = mm->add_parameter("dims", dims_shape); + mm->add_instruction(migraphx::make_op("broadcast_with_dims"), input_param, dims_param); + p.compile(migraphx::make_target("ref")); + + std::vector input_data{-3, 3}; + std::vector dims_data{2, 1}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(input_shape, input_data.data()); + params["dims"] = migraphx::argument(dims_shape, dims_data.data()); + auto result = p.eval(params).back(); + auto output = result.get(); + EXPECT(output.get_shape().lens() == std::vector{2, 2}); + EXPECT(output.get_shape().strides() == std::vector{0, 1}); + EXPECT(output(0, 0) == -3); + EXPECT(output(0, 1) == 3); + EXPECT(output(1, 0) == -3); + EXPECT(output(1, 1) == 3); +} + +TEST_CASE(broadcast_with_dims_static1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape input_shape{migraphx::shape::float_type, {2, 1}, {1, 0}}; + migraphx::shape dims_shape{migraphx::shape::int64_type, {1}}; + auto input_param = mm->add_parameter("x", input_shape); + auto dims_param = mm->add_parameter("dims", dims_shape); + mm->add_instruction(migraphx::make_op("broadcast_with_dims"), input_param, dims_param); + p.compile(migraphx::make_target("ref")); + + std::vector input_data{7, 11}; + std::vector dims_data{3}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(input_shape, input_data.data()); + params["dims"] = migraphx::argument(dims_shape, dims_data.data()); + auto result = p.eval(params).back(); + auto output = result.get(); + EXPECT(output.get_shape().lens() == std::vector{2, 3}); + EXPECT(output.get_shape().strides() == std::vector{1, 0}); + EXPECT(migraphx::float_equal(output(0, 0), 7.f)); + EXPECT(migraphx::float_equal(output(0, 1), 7.f)); + EXPECT(migraphx::float_equal(output(0, 2), 7.f)); + EXPECT(migraphx::float_equal(output(1, 0), 11.f)); + EXPECT(migraphx::float_equal(output(1, 1), 11.f)); + EXPECT(migraphx::float_equal(output(1, 2), 11.f)); +} + +TEST_CASE(broadcast_with_dims_static2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape input_shape{migraphx::shape::float_type, {2, 3}, {3, 1}}; + migraphx::shape dims_shape{migraphx::shape::int64_type, {3}}; + auto input_param = mm->add_parameter("x", input_shape); + auto dims_param = mm->add_parameter("dims", dims_shape); + mm->add_instruction(migraphx::make_op("broadcast_with_dims"), input_param, dims_param); + p.compile(migraphx::make_target("ref")); + + std::vector input_data(6); + std::iota(input_data.begin(), input_data.end(), 0.0); + std::vector dims_data{4, 2, 3}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(input_shape, input_data.data()); + params["dims"] = migraphx::argument(dims_shape, dims_data.data()); + auto result = p.eval(params).back(); + auto output = result.get(); + EXPECT(output.get_shape().lens() == std::vector{4, 2, 3}); + EXPECT(output.get_shape().strides() == std::vector{0, 3, 1}); + std::vector results_vector; + result.visit([&](auto x) { results_vector.assign(x.begin(), x.end()); }); + std::vector gold = { + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, + }; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(broadcast_with_dims_dyn) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dds = {{2, 4}}; + migraphx::shape input_shape{migraphx::shape::int32_type, dds}; + migraphx::shape dims_shape{migraphx::shape::int64_type, {2}}; + auto input_param = mm->add_parameter("x", input_shape); + auto dims_param = mm->add_parameter("dims", dims_shape); + mm->add_instruction(migraphx::make_op("broadcast_with_dims"), input_param, dims_param); + p.compile(migraphx::make_target("ref")); + + std::vector input_data{-3, 3}; + std::vector dims_data{2, 2}; + migraphx::shape input_static_shape{migraphx::shape::int32_type, {2}}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(input_static_shape, input_data.data()); + params["dims"] = migraphx::argument(dims_shape, dims_data.data()); + auto result = p.eval(params).back(); + auto output = result.get(); + EXPECT(output.get_shape().lens() == std::vector{2, 2}); + EXPECT(output.get_shape().strides() == std::vector{0, 1}); + EXPECT(output(0, 0) == -3); + EXPECT(output(0, 1) == 3); + EXPECT(output(1, 0) == -3); + EXPECT(output(1, 1) == 3); +} + +TEST_CASE(broadcast_with_dims_mismatch) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape input_shape{migraphx::shape::float_type, {2, 3}}; + migraphx::shape dims_shape{migraphx::shape::int64_type, {1}}; + auto input_param = mm->add_parameter("x", input_shape); + auto dims_param = mm->add_parameter("dims", dims_shape); + mm->add_instruction(migraphx::make_op("broadcast_with_dims"), input_param, dims_param); + p.compile(migraphx::make_target("ref")); + + std::vector input_data{3, 9}; + std::vector dims_data{6}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(input_shape, input_data.data()); + params["dims"] = migraphx::argument(dims_shape, dims_data.data()); + EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); })); +} diff --git a/test/simplify_dyn_ops_test.cpp b/test/simplify_dyn_ops_test.cpp index caad3cb2771..ebc79b0bce4 100644 --- a/test/simplify_dyn_ops_test.cpp +++ b/test/simplify_dyn_ops_test.cpp @@ -34,6 +34,56 @@ void run_pass(migraphx::module& m) migraphx::run_passes(m, {migraphx::simplify_dyn_ops{}, migraphx::dead_code_elimination{}}); } +TEST_CASE(broadcast_with_dims) +{ + migraphx::module m0; + { + // the X input + migraphx::shape sx{migraphx::shape::float_type, {3, 1, 1}}; + auto inx = m0.add_parameter("x", sx); + + // the shape input. Broadcast to this + migraphx::shape dims_s{migraphx::shape::int64_type, {4}}; + std::vector dims = {2, 3, 4, 5}; + auto out_dims = m0.add_literal(migraphx::literal{dims_s, dims}); + + auto r = m0.add_instruction(migraphx::make_op("broadcast_with_dims"), inx, out_dims); + m0.add_return({r}); + } + run_pass(m0); + + migraphx::module m1; + { + migraphx::shape sx{migraphx::shape::float_type, {3, 1, 1}}; + auto inx = m1.add_parameter("x", sx); + + auto r = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), inx); + m1.add_return({r}); + } + EXPECT(m0 == m1); +} + +TEST_CASE(broadcast_with_dims_invalid) +{ + migraphx::module m0; + { + // X input shape is not broadcastable to given shape + migraphx::shape sx{migraphx::shape::float_type, {3, 1, 2}}; + auto inx = m0.add_parameter("x", sx); + + // the shape input. Broadcast to this + migraphx::shape dims_s{migraphx::shape::int64_type, {4}}; + std::vector dims = {2, 3, 4, 5}; + auto out_dims = m0.add_literal(migraphx::literal{dims_s, dims}); + + auto r = m0.add_instruction(migraphx::make_op("broadcast_with_dims"), inx, out_dims); + m0.add_return({r}); + } + // replacement will be rejected by multibroadcast operation + EXPECT(test::throws([&] { run_pass(m0); })); +} + TEST_CASE(resize) { migraphx::module m0;