From 5434e87165afa25fe967ceeffeca23bf6f156f28 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Fri, 31 May 2024 16:52:11 -0400 Subject: [PATCH] `Split` dynamic shape parsing update (#3034) --- src/onnx/parse_split.cpp | 198 ++++++++++++------ test/onnx/gen_onnx.py | 71 +++++++ test/onnx/parse/split_dyn_input.cpp | 107 ++++++++++ test/onnx/parse/split_minus_axis_test.cpp | 6 +- .../split_dyn_input_dyn_split_axis_test.onnx | Bin 0 -> 199 bytes ...split_dyn_input_fixed_split_axis_test.onnx | Bin 0 -> 203 bytes .../onnx/split_dyn_input_split_attr_test.onnx | Bin 0 -> 209 bytes .../split_dyn_input_split_input_test.onnx | Bin 0 -> 249 bytes test/onnx/verify/split_dyn_verify_test.cpp | 101 +++++++++ 9 files changed, 417 insertions(+), 66 deletions(-) create mode 100644 test/onnx/parse/split_dyn_input.cpp create mode 100644 test/onnx/split_dyn_input_dyn_split_axis_test.onnx create mode 100644 test/onnx/split_dyn_input_fixed_split_axis_test.onnx create mode 100644 test/onnx/split_dyn_input_split_attr_test.onnx create mode 100644 test/onnx/split_dyn_input_split_input_test.onnx create mode 100644 test/onnx/verify/split_dyn_verify_test.cpp diff --git a/src/onnx/parse_split.cpp b/src/onnx/parse_split.cpp index 8de16f80fad..5285e7230d5 100644 --- a/src/onnx/parse_split.cpp +++ b/src/onnx/parse_split.cpp @@ -34,6 +34,131 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { +auto parse_dyn_split(const onnx_parser::node_info& info, + const std::vector& args, + int64_t tuned_axis) +{ + if(contains(info.attributes, "split")) + { + MIGRAPHX_THROW("PARSE_SPLIT: dynamic input and non-fixed split axis and `split` " + "attribute not supported"); + } + if(args.size() == 2) + { + MIGRAPHX_THROW("PARSE_SPLIT: dynamic input and non-fixed split axis and `split` " + "input not supported"); + } + + std::size_t num_outputs = info.num_outputs; + std::vector ret_ins(num_outputs); + + // Doing shape calculations for the splits in the graph + auto split_dim = info.add_instruction( + make_op("dimensions_of", {{"start", tuned_axis}, {"end", tuned_axis + 1}}), args[0]); + shape int64_scalar_shape{shape::int64_type, {1}, {0}}; + auto num_outputs_lit = info.add_literal(literal{int64_scalar_shape, {num_outputs}}); + auto num_outputs_minus_1_lit = info.add_literal(literal{int64_scalar_shape, {num_outputs - 1}}); + // (A + (B - 1)) / B == ceil(A / B) + auto chunk_size = info.add_instruction( + make_op("div"), + info.add_instruction(make_op("add"), split_dim, num_outputs_minus_1_lit), + num_outputs_lit); + for(int n = 0; n < num_outputs - 1; ++n) + { + // slice(input, starts = {n * chunk_size}, ends = {(n+1) * chunk_size}); axes = + // {tuned_axis} + ret_ins.at(n) = info.add_instruction( + make_op("slice", {{"axes", {tuned_axis}}}), + args[0], + info.add_instruction( + make_op("mul"), chunk_size, info.add_literal(literal{int64_scalar_shape, {n}})), + info.add_instruction(make_op("mul"), + chunk_size, + info.add_literal(literal{int64_scalar_shape, {n + 1}}))); + } + // last slice: slice(input, starts = {n * chunk_size}); ends = max_int, axes = + // {tuned_axis} + ret_ins.at(num_outputs - 1) = info.add_instruction( + make_op("slice", {{"axes", {tuned_axis}}, {"ends", {std::numeric_limits::max()}}}), + args[0], + info.add_instruction(make_op("mul"), + chunk_size, + info.add_literal(literal{int64_scalar_shape, {num_outputs - 1}}))); + return ret_ins; +} + +auto parse_static_split(const onnx_parser::node_info& info, + const onnx_parser& parser, + const std::vector& args, + int64_t tuned_axis) +{ + const auto& input_shape = args[0]->get_shape(); + // either static shape or fixed dynamic_dimension for split axis + auto tuned_axis_len = input_shape.to_static(0).lens().at(tuned_axis); + std::vector vec_splits; + if(contains(info.attributes, "split")) + { + literal s = parser.parse_value(info.attributes.at("split")); + s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); }); + } + else if(args.size() == 2) + { + auto s = args[1]->eval(); + check_arg_empty(s, "PARSE_SPLIT: non-constant `split` input is not supported"); + s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); }); + } + // no split attribute, input is equally divided + else + { + std::size_t num_outputs = info.num_outputs; + // the num_outputs attribute seems to be redundant since we already have + // node_info::num_outputs, but we can still perform an error check + if(contains(info.attributes, "num_outputs")) + { + num_outputs = parser.parse_value(info.attributes.at("num_outputs")).at(); + if(num_outputs != info.num_outputs) + { + MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " + std::to_string(num_outputs) + + " doesn't match actual number of outputs " + + std::to_string(info.num_outputs) + "!"); + } + } + if(tuned_axis_len % num_outputs == 0) + { + std::size_t chunk_size = tuned_axis_len / num_outputs; + vec_splits.resize(num_outputs, chunk_size); + } + else + { + std::size_t chunk_size = tuned_axis_len / num_outputs + 1; + std::size_t last_chunk_size = tuned_axis_len - chunk_size * (num_outputs - 1); + vec_splits.resize(num_outputs - 1, chunk_size); + vec_splits.push_back(last_chunk_size); + } + } + + if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) != + static_cast(tuned_axis_len)) + { + MIGRAPHX_THROW( + "PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" + + std::to_string(tuned_axis_len) + " Output " + to_string_range(vec_splits) + " Rank " + + std::to_string(input_shape.ndim())); + } + + std::vector ret_ins; + int64_t start = 0; + for(auto sl : vec_splits) + { + ret_ins.push_back(info.add_instruction( + make_op("slice", {{"axes", {tuned_axis}}, {"starts", {start}}, {"ends", {start + sl}}}), + args[0])); + start += sl; + } + + return ret_ins; +} + struct parse_split : op_parser { std::vector operators() const { return {{"Split"}}; } @@ -49,75 +174,22 @@ struct parse_split : op_parser axis = parser.parse_value(info.attributes.at("axis")).at(); } - auto lens = args[0]->get_shape().lens(); - int64_t n_rank = lens.size(); - int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name); + const auto& input_shape = args[0]->get_shape(); + // axis over which the split occurs (split_axis) + int64_t tuned_axis = tune_axis(input_shape.ndim(), axis, opd.op_name); - std::vector vec_splits; - if(contains(info.attributes, "split")) - { - literal s = parser.parse_value(info.attributes.at("split")); - s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); }); - } - else if(args.size() == 2) - { - auto s = args[1]->eval(); - check_arg_empty(s, "Split: dynamic shape is not supported"); - s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); }); - } - // no split attribute, input is equally divided - else - { - std::size_t num_outputs = info.num_outputs; - // the num_outputs attribute seems to be redundant since we already have - // node_info::num_outputs, but we can still perform an error check - if(contains(info.attributes, "num_outputs")) - { - num_outputs = - parser.parse_value(info.attributes.at("num_outputs")).at(); - if(num_outputs != info.num_outputs) - { - MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " + - std::to_string(num_outputs) + - " doesn't match actual number of outputs " + - std::to_string(info.num_outputs) + "!"); - } - } - - if(lens[tuned_axis] % num_outputs == 0) - { - std::size_t chunk_size = lens[tuned_axis] / num_outputs; - vec_splits.resize(num_outputs, chunk_size); - } - else - { - std::size_t chunk_size = lens[tuned_axis] / num_outputs + 1; - std::size_t last_chunk_size = lens[tuned_axis] - chunk_size * (num_outputs - 1); - vec_splits.resize(num_outputs - 1, chunk_size); - vec_splits.push_back(last_chunk_size); - } - } + auto split_axis_is_fixed = [&]() { + return input_shape.dyn_dims().at(tuned_axis).is_fixed(); + }; - if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) != - static_cast(lens[tuned_axis])) + if(input_shape.dynamic() and not split_axis_is_fixed()) { - MIGRAPHX_THROW( - "PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" + - std::to_string(lens[tuned_axis]) + " Output " + to_string_range(vec_splits) + - " Rank " + std::to_string(n_rank) + " Len outs " + to_string_range(lens)); + return parse_dyn_split(info, args, tuned_axis); } - - std::vector ret_ins; - int64_t start = 0; - for(auto sl : vec_splits) + else { - ret_ins.push_back(info.add_instruction( - make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}), - args[0])); - start += sl; + return parse_static_split(info, parser, args, tuned_axis); } - - return ret_ins; } }; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index e1effe70183..641714e5af4 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10828,6 +10828,77 @@ def split_test_invalid_num_outputs(): return ([node], [x], [y1, y2, y3, y4]) +@onnx_test() +def split_dyn_input_fixed_split_axis_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5]) + + node = onnx.helper.make_node('Split', + inputs=['x'], + outputs=['y1', 'y2', 'y3'], + axis=1) + + return ([node], [x], [y1, y2, y3]) + + +@onnx_test() +def split_dyn_input_dyn_split_axis_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5]) + + node = onnx.helper.make_node('Split', + inputs=['x'], + outputs=['y1', 'y2', 'y3'], + axis=0) + + return ([node], [x], [y1, y2, y3]) + + +@onnx_test() +def split_dyn_input_split_attr_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5]) + + node = onnx.helper.make_node('Split', + inputs=['x'], + outputs=['y1', 'y2', 'y3'], + axis=0, + split=[7, 4, 4]) + + return ([node], [x], [y1, y2, y3]) + + +@onnx_test() +def split_dyn_input_split_input_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5]) + + split = np.ones(3) * 5 + split_tensor = helper.make_tensor(name="split", + data_type=TensorProto.INT64, + dims=split.shape, + vals=split.astype(np.int64)) + const_node = helper.make_node("Constant", + inputs=[], + outputs=['split'], + value=split_tensor) + + node = onnx.helper.make_node('Split', + inputs=['x', 'split'], + outputs=['y1', 'y2', 'y3'], + axis=0) + + return ([const_node, node], [x], [y1, y2, y3]) + + @onnx_test() def sqrt_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) diff --git a/test/onnx/parse/split_dyn_input.cpp b/test/onnx/parse/split_dyn_input.cpp new file mode 100644 index 00000000000..7587a4dc4c6 --- /dev/null +++ b/test/onnx/parse/split_dyn_input.cpp @@ -0,0 +1,107 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(split_dyn_input_fixed_split_axis_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {{10, 30}, {15, 15}}}); + auto r1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), input); + auto r2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {5}}, {"ends", {10}}}), input); + auto r3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {10}}, {"ends", {15}}}), input); + mm->add_return({r1, r2, r3}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + auto prog = read_onnx("split_dyn_input_fixed_split_axis_test.onnx", options); + EXPECT(p == prog); +} + +TEST_CASE(split_dyn_input_dyn_split_axis_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {{10, 30}, {15, 15}}}); + auto split_dim = + mm->add_instruction(migraphx::make_op("dimensions_of", {{"start", 0}, {"end", 1}}), input); + migraphx::shape int64_scalar_shape{migraphx::shape::int64_type, {1}, {0}}; + auto num_outputs_lit = mm->add_literal(migraphx::literal{int64_scalar_shape, {3}}); + auto num_outputs_minus_1_lit = mm->add_literal(migraphx::literal{int64_scalar_shape, {2}}); + auto chunk_size = mm->add_instruction( + migraphx::make_op("div"), + mm->add_instruction(migraphx::make_op("add"), split_dim, num_outputs_minus_1_lit), + num_outputs_lit); + auto r1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}}), + input, + mm->add_instruction(migraphx::make_op("mul"), + chunk_size, + mm->add_literal(migraphx::literal{int64_scalar_shape, {0}})), + mm->add_instruction(migraphx::make_op("mul"), + chunk_size, + mm->add_literal(migraphx::literal{int64_scalar_shape, {1}}))); + auto r2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}}), + input, + mm->add_instruction(migraphx::make_op("mul"), + chunk_size, + mm->add_literal(migraphx::literal{int64_scalar_shape, {1}})), + mm->add_instruction(migraphx::make_op("mul"), + chunk_size, + mm->add_literal(migraphx::literal{int64_scalar_shape, {2}}))); + auto r3 = mm->add_instruction( + migraphx::make_op("slice", + {{"axes", {0}}, {"ends", {std::numeric_limits::max()}}}), + input, + mm->add_instruction(migraphx::make_op("mul"), + chunk_size, + mm->add_literal(migraphx::literal{int64_scalar_shape, {2}}))); + mm->add_return({r1, r2, r3}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + auto prog = read_onnx("split_dyn_input_dyn_split_axis_test.onnx", options); + EXPECT(p == prog); +} + +TEST_CASE(split_dyn_input_split_attr_error) +{ + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + EXPECT(test::throws([&] { read_onnx("split_dyn_input_split_attr_test.onnx", options); })); +} + +TEST_CASE(split_dyn_input_split_input_error) +{ + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + EXPECT(test::throws([&] { read_onnx("split_dyn_input_split_input_test.onnx", options); })); +} diff --git a/test/onnx/parse/split_minus_axis_test.cpp b/test/onnx/parse/split_minus_axis_test.cpp index c1200eb7878..6de7e1ff809 100644 --- a/test/onnx/parse/split_minus_axis_test.cpp +++ b/test/onnx/parse/split_minus_axis_test.cpp @@ -30,11 +30,11 @@ TEST_CASE(split_minus_axis_test) auto* mm = p.get_main_module(); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); auto r1 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {0}}, {"ends", {5}}}), input); + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), input); auto r2 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {5}}, {"ends", {10}}}), input); + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {5}}, {"ends", {10}}}), input); auto r3 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {10}}, {"ends", {15}}}), input); + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {10}}, {"ends", {15}}}), input); mm->add_return({r1, r2, r3}); auto prog = read_onnx("split_minus_axis_test.onnx"); diff --git a/test/onnx/split_dyn_input_dyn_split_axis_test.onnx b/test/onnx/split_dyn_input_dyn_split_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dbd86ea7c3cc2d3d630328d143042e2505d47bf6 GIT binary patch literal 199 zcmdew4 zgFza#xVczB>LeHzFftKkSd<{hBpxm<4n`q1E(R_p4*nz|E|AH{Vys{>BP?RZsA5h` G0>S`Rp)GX) literal 0 HcmV?d00001 diff --git a/test/onnx/split_dyn_input_fixed_split_axis_test.onnx b/test/onnx/split_dyn_input_fixed_split_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d704ca3d7d337b18adf456fe1fe53402def98211 GIT binary patch literal 203 zcmdhXY1QK9Vgad>U|hh+M5J+1f*`YaxVShNh1j?lxR^NjlZ3cHW+RKSg2jxm Nh#8}bIWY+c0|0{YF9rYr literal 0 HcmV?d00001 diff --git a/test/onnx/split_dyn_input_split_attr_test.onnx b/test/onnx/split_dyn_input_split_attr_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3a4e769d53a0153dfcd943ca78415716ced52d3b GIT binary patch literal 209 zcmdZkhwfuTpWx-Y+MXnOdR}4 XLR=twki}TRVn$fRj8Vm$m;{6Yi4QI@ literal 0 HcmV?d00001 diff --git a/test/onnx/split_dyn_input_split_input_test.onnx b/test/onnx/split_dyn_input_split_input_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0fa68022dc3ad964e2f510a9f075bf5bb6eb6bb5 GIT binary patch literal 249 zcmd7sg8v(Jg5^FHXHZ5*0mc)w8 zVhM%?j7)@G5G4pSO^An!i-S>!jf;VciGx2$hzn#avKT8^%m|B^F{+polYlS)IYT+B literal 0 HcmV?d00001 diff --git a/test/onnx/verify/split_dyn_verify_test.cpp b/test/onnx/verify/split_dyn_verify_test.cpp new file mode 100644 index 00000000000..da362a93f97 --- /dev/null +++ b/test/onnx/verify/split_dyn_verify_test.cpp @@ -0,0 +1,101 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(split_dyn_input_fixed_split_axis_test) +{ + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + auto p = read_onnx("split_dyn_input_fixed_split_axis_test.onnx", options); + p.compile(migraphx::make_target("ref")); + migraphx::shape data_shape{migraphx::shape::float_type, {10, 15}}; + std::vector data(150, 1.23); + migraphx::parameter_map pm; + pm["x"] = migraphx::argument(data_shape, data.data()); + auto results = p.eval(pm); + std::vector result_vector; + std::vector gold(50, 1.23); + + results.at(0).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + + results.at(1).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + + results.at(2).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(split_dyn_input_dyn_split_axis_test0) +{ + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + auto p = read_onnx("split_dyn_input_dyn_split_axis_test.onnx", options); + p.compile(migraphx::make_target("ref")); + migraphx::shape data_shape{migraphx::shape::float_type, {12, 15}}; + std::vector data(180, 1.23); + migraphx::parameter_map pm; + pm["x"] = migraphx::argument(data_shape, data.data()); + auto results = p.eval(pm); + std::vector result_vector; + std::vector gold(60, 1.23); + + results.at(0).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + + results.at(1).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + + results.at(2).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +// different static shape that doesn't split evenly +TEST_CASE(split_dyn_input_dyn_split_axis_test1) +{ + migraphx::onnx_options options; + options.default_dyn_dim_value = {10, 30}; + auto p = read_onnx("split_dyn_input_dyn_split_axis_test.onnx", options); + p.compile(migraphx::make_target("ref")); + migraphx::shape data_shape{migraphx::shape::float_type, {20, 15}}; + std::vector data(300, 1.23); + migraphx::parameter_map pm; + pm["x"] = migraphx::argument(data_shape, data.data()); + auto results = p.eval(pm); + std::vector result_vector; + std::vector gold_1(105, 1.23); + + results.at(0).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_1)); + + results.at(1).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_1)); + + std::vector gold_2(90, 1.23); + results.at(2).visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_2)); +}