From 74333b06b4af18e4dec72098c37b6050033b61f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 12 Mar 2024 15:04:49 +0000 Subject: [PATCH 01/18] Implementation start --- src/CMakeLists.txt | 1 + src/include/migraphx/op/scan.hpp | 140 ++++++++++++++++++++++++++++ src/include/migraphx/operators.hpp | 1 + src/onnx/parse_scan.cpp | 141 +++++++++++++++++++++++++++++ test/py/onnx_backend_test.py | 4 +- 5 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 src/include/migraphx/op/scan.hpp create mode 100644 src/onnx/parse_scan.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c20bc0a0a6c..f5c04a81995 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -226,6 +226,7 @@ register_migraphx_ops( rsqrt run_on_target scalar + scan scatter_none scatter_add scatter_mul diff --git a/src/include/migraphx/op/scan.hpp b/src/include/migraphx/op/scan.hpp new file mode 100644 index 00000000000..2026deed22e --- /dev/null +++ b/src/include/migraphx/op/scan.hpp @@ -0,0 +1,140 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_LOOP_HPP +#define MIGRAPHX_GUARD_OPERATORS_LOOP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scan : op_name +{ + int64_t iterations; + int64_t num_scan_inputs; + int64_t num_state_vars; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.iterations, "iterations"), + f(self.num_scan_inputs, "num_scan_inputs"), + f(self.num_state_vars, "num_state_vars")); + } + + shape compute_shape(const std::vector& inputs, std::vector mods) const + { + std::cout << "SCAN COMPUTE SHAPE" << std::endl; + assert(mods.size() == 1); + check_shapes{inputs, *this}.standard(); + auto mod = mods.front(); + std::cout << "Inputs size: " << inputs.size() << std::endl; + // The module has: + // N + M inputs (M = num_scan_inputs), same as the Scan node itself + // N + K outputs, same as the Scan node itself + auto output_shapes = mod->get_output_shapes(); + std::cout << to_string_range(output_shapes) << std::endl; + // Can't use mod->get_parameter_shapes() like this, parameters can include output parameters + // as well + // auto N = mod->get_parameter_shapes().size() - num_scan_inputs; + auto N = num_state_vars; + std::transform(output_shapes.begin() + N, + output_shapes.end(), + output_shapes.begin() + N, + [&](const auto& s) { + auto lens = s.lens(); + lens.insert(lens.begin(), iterations); + return shape{s.type(), lens}; + }); + + std::cout << "OUTPUT SHAPES: " << std::endl; + std::cout << to_string_range(output_shapes) << std::endl; + + return shape{output_shapes}; + } + + std::unordered_map get_output_params(const module_ref mod) const + { + std::unordered_map ret; + const std::string output_prefix = "#output_"; + + const auto& param_names = mod->get_parameter_names(); + for(const auto& name : param_names) + { + auto n = name.find(output_prefix); + if(n == std::string::npos) + continue; + int idx = std::stoi(name.substr(n + output_prefix.size())); + ret[name] = idx; + } + + return ret; + } + + argument compute(context& ctx, + const shape& out_shape, + const std::vector& args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) const + { + std::cout << "SCAN COMPUTE" << std::endl; + std::cout << args.size() << std::endl; + std::cout << out_shape << std::endl; + assert(mods.size() == 1); + auto mod = mods.front(); + mod->debug_print(); + auto param_shapes = mod->get_parameter_shapes(); + for(const auto& s : param_shapes) + std::cout << s.first << " " << s.second << std::endl; + auto output_params = get_output_params(mod); + for(const auto& p : output_params) + std::cout << p.first << " " << p.second << std::endl; + for(int64_t i = 0; i < iterations; ++i) + { + // Prepare params + // Run + // Set up next iteration + } + return args[0]; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index 5166d97eecf..9d1ea1e94fc 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -113,6 +113,7 @@ #include #include #include +#include #include #include #include diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp new file mode 100644 index 00000000000..323199ccbf6 --- /dev/null +++ b/src/onnx/parse_scan.cpp @@ -0,0 +1,141 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_scan : op_parser +{ + std::vector operators() const { return {{"Scan"}}; } + + std::vector parse(const op_desc& opd, + onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + std::cout << "Scan parse" << std::endl; + if(not contains(info.attributes, "body")) + MIGRAPHX_THROW("Scan: body attribute required"); + + if(not contains(info.attributes, "num_scan_inputs")) + MIGRAPHX_THROW("Scan: num_scan_inputs attribute required"); + + const auto& body = info.attributes["body"].g(); + const auto num_scan_inputs = info.attributes["num_scan_inputs"].i(); + + auto sub_mod = parser.prog.create_module(info.name + "_scan"); + (void)parser.parse_graph(sub_mod, body); + + auto param_names = sub_mod->get_parameter_names(); + std::cout << param_names.size() << std::endl; + std::cout << args.size() << std::endl; + + // This does not hold for the opset 8 version of Scan, which has an optional first input + // that the other versions do not have + // if(param_names.size() != args.size()) + // MIGRAPHX_THROW("Scan: Number of inputs to Scan does not match the number of inputs to + // " + // "its subgraph"); + + std::vector scan_input_axes(num_scan_inputs, 0); + if(contains(info.attributes, "scan_input_axes")) + { + auto&& axes = info.attributes["scan_input_axes"].ints(); + // if(axes.size() != num_scan_inputs) + // { + // MIGRAPHX_THROW("Scan: Size of scan_input_axes needs to match num_scan_inputs"); + // } + scan_input_axes.assign(axes.begin(), axes.end()); + // TODO add axes normalization + } + std::cout << "scan_input_axes: " << to_string_range(scan_input_axes) << std::endl; + + // TODO check that scan_input_axes lens match for every arg + auto num_inits = args.size() - num_scan_inputs; + size_t num_iters = args[num_inits]->get_shape().lens()[scan_input_axes[0]]; + + std::cout << "num_scan_inputs: " << num_scan_inputs << std::endl; + std::cout << "args shapes: " << to_string_range(to_shapes(args)) << std::endl; + sub_mod->debug_print(); + auto param_shapes = sub_mod->get_parameter_shapes(); + std::cout << to_string_range(param_names) << std::endl; + for(const auto& s : param_shapes) + std::cout << s.first << " " << s.second << std::endl; + std::cout << "num_iters: " << num_iters << std::endl; + + auto N = args.size() - num_scan_inputs; + std::vector alt_args(args.begin(), args.begin() + N); + for(int64_t i = 0; i < num_iters; ++i) + { + std::transform( + args.begin() + N, args.end(), std::back_inserter(alt_args), [&](const auto& arg) { + auto slice = info.add_instruction( + make_op("slice", {{"axes", {0}}, {"starts", {i}}, {"ends", {i + 1}}}), arg); + return info.add_instruction(make_op("squeeze", {{"axes", {0}}}), slice); + }); + } + + std::cout << "Whole args" << std::endl; + for(const auto& ins : args) + { + ins->debug_print(); + } + std::cout << "Sliced args" << std::endl; + for(const auto& ins : alt_args) + { + ins->debug_print(); + } + + auto ret = info.add_instruction(make_op("scan", + {{"iterations", num_iters}, + {"num_scan_inputs", num_scan_inputs}, + {"num_state_vars", N}}), + alt_args, + {sub_mod}); + + auto out_s = ret->get_shape(); + assert(out_s.type() == shape::tuple_type); + + const auto& vec_shapes = out_s.sub_shapes(); + std::vector out_inss; + for(std::size_t i = 0; i < vec_shapes.size(); ++i) + { + auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret); + out_inss.push_back(r); + } + + return out_inss; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index a6956c6199e..f7a20492958 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -814,7 +814,9 @@ def create_backend_test(testname=None, target_device=None): c2.set_device(target_device) backend_test = MIGraphXBackendTest(c2, __name__) - if testname: + if True: + backend_test.include(r'test_scan9_sum_cpu') + elif testname: backend_test.include(testname + '.*') else: # Onnx Operator tests From a3011f377fbaf04051bce8090d8de44bdb2a0cd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 19 Mar 2024 11:42:13 +0000 Subject: [PATCH 02/18] Implement scan base case --- src/include/migraphx/op/scan.hpp | 67 ++++++++--------- src/onnx/parse_celu.cpp | 1 + src/onnx/parse_scan.cpp | 119 ++++++++++++++++--------------- test/onnx/gen_onnx.py | 47 ++++++++++++ test/onnx/scan_test.onnx | 45 ++++++++++++ test/onnx/verify/scan_test.cpp | 72 +++++++++++++++++++ 6 files changed, 255 insertions(+), 96 deletions(-) create mode 100644 test/onnx/scan_test.onnx create mode 100644 test/onnx/verify/scan_test.cpp diff --git a/src/include/migraphx/op/scan.hpp b/src/include/migraphx/op/scan.hpp index 2026deed22e..7a251a3e23c 100644 --- a/src/include/migraphx/op/scan.hpp +++ b/src/include/migraphx/op/scan.hpp @@ -57,33 +57,21 @@ struct scan : op_name shape compute_shape(const std::vector& inputs, std::vector mods) const { - std::cout << "SCAN COMPUTE SHAPE" << std::endl; assert(mods.size() == 1); check_shapes{inputs, *this}.standard(); auto mod = mods.front(); - std::cout << "Inputs size: " << inputs.size() << std::endl; - // The module has: - // N + M inputs (M = num_scan_inputs), same as the Scan node itself - // N + K outputs, same as the Scan node itself - auto output_shapes = mod->get_output_shapes(); - std::cout << to_string_range(output_shapes) << std::endl; - // Can't use mod->get_parameter_shapes() like this, parameters can include output parameters - // as well - // auto N = mod->get_parameter_shapes().size() - num_scan_inputs; - auto N = num_state_vars; - std::transform(output_shapes.begin() + N, - output_shapes.end(), - output_shapes.begin() + N, - [&](const auto& s) { - auto lens = s.lens(); - lens.insert(lens.begin(), iterations); - return shape{s.type(), lens}; - }); - - std::cout << "OUTPUT SHAPES: " << std::endl; - std::cout << to_string_range(output_shapes) << std::endl; - - return shape{output_shapes}; + // The module has N + K outputs + auto mod_output_shapes = mod->get_output_shapes(); + std::vector op_output_shapes{mod_output_shapes.begin(), + mod_output_shapes.begin() + num_state_vars}; + auto K = mod_output_shapes.size() - num_state_vars; + op_output_shapes.reserve(num_state_vars + iterations * K); + for(auto i = 0; i < iterations; ++i) + op_output_shapes.insert(op_output_shapes.end(), + mod_output_shapes.begin() + num_state_vars, + mod_output_shapes.end()); + + return shape{op_output_shapes}; } std::unordered_map get_output_params(const module_ref mod) const @@ -111,25 +99,28 @@ struct scan : op_name const std::function( module_ref&, const std::unordered_map&)>& run) const { - std::cout << "SCAN COMPUTE" << std::endl; - std::cout << args.size() << std::endl; - std::cout << out_shape << std::endl; assert(mods.size() == 1); auto mod = mods.front(); - mod->debug_print(); auto param_shapes = mod->get_parameter_shapes(); - for(const auto& s : param_shapes) - std::cout << s.first << " " << s.second << std::endl; - auto output_params = get_output_params(mod); - for(const auto& p : output_params) - std::cout << p.first << " " << p.second << std::endl; - for(int64_t i = 0; i < iterations; ++i) + auto param_names = mod->get_parameter_names(); + + auto K = mod->get_output_shapes().size() - num_state_vars; + parameter_map pm; + std::vector ret{args.begin(), args.begin() + num_state_vars}; + for(auto i = 0; i < iterations; ++i) { - // Prepare params - // Run - // Set up next iteration + for(auto j = 0; j < num_state_vars; ++j) + pm[param_names[j]] = ret[j]; + for(auto j = num_state_vars; j < num_state_vars + K; ++j) + pm[param_names[j]] = args[i * K + j]; + + auto mod_output = run(mod, pm); + + std::copy(mod_output.begin(), mod_output.begin() + num_state_vars, ret.begin()); + ret.insert(ret.end(), mod_output.begin() + num_state_vars, mod_output.end()); } - return args[0]; + + return argument{ret}; } }; diff --git a/src/onnx/parse_celu.cpp b/src/onnx/parse_celu.cpp index 3bd8fd62e38..05c434d9bc2 100644 --- a/src/onnx/parse_celu.cpp +++ b/src/onnx/parse_celu.cpp @@ -35,6 +35,7 @@ struct parse_celu : op_parser { std::vector operators() const { return {{"Celu"}}; } + instruction_ref parse(const op_desc&, const onnx_parser&, const onnx_parser::node_info& info, diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 323199ccbf6..38058859535 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/instruction_ref.hpp" #include #include #include @@ -41,57 +42,48 @@ struct parse_scan : op_parser onnx_parser::node_info info, std::vector args) const { - std::cout << "Scan parse" << std::endl; + // NOTE Version 8 of the operator differs to all the later versions if(not contains(info.attributes, "body")) MIGRAPHX_THROW("Scan: body attribute required"); if(not contains(info.attributes, "num_scan_inputs")) MIGRAPHX_THROW("Scan: num_scan_inputs attribute required"); - const auto& body = info.attributes["body"].g(); - const auto num_scan_inputs = info.attributes["num_scan_inputs"].i(); - - auto sub_mod = parser.prog.create_module(info.name + "_scan"); + const auto& body = info.attributes["body"].g(); + auto sub_mod = parser.prog.create_module(info.name + "_scan"); (void)parser.parse_graph(sub_mod, body); - auto param_names = sub_mod->get_parameter_names(); - std::cout << param_names.size() << std::endl; - std::cout << args.size() << std::endl; - - // This does not hold for the opset 8 version of Scan, which has an optional first input - // that the other versions do not have - // if(param_names.size() != args.size()) - // MIGRAPHX_THROW("Scan: Number of inputs to Scan does not match the number of inputs to - // " - // "its subgraph"); + const auto num_scan_inputs = info.attributes["num_scan_inputs"].i(); std::vector scan_input_axes(num_scan_inputs, 0); if(contains(info.attributes, "scan_input_axes")) { auto&& axes = info.attributes["scan_input_axes"].ints(); - // if(axes.size() != num_scan_inputs) - // { - // MIGRAPHX_THROW("Scan: Size of scan_input_axes needs to match num_scan_inputs"); - // } scan_input_axes.assign(axes.begin(), axes.end()); - // TODO add axes normalization + // Validate: Size of scan_input_axes must be equal to num_scan_inputs + // Perform: Normalize the axes } - std::cout << "scan_input_axes: " << to_string_range(scan_input_axes) << std::endl; - - // TODO check that scan_input_axes lens match for every arg - auto num_inits = args.size() - num_scan_inputs; - size_t num_iters = args[num_inits]->get_shape().lens()[scan_input_axes[0]]; - - std::cout << "num_scan_inputs: " << num_scan_inputs << std::endl; - std::cout << "args shapes: " << to_string_range(to_shapes(args)) << std::endl; - sub_mod->debug_print(); - auto param_shapes = sub_mod->get_parameter_shapes(); - std::cout << to_string_range(param_names) << std::endl; - for(const auto& s : param_shapes) - std::cout << s.first << " " << s.second << std::endl; - std::cout << "num_iters: " << num_iters << std::endl; - - auto N = args.size() - num_scan_inputs; + // Validate: The scan axis len across each scan_in must be equal + + // TODO + // Parse scan_input_directions + // Validate: Size of scan_input_directions must be equal to num_scan_inputs + // Validate: 0 and 1 are only allowed values + + // TODO + // Parse scan_output_axes + // Validate: Size of scan_output_axes must be equal to K + // Perform: Normalize the axes + // Validate: Values must be in range[0, r-1] + + // TODO + // Parse scan_output_directions + // Validate: Size of scan_output_directions must be equal to K + // Validate: 0 and 1 are only allowed values + + auto N = args.size() - num_scan_inputs; + size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; + std::vector alt_args(args.begin(), args.begin() + N); for(int64_t i = 0; i < num_iters; ++i) { @@ -103,36 +95,47 @@ struct parse_scan : op_parser }); } - std::cout << "Whole args" << std::endl; - for(const auto& ins : args) + // Inputs: init_states, array of pre-sliced scan_inputs + // N + M * num_iters number of inputs + auto scan = info.add_instruction(make_op("scan", + {{"iterations", num_iters}, + {"num_scan_inputs", num_scan_inputs}, + {"num_state_vars", N}}), + alt_args, + {sub_mod}); + // Outputs: final_states, array of scan_output_elements + // N + K * num_iters number of outputs + + std::vector ret; + auto K = sub_mod->get_output_shapes().size() - N; + for(auto i = 0; i < N; ++i) { - ins->debug_print(); + auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); + ret.push_back(get); } - std::cout << "Sliced args" << std::endl; - for(const auto& ins : alt_args) + + for(auto i = N; i < N + K; ++i) { - ins->debug_print(); + auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); + auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), get); + ret.push_back(usq); } - auto ret = info.add_instruction(make_op("scan", - {{"iterations", num_iters}, - {"num_scan_inputs", num_scan_inputs}, - {"num_state_vars", N}}), - alt_args, - {sub_mod}); - - auto out_s = ret->get_shape(); - assert(out_s.type() == shape::tuple_type); - - const auto& vec_shapes = out_s.sub_shapes(); - std::vector out_inss; - for(std::size_t i = 0; i < vec_shapes.size(); ++i) + for(auto i = 1; i < num_iters; ++i) { - auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret); - out_inss.push_back(r); + for(auto j = 0; j < K; ++j) + { + auto tuple_idx = N + i * K + j; + auto get = + info.add_instruction(make_op("get_tuple_elem", {{"index", tuple_idx}}), scan); + auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), get); + auto concat = + info.add_instruction(make_op("concat", {{"axis", 0}}), ret[N + j], usq); + ret[N + j] = concat; + } } - return out_inss; + return ret; } }; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 66605a91d35..82f4d07aa2f 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10684,3 +10684,50 @@ def where_mixed_test(): outputs=['z']) return ([node], [c, x, y], [z]) + + +@onnx_test() +def scan_test(): + sum_in = onnx.helper.make_tensor_value_info( + "sum_in", onnx.TensorProto.FLOAT, [2] + ) + next = onnx.helper.make_tensor_value_info( + "next", onnx.TensorProto.FLOAT, [2] + ) + sum_out = onnx.helper.make_tensor_value_info( + "sum_out", onnx.TensorProto.FLOAT, [2] + ) + scan_out = onnx.helper.make_tensor_value_info( + "scan_out", onnx.TensorProto.FLOAT, [2] + ) + add_node = onnx.helper.make_node( + "Add", inputs=["sum_in", "next"], outputs=["sum_out"] + ) + id_node = onnx.helper.make_node( + "Identity", inputs=["sum_out"], outputs=["scan_out"] + ) + scan_body = onnx.helper.make_graph( + [add_node, id_node], "scan_body", [sum_in, next], [sum_out, scan_out] + ) + + init_state = onnx.helper.make_tensor_value_info( + "init_state", onnx.TensorProto.FLOAT, [2] + ) + scan_ins = onnx.helper.make_tensor_value_info( + "scan_ins", onnx.TensorProto.FLOAT, [3, 2] + ) + final_state = onnx.helper.make_tensor_value_info( + "final_state", onnx.TensorProto.FLOAT, [2] + ) + scan_outs = onnx.helper.make_tensor_value_info( + "scan_outs", onnx.TensorProto.FLOAT, [3, 2] + ) + node = onnx.helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) diff --git a/test/onnx/scan_test.onnx b/test/onnx/scan_test.onnx new file mode 100644 index 00000000000..ae29be2be54 --- /dev/null +++ b/test/onnx/scan_test.onnx @@ -0,0 +1,45 @@ +  scan_test: + + +init_state +scan_ins final_state scan_outs"Scan* +body2 + +sum_in +nextsum_out"Add + +sum_outscan_out"Identity scan_bodyZ +sum_in + + +Z +next + + +b +sum_out + + +b +scan_out + + +* +num_scan_inputs scan_testZ + +init_state + + +Z +scan_ins +  + +b + final_state + + +b + scan_outs +  + +B \ No newline at end of file diff --git a/test/onnx/verify/scan_test.cpp b/test/onnx/verify/scan_test.cpp new file mode 100644 index 00000000000..7d0ac0ecfde --- /dev/null +++ b/test/onnx/verify/scan_test.cpp @@ -0,0 +1,72 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "migraphx/argument.hpp" +#include "migraphx/generate.hpp" +#include "migraphx/module.hpp" +#include "migraphx/onnx.hpp" +#include "migraphx/shape.hpp" +#include +#include +#include + +static migraphx::shape make_shape(const std::vector& lens) +{ + return migraphx::shape{migraphx::shape::float_type, lens}; +} + +static std::vector arg_to_vec(const migraphx::argument& arg) +{ + std::vector ret; + arg.visit([&](auto output) { ret.assign(output.begin(), output.end()); }); + return ret; +} + +TEST_CASE(scan_test) +{ + auto prog = migraphx::parse_onnx("scan_test.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2}}; + std::vector init_state{0, 0}; + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2}}; + std::vector scan_ins{1, 2, 3, 4, 5, 6}; + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + auto final_state = result[0]; + auto scan_out = result[1]; + + EXPECT(final_state.get_shape() == make_shape({2})); + std::vector final_state_gold{9.f, 12.f}; + EXPECT(arg_to_vec(final_state) == final_state_gold); + + EXPECT(scan_out.get_shape() == make_shape({3, 2})); + std::vector scan_out_gold{1.f, 2.f, 4.f, 6.f, 9.f, 12.f}; + EXPECT(arg_to_vec(scan_out) == scan_out_gold); +} From 72d37165df872a8d8f3c4d9420c4e3b79cdab356 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 20 Mar 2024 08:52:41 +0000 Subject: [PATCH 03/18] Implement scan output direction and axes attribute support --- src/include/migraphx/op/scan.hpp | 4 +- src/onnx/parse_scan.cpp | 84 +++++++++++++++++++----- test/onnx/gen_onnx.py | 98 ++++++++++++++++------------ test/onnx/scan_test.onnx | 45 ------------- test/onnx/scan_test1.onnx | Bin 0 -> 564 bytes test/onnx/scan_test2.onnx | Bin 0 -> 597 bytes test/onnx/scan_test3.onnx | Bin 0 -> 633 bytes test/onnx/verify/scan_test.cpp | 106 +++++++++++++++++++++++++++---- 8 files changed, 223 insertions(+), 114 deletions(-) delete mode 100644 test/onnx/scan_test.onnx create mode 100644 test/onnx/scan_test1.onnx create mode 100644 test/onnx/scan_test2.onnx create mode 100644 test/onnx/scan_test3.onnx diff --git a/src/include/migraphx/op/scan.hpp b/src/include/migraphx/op/scan.hpp index 7a251a3e23c..7209188d4a9 100644 --- a/src/include/migraphx/op/scan.hpp +++ b/src/include/migraphx/op/scan.hpp @@ -111,8 +111,8 @@ struct scan : op_name { for(auto j = 0; j < num_state_vars; ++j) pm[param_names[j]] = ret[j]; - for(auto j = num_state_vars; j < num_state_vars + K; ++j) - pm[param_names[j]] = args[i * K + j]; + for(auto j = num_state_vars; j < num_state_vars + num_scan_inputs; ++j) + pm[param_names[j]] = args[i * num_scan_inputs + j]; auto mod_output = run(mod, pm); diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 38058859535..723addaaf99 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -21,7 +21,11 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/errors.hpp" #include "migraphx/instruction_ref.hpp" +#include +#include +#include #include #include #include @@ -54,6 +58,9 @@ struct parse_scan : op_parser (void)parser.parse_graph(sub_mod, body); const auto num_scan_inputs = info.attributes["num_scan_inputs"].i(); + auto N = args.size() - num_scan_inputs; + auto sub_mod_output_shapes = sub_mod->get_output_shapes(); + auto K = sub_mod_output_shapes.size() - N; std::vector scan_input_axes(num_scan_inputs, 0); if(contains(info.attributes, "scan_input_axes")) @@ -70,20 +77,50 @@ struct parse_scan : op_parser // Validate: Size of scan_input_directions must be equal to num_scan_inputs // Validate: 0 and 1 are only allowed values - // TODO - // Parse scan_output_axes - // Validate: Size of scan_output_axes must be equal to K - // Perform: Normalize the axes - // Validate: Values must be in range[0, r-1] + // SCAN OUTPUT AXES + std::vector scan_output_axes(K, 0); + if(contains(info.attributes, "scan_output_axes")) + { + auto&& axes = info.attributes["scan_output_axes"].ints(); + scan_output_axes.assign(axes.begin(), axes.end()); + + if(scan_output_axes.size() != K) + MIGRAPHX_THROW("Number of scan output axes (" + to_string(scan_output_axes.size()) + + ") does not match number of body scan outputs(" + to_string(K) + + ")"); + + std::vector ndims; + ndims.reserve(K); + std::transform(sub_mod_output_shapes.begin() + N, + sub_mod_output_shapes.end(), + std::back_inserter(ndims), + [](const shape& sh) { return sh.ndim() + 1; }); + normalize_axes(scan_output_axes, ndims); + } + std::cout << to_string_range(scan_output_axes) << std::endl; + // SCAN OUTPUT AXES - // TODO - // Parse scan_output_directions - // Validate: Size of scan_output_directions must be equal to K - // Validate: 0 and 1 are only allowed values + // SCAN OUTPUT DIRECTIONS + std::vector scan_output_directions(K, 0); + if(contains(info.attributes, "scan_output_directions")) + { + auto&& dirs = info.attributes["scan_output_directions"].ints(); + scan_output_directions.assign(dirs.begin(), dirs.end()); + + if(scan_output_directions.size() != K) + MIGRAPHX_THROW("Number of scan output directions (" + + to_string(scan_output_directions.size()) + + ") does not match number of body scan outputs(" + to_string(K) + + ")"); + + if(any_of(scan_output_directions, [](auto i) { return i != 0 and i != 1; })) + MIGRAPHX_THROW( + "Scan output directions may contain only 1s and 0s, actual values: " + + to_string_range(scan_output_directions)); + } + // SCAN OUTPUT DIRECTIONS - auto N = args.size() - num_scan_inputs; size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; - std::vector alt_args(args.begin(), args.begin() + N); for(int64_t i = 0; i < num_iters; ++i) { @@ -107,7 +144,6 @@ struct parse_scan : op_parser // N + K * num_iters number of outputs std::vector ret; - auto K = sub_mod->get_output_shapes().size() - N; for(auto i = 0; i < N; ++i) { auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); @@ -117,7 +153,8 @@ struct parse_scan : op_parser for(auto i = N; i < N + K; ++i) { auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); - auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), get); + auto scan_axis = scan_output_axes[i - N]; + auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); ret.push_back(usq); } @@ -128,15 +165,32 @@ struct parse_scan : op_parser auto tuple_idx = N + i * K + j; auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", tuple_idx}}), scan); - auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), get); + auto scan_axis = scan_output_axes[j]; + auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); + auto dir = scan_output_directions[j]; + std::vector concat_args(2, usq); + concat_args[dir] = ret[N + j]; auto concat = - info.add_instruction(make_op("concat", {{"axis", 0}}), ret[N + j], usq); + info.add_instruction(make_op("concat", {{"axis", scan_axis}}), concat_args); ret[N + j] = concat; } } return ret; } + + void normalize_axes(std::vector& axes, const std::vector& ndims) const + { + auto normalize_axis = [=](int64_t axis, int64_t ndim) { + if(axis < -ndim or axis >= ndim) + MIGRAPHX_THROW("Axis value {" + to_string(axis) + "} out of range [" + + to_string(-ndim) + ", " + to_string(ndim) + ")"); + + return axis < 0 ? ndim + axis : axis; + }; + + std::transform(axes.begin(), axes.end(), ndims.begin(), axes.begin(), normalize_axis); + } }; } // namespace onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 82f4d07aa2f..3226e423c00 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10686,48 +10686,68 @@ def where_mixed_test(): return ([node], [c, x, y], [z]) -@onnx_test() -def scan_test(): - sum_in = onnx.helper.make_tensor_value_info( - "sum_in", onnx.TensorProto.FLOAT, [2] - ) - next = onnx.helper.make_tensor_value_info( - "next", onnx.TensorProto.FLOAT, [2] - ) - sum_out = onnx.helper.make_tensor_value_info( - "sum_out", onnx.TensorProto.FLOAT, [2] - ) - scan_out = onnx.helper.make_tensor_value_info( - "scan_out", onnx.TensorProto.FLOAT, [2] - ) - add_node = onnx.helper.make_node( - "Add", inputs=["sum_in", "next"], outputs=["sum_out"] - ) - id_node = onnx.helper.make_node( - "Identity", inputs=["sum_out"], outputs=["scan_out"] - ) - scan_body = onnx.helper.make_graph( - [add_node, id_node], "scan_body", [sum_in, next], [sum_out, scan_out] - ) - - init_state = onnx.helper.make_tensor_value_info( - "init_state", onnx.TensorProto.FLOAT, [2] - ) - scan_ins = onnx.helper.make_tensor_value_info( - "scan_ins", onnx.TensorProto.FLOAT, [3, 2] - ) - final_state = onnx.helper.make_tensor_value_info( - "final_state", onnx.TensorProto.FLOAT, [2] - ) - scan_outs = onnx.helper.make_tensor_value_info( - "scan_outs", onnx.TensorProto.FLOAT, [3, 2] - ) - node = onnx.helper.make_node( +def scan_test(scan_output_axes=[0, 0], scan_output_directions=[0, 0]): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + next = helper.make_tensor_value_info("next", TensorProto.FLOAT, [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out1 = helper.make_tensor_value_info("scan_out1", TensorProto.FLOAT, + [2, 2]) + scan_out2 = helper.make_tensor_value_info("scan_out2", TensorProto.FLOAT, + [2]) + add = helper.make_node("Add", + inputs=["sum_in", "next"], + outputs=["sum_out"]) + id = helper.make_node("Identity", + inputs=["sum_out"], + outputs=["scan_out1"]) + reduce_sum = helper.make_node("ReduceSum", + axes=[0], + keepdims=0, + inputs=["sum_out"], + outputs=["scan_out2"]) + scan_body = helper.make_graph([add, id, reduce_sum], "scan_body", + [sum_in, next], + [sum_out, scan_out1, scan_out2]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + [3, 2, 2]) + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs1_sh = [2, 2, 2] + scan_outs1_sh[scan_output_axes[0]] = 3 + scan_outs1 = helper.make_tensor_value_info("scan_outs1", TensorProto.FLOAT, + scan_outs1_sh) + scan_outs2_sh = [2, 2] + scan_outs2_sh[scan_output_axes[1]] = 3 + scan_outs2 = helper.make_tensor_value_info("scan_outs2", TensorProto.FLOAT, + scan_outs2_sh) + node = helper.make_node( "Scan", inputs=["init_state", "scan_ins"], - outputs=["final_state", "scan_outs"], + outputs=["final_state", "scan_outs1", "scan_outs2"], num_scan_inputs=1, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, body=scan_body, ) - return ([node], [init_state, scan_ins], [final_state, scan_outs]) + return ([node], [init_state, + scan_ins], [final_state, scan_outs1, scan_outs2]) + + +@onnx_test() +def scan_test1(): + return scan_test() + + +@onnx_test() +def scan_test2(): + return scan_test(scan_output_directions=[1, 0]) + + +@onnx_test() +def scan_test3(): + return scan_test(scan_output_axes=[1, -1], scan_output_directions=[0, 1]) diff --git a/test/onnx/scan_test.onnx b/test/onnx/scan_test.onnx deleted file mode 100644 index ae29be2be54..00000000000 --- a/test/onnx/scan_test.onnx +++ /dev/null @@ -1,45 +0,0 @@ -  scan_test: - - -init_state -scan_ins final_state scan_outs"Scan* -body2 - -sum_in -nextsum_out"Add - -sum_outscan_out"Identity scan_bodyZ -sum_in - - -Z -next - - -b -sum_out - - -b -scan_out - - -* -num_scan_inputs scan_testZ - -init_state - - -Z -scan_ins -  - -b - final_state - - -b - scan_outs -  - -B \ No newline at end of file diff --git a/test/onnx/scan_test1.onnx b/test/onnx/scan_test1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..564c8f0768a73b6a5ba812d2af34c17d1ef584b9 GIT binary patch literal 564 zcmZuuO-sW-6g1mtH`8Lv3N>Oal4EYkQ4rLdS9{APB+UwoW?R?|w7 z5#c=p>3g6(EEp; zy*6b2H-VjQvqShru*8Kb@>odNlJ7Memz5dxM$cJ>7229Pm6AaMkG^JHYM~E!Jb{Ov z^TxAc!`DUbPhcl|uAa`0*4UqcesI3aM6O)ZAPr|}C30II>=ILIF$zLS`JzSRY|u|{ zI_)8Ji_YzkFTL)-kO9<)*`?6YI`j}0KC(2>%y=71|zNJYl)LdKtUxw(`VTCI9} z>N^|LEI1%=lAP`7H-ZIRs3MDm1V{3nCZqCXF8cN}W3YtwOx%iOl$DQP8JAk}ZXGDc}#zw@hTpGX>Q2EUkoI>w{dv1X~0H!l+!~ym2bC^mqfAgqY^M0OXj0X2GZ^t*qh>0(rYLQE| zwyM$kDTRxELPLDG>}5t^X6O?`x7%@(JNTF6aI^=gKSZ7FsRjn$dR4TCuMjV2cJ&9q CS&Z!f literal 0 HcmV?d00001 diff --git a/test/onnx/scan_test3.onnx b/test/onnx/scan_test3.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1a49a6f229b44b9d168711097a3c0134476f57cb GIT binary patch literal 633 zcmZuv%SyvQ6g3a+Vgg6w`I+cm$O5fgn+wf%p2q|=;bgDU!3Kde(@r#h@TU~B0<%L$u z-goHRD;ySlnqX%++lL<}%;7{8n?y*kCEsZ}Do=RPw;mXS1+-`CR7{90ee%S()I#s! z{vSN_ooCL9HD48*U<`J;<0?8j7^7ec arg_to_vec(const migraphx::argument& arg) return ret; } -TEST_CASE(scan_test) +TEST_CASE(scan_test1) { - auto prog = migraphx::parse_onnx("scan_test.onnx"); + auto prog = migraphx::parse_onnx("scan_test1.onnx"); prog.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; - migraphx::shape init_state_sh{migraphx::shape::float_type, {2}}; - std::vector init_state{0, 0}; + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(4, 0); pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); - migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2}}; - std::vector scan_ins{1, 2, 3, 4, 5, 6}; + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(12); + std::iota(scan_ins.begin(), scan_ins.end(), 1); pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); - auto result = prog.eval(pm); + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + + auto final_state = result[0]; + auto scan_out1 = result[1]; + auto scan_out2 = result[2]; + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{15, 18, 21, 24}; + EXPECT(arg_to_vec(final_state) == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{1, 2, 3, 4, 6, 8, 10, 12, 15, 18, 21, 24}; + EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{4, 6, 16, 20, 36, 42}; + EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); +} + +TEST_CASE(scan_test2) +{ + auto prog = migraphx::parse_onnx("scan_test2.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(4, 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(12); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + auto final_state = result[0]; - auto scan_out = result[1]; + auto scan_out1 = result[1]; + auto scan_out2 = result[2]; - EXPECT(final_state.get_shape() == make_shape({2})); - std::vector final_state_gold{9.f, 12.f}; + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{15, 18, 21, 24}; EXPECT(arg_to_vec(final_state) == final_state_gold); - EXPECT(scan_out.get_shape() == make_shape({3, 2})); - std::vector scan_out_gold{1.f, 2.f, 4.f, 6.f, 9.f, 12.f}; - EXPECT(arg_to_vec(scan_out) == scan_out_gold); + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{15, 18, 21, 24, 6, 8, 10, 12, 1, 2, 3, 4}; + EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{4, 6, 16, 20, 36, 42}; + EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); +} + +TEST_CASE(scan_test3) +{ + auto prog = migraphx::parse_onnx("scan_test3.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(4, 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(12); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + + auto final_state = result[0]; + auto scan_out1 = result[1]; + auto scan_out2 = result[2]; + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{15, 18, 21, 24}; + EXPECT(arg_to_vec(final_state) == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({2, 3, 2})); + std::vector scan_out1_gold{1, 2, 6, 8, 15, 18, 3, 4, 10, 12, 21, 24}; + EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({2, 3})); + std::vector scan_out2_gold{36, 16, 4, 42, 20, 6}; + EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } From f7a7e9fe4100a9c88644daef07996668ed4f433e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 20 Mar 2024 11:16:02 +0000 Subject: [PATCH 04/18] Implement scan_input_axes attribute support --- src/include/migraphx/op/scan.hpp | 2 +- src/onnx/parse_scan.cpp | 104 +++++++++++++++++++++---------- test/onnx/gen_onnx.py | 21 ++++++- test/onnx/scan_test4.onnx | Bin 0 -> 654 bytes test/onnx/scan_test5.onnx | Bin 0 -> 678 bytes test/onnx/verify/scan_test.cpp | 72 +++++++++++++++++++++ 6 files changed, 163 insertions(+), 36 deletions(-) create mode 100644 test/onnx/scan_test4.onnx create mode 100644 test/onnx/scan_test5.onnx diff --git a/src/include/migraphx/op/scan.hpp b/src/include/migraphx/op/scan.hpp index 7209188d4a9..629477ba2b7 100644 --- a/src/include/migraphx/op/scan.hpp +++ b/src/include/migraphx/op/scan.hpp @@ -58,7 +58,7 @@ struct scan : op_name shape compute_shape(const std::vector& inputs, std::vector mods) const { assert(mods.size() == 1); - check_shapes{inputs, *this}.standard(); + // check_shapes{inputs, *this}.standard(); auto mod = mods.front(); // The module has N + K outputs auto mod_output_shapes = mod->get_output_shapes(); diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 723addaaf99..739d19d9df5 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/argument.hpp" #include "migraphx/errors.hpp" #include "migraphx/instruction_ref.hpp" #include @@ -57,25 +58,61 @@ struct parse_scan : op_parser auto sub_mod = parser.prog.create_module(info.name + "_scan"); (void)parser.parse_graph(sub_mod, body); - const auto num_scan_inputs = info.attributes["num_scan_inputs"].i(); - auto N = args.size() - num_scan_inputs; auto sub_mod_output_shapes = sub_mod->get_output_shapes(); - auto K = sub_mod_output_shapes.size() - N; + const auto M = info.attributes["num_scan_inputs"].i(); + const auto N = args.size() - M; + const auto K = sub_mod_output_shapes.size() - N; - std::vector scan_input_axes(num_scan_inputs, 0); + // NOTE Does not apply to opset 8 version + if(sub_mod->get_parameter_names().size() != N + M) + MIGRAPHX_THROW("Lorem ipsum"); + + // SCAN INPUT AXES + std::vector scan_input_axes(M, 0); if(contains(info.attributes, "scan_input_axes")) { auto&& axes = info.attributes["scan_input_axes"].ints(); scan_input_axes.assign(axes.begin(), axes.end()); - // Validate: Size of scan_input_axes must be equal to num_scan_inputs - // Perform: Normalize the axes + + if(scan_input_axes.size() != M) + MIGRAPHX_THROW("Number of scan input axes (" + to_string(scan_input_axes.size()) + + ") does not match number of scan inputs(" + to_string(M) + ")"); + + std::vector ndims; + ndims.reserve(M); + std::transform(args.begin() + N, + args.end(), + std::back_inserter(ndims), + [](instruction_ref arg) { return arg->get_shape().ndim(); }); + normalize_axes(scan_input_axes, ndims); } - // Validate: The scan axis len across each scan_in must be equal - // TODO - // Parse scan_input_directions - // Validate: Size of scan_input_directions must be equal to num_scan_inputs - // Validate: 0 and 1 are only allowed values + size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; + for(auto i = 1; i < M; ++i) + { + if(args[i]->get_shape().lens()[scan_input_axes[i]] != num_iters) + MIGRAPHX_THROW("Lorem ipsum"); + } + // SCAN INPUT AXES + + // SCAN INPUT DIRECTIONS + std::vector scan_input_directions(M, 0); + if(contains(info.attributes, "scan_input_directions")) + { + auto&& dirs = info.attributes["scan_input_directions"].ints(); + scan_input_directions.assign(dirs.begin(), dirs.end()); + + if(scan_input_directions.size() != M) + MIGRAPHX_THROW("Number of scan input directions (" + + to_string(scan_input_directions.size()) + + ") does not match number of scan inputs(" + to_string(M) + ")"); + + if(any_of(scan_input_directions, [](auto i) { return i != 0 and i != 1; })) + MIGRAPHX_THROW( + "Scan output directions may contain only 1s and 0s, actual values: " + + to_string_range(scan_input_directions)); + } + // SCAN INPUT DIRECTIONS // SCAN OUTPUT AXES std::vector scan_output_axes(K, 0); @@ -97,7 +134,6 @@ struct parse_scan : op_parser [](const shape& sh) { return sh.ndim() + 1; }); normalize_axes(scan_output_axes, ndims); } - std::cout << to_string_range(scan_output_axes) << std::endl; // SCAN OUTPUT AXES // SCAN OUTPUT DIRECTIONS @@ -120,30 +156,33 @@ struct parse_scan : op_parser } // SCAN OUTPUT DIRECTIONS - size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; std::vector alt_args(args.begin(), args.begin() + N); for(int64_t i = 0; i < num_iters; ++i) { - std::transform( - args.begin() + N, args.end(), std::back_inserter(alt_args), [&](const auto& arg) { - auto slice = info.add_instruction( - make_op("slice", {{"axes", {0}}, {"starts", {i}}, {"ends", {i + 1}}}), arg); - return info.add_instruction(make_op("squeeze", {{"axes", {0}}}), slice); - }); + for(auto j = 0; j < M; ++j) + { + auto dir = scan_input_directions[j]; + auto idx = (1 - dir) * i + dir * (num_iters - 1 - i); + auto scan_axis = scan_input_axes[j]; + auto slice = info.add_instruction( + make_op("slice", + {{"axes", {scan_axis}}, {"starts", {idx}}, {"ends", {idx + 1}}}), + args[N + j]); + alt_args.push_back( + info.add_instruction(make_op("squeeze", {{"axes", {scan_axis}}}), slice)); + } } - // Inputs: init_states, array of pre-sliced scan_inputs - // N + M * num_iters number of inputs - auto scan = info.add_instruction(make_op("scan", - {{"iterations", num_iters}, - {"num_scan_inputs", num_scan_inputs}, - {"num_state_vars", N}}), - alt_args, - {sub_mod}); - // Outputs: final_states, array of scan_output_elements - // N + K * num_iters number of outputs + // TODO check that alt_args shapes match sub_mod input parameter shapes + + auto scan = info.add_instruction( + make_op("scan", + {{"iterations", num_iters}, {"num_scan_inputs", M}, {"num_state_vars", N}}), + alt_args, + {sub_mod}); std::vector ret; + ret.reserve(N + K); for(auto i = 0; i < N; ++i) { auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); @@ -152,7 +191,7 @@ struct parse_scan : op_parser for(auto i = N; i < N + K; ++i) { - auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); + auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); auto scan_axis = scan_output_axes[i - N]; auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); ret.push_back(usq); @@ -167,9 +206,8 @@ struct parse_scan : op_parser info.add_instruction(make_op("get_tuple_elem", {{"index", tuple_idx}}), scan); auto scan_axis = scan_output_axes[j]; auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); - auto dir = scan_output_directions[j]; - std::vector concat_args(2, usq); - concat_args[dir] = ret[N + j]; + std::vector concat_args{usq, usq}; + concat_args[scan_output_directions[j]] = ret[N + j]; auto concat = info.add_instruction(make_op("concat", {{"axis", scan_axis}}), concat_args); ret[N + j] = concat; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 3226e423c00..fe4bbc1f379 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10686,7 +10686,10 @@ def where_mixed_test(): return ([node], [c, x, y], [z]) -def scan_test(scan_output_axes=[0, 0], scan_output_directions=[0, 0]): +def scan_test(scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0, 0], + scan_output_directions=[0, 0]): sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) next = helper.make_tensor_value_info("next", TensorProto.FLOAT, [2, 2]) sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, @@ -10712,8 +10715,10 @@ def scan_test(scan_output_axes=[0, 0], scan_output_directions=[0, 0]): init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, [2, 2]) + scan_ins_sh = [2, 2, 2] + scan_ins_sh[scan_input_axes[0]] = 3 scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, - [3, 2, 2]) + scan_ins_sh) final_state = helper.make_tensor_value_info("final_state", TensorProto.FLOAT, [2, 2]) scan_outs1_sh = [2, 2, 2] @@ -10729,6 +10734,8 @@ def scan_test(scan_output_axes=[0, 0], scan_output_directions=[0, 0]): inputs=["init_state", "scan_ins"], outputs=["final_state", "scan_outs1", "scan_outs2"], num_scan_inputs=1, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, scan_output_axes=scan_output_axes, scan_output_directions=scan_output_directions, body=scan_body, @@ -10751,3 +10758,13 @@ def scan_test2(): @onnx_test() def scan_test3(): return scan_test(scan_output_axes=[1, -1], scan_output_directions=[0, 1]) + + +@onnx_test() +def scan_test4(): + return scan_test(scan_input_directions=[1]) + + +@onnx_test() +def scan_test5(): + return scan_test(scan_input_axes=[1]) diff --git a/test/onnx/scan_test4.onnx b/test/onnx/scan_test4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..648e6b5b44fa53793e4947d61057b4a7dca5834e GIT binary patch literal 654 zcmZuvJx{|h5N*;xoohuaBSI@mK^e0(VnRYHY^-$45+${fET~&yM{0jVe*kO$5I=~G z%a>xJ4wmeDzI*T99dRkBRU%`}m40~qb>OE9Af(W-(uw9k3RFl%`|m;~pJlj8+GJYI zjeF=jODqsUx|B;RS;%U_u2cb^%98Psp;7EFlDee%k<)I#s! z=>#77?i**>iZ8Qu&<7{kakc0itzIw$^1=C*iM1MIz^-cMQlfFKgI>UZSd@TG?EFMt zrK}ZCCk=NSx`9S*yIy+TfFT7SNao~3N9!!`U?{Qw%G53TH?~k{Fomlo7Blh1S6XbO zS|HjBCNQea!IZdcxNP|jJQFf9JH-ccywTJgZCiIVGlorJgt7-HyIcE|-on2XXnS%G X{ literal 0 HcmV?d00001 diff --git a/test/onnx/scan_test5.onnx b/test/onnx/scan_test5.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f54fb89f7fb4abc3334ba3dcd548d4fe896dc08d GIT binary patch literal 678 zcmZuvJx{|h6l^|9+}Da$Mub$9f-+`nBo-tjU}L3QmMDphWI^2uJ5u`_n3!1m=lCBu zPNJBIAxeBN@9y6Dp%DgAWIUBr$lJTGA^bQ1j5C=EnMn#rR~Z+0{K|OtUc*h(wv=LJ ze8-`?R%Fq)4Q{?AhxpTm1)PX-lQIr2r+bOL>ZBS&=aCRtLVLy|t)wce5IqyhC6foZ z|EC^@qZdlbg09O=GzK@@Q}Hl3T3$2-^hW89u#E^5U^kqVl_Y3=;8U0&3nQ?RT|J54 zIP2m0ywmQNE}?VV=gTj;VvGTh5_2n|qcx0tm~u6L6Y5Uo7g)+@g?Cu0CsMc scan_out2_gold{36, 16, 4, 42, 20, 6}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } + +TEST_CASE(scan_test4) +{ + auto prog = migraphx::parse_onnx("scan_test4.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(4, 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(12); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + + auto final_state = result[0]; + auto scan_out1 = result[1]; + auto scan_out2 = result[2]; + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{15, 18, 21, 24}; + EXPECT(arg_to_vec(final_state) == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{9, 10, 11, 12, 14, 16, 18, 20, 15, 18, 21, 24}; + EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{20, 22, 32, 36, 36, 42}; + EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); +} + +TEST_CASE(scan_test5) +{ + auto prog = migraphx::parse_onnx("scan_test5.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(4, 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {2, 3, 2}}; + std::vector scan_ins(12); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + + auto final_state = result[0]; + auto scan_out1 = result[1]; + auto scan_out2 = result[2]; + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{9, 12, 27, 30}; + EXPECT(arg_to_vec(final_state) == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{1, 2, 7, 8, 4, 6, 16, 18, 9, 12, 27, 30}; + EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{8, 10, 20, 24, 36, 42}; + EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); +} From 7793cb0cf4d3b8bb6161890b5c13405df8d1ea3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 26 Mar 2024 11:50:39 +0000 Subject: [PATCH 05/18] Implement ScanSlice operator --- src/CMakeLists.txt | 1 + src/include/migraphx/op/scan.hpp | 4 +- src/include/migraphx/op/scan_slice.hpp | 96 +++++++++++++ src/include/migraphx/operators.hpp | 1 + src/targets/gpu/lowering.cpp | 12 ++ test/ref/scan_slice.cpp | 187 +++++++++++++++++++++++++ test/verify/test_scan_slice.cpp | 62 ++++++++ 7 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 src/include/migraphx/op/scan_slice.hpp create mode 100644 test/ref/scan_slice.cpp create mode 100644 test/verify/test_scan_slice.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f5c04a81995..52ab9d877d9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -227,6 +227,7 @@ register_migraphx_ops( run_on_target scalar scan + scan_slice scatter_none scatter_add scatter_mul diff --git a/src/include/migraphx/op/scan.hpp b/src/include/migraphx/op/scan.hpp index 629477ba2b7..738dde181ea 100644 --- a/src/include/migraphx/op/scan.hpp +++ b/src/include/migraphx/op/scan.hpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_OPERATORS_LOOP_HPP -#define MIGRAPHX_GUARD_OPERATORS_LOOP_HPP +#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCAN_HPP #include #include diff --git a/src/include/migraphx/op/scan_slice.hpp b/src/include/migraphx/op/scan_slice.hpp new file mode 100644 index 00000000000..c31d9ed5d58 --- /dev/null +++ b/src/include/migraphx/op/scan_slice.hpp @@ -0,0 +1,96 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scan_slice : op_name +{ + int64_t axis = 0; + int64_t direction = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axis, "axis"), f(self.direction, "direction")); + } + + value attributes() const + { + value normalize_axes = value::object{}; + normalize_axes["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize_axes}}; + } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + auto input_shape = inputs[0]; + auto new_lens = input_shape.lens(); + new_lens[axis] = 1; + + return shape{input_shape.type(), new_lens, input_shape.strides()}; + } + + auto compute_offset(const shape& s, int64_t idx) const + { + return idx * s.strides().at(axis) * s.type_size(); + } + + argument compute(shape output_shape, std::vector args) const + { + auto input = args[0]; + auto input_sh = input.get_shape(); + + int64_t idx; + args[1].visit([&](auto i) { idx = i.front(); }); + const auto max_idx = input_sh.lens()[axis] - 1; + if(idx > max_idx or idx < 0) + MIGRAPHX_THROW("ScanSlice: index {" + std::to_string(idx) + "} out of range [0, " + + std::to_string(max_idx) + "]"); + idx = (1 - direction) * idx + direction * (max_idx - idx); + + auto offset = compute_offset(input_sh, idx); + return {output_shape, [=] { return input.data() + offset; }}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index 9d1ea1e94fc..2d9fa350bde 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -114,6 +114,7 @@ #include #include #include +#include #include #include #include diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index fcde59841fd..9e85140b68d 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -112,6 +112,7 @@ struct miopen_apply add_nms_op(); add_select_module_op(); add_reshape_lazy_op(); + add_scan_slice_op(); } void copy_params() const @@ -396,6 +397,17 @@ struct miopen_apply return mod->replace_instruction(ins, make_op("gpu::contiguous"), after_contiguous_args); }); } + + void add_scan_slice_op() + { + apply_map.emplace("scan_slice", [=](instruction_ref ins) { + auto inputs = ins->inputs(); + auto cpu_idx = mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), inputs[1]); + inputs[1] = mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_idx); + return mod->replace_instruction( + ins, mod->insert_instruction(ins, ins->get_operator(), inputs)); + }); + } }; void lowering::apply(module_pass_manager& mpm) const diff --git a/test/ref/scan_slice.cpp b/test/ref/scan_slice.cpp new file mode 100644 index 00000000000..6dfb53d8a60 --- /dev/null +++ b/test/ref/scan_slice.cpp @@ -0,0 +1,187 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include "migraphx/compile_options.hpp" +#include "migraphx/module.hpp" +#include +#include +#include +#include +#include +#include + +#include + +static migraphx::shape make_shape(const std::vector& lens) +{ + return migraphx::shape{migraphx::shape::int32_type, lens}; +} + +static std::vector arg_to_vec(const migraphx::argument& arg) +{ + std::vector ret; + arg.visit([&](auto output) { ret.assign(output.begin(), output.end()); }); + return ret; +} + +migraphx::program make_scan_slice_program(int64_t axis, int64_t direction) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; + std::vector data(data_sh.elements()); + std::iota(data.begin(), data.end(), 0); + auto data_lit = mm->add_literal(migraphx::literal{data_sh, data}); + + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + auto idx_param = mm->add_parameter("idx", idx_sh); + + mm->add_instruction(migraphx::make_op("scan_slice", {{"axis", axis}, {"direction", direction}}), + data_lit, + idx_param); + + p.compile(migraphx::make_target("ref")); + + return p; +} + +TEST_CASE(scan_slice_test_1) +{ + auto p = make_scan_slice_program(0, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); +} + +TEST_CASE(scan_slice_test_2) +{ + auto p = make_scan_slice_program(1, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); +} + +TEST_CASE(scan_slice_test_3) +{ + auto p = make_scan_slice_program(2, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 1})); + EXPECT(arg_to_vec(result) == std::vector{0, 2, 4, 6}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 1})); + EXPECT(arg_to_vec(result) == std::vector{1, 3, 5, 7}); +} + +TEST_CASE(scan_slice_test_4) +{ + auto p = make_scan_slice_program(-3, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); +} + +TEST_CASE(scan_slice_test_5) +{ + auto p = make_scan_slice_program(0, 1); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); +} + +TEST_CASE(scan_slice_test_6) +{ + auto p = make_scan_slice_program(-2, 1); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); +} + +TEST_CASE(scan_slice_test_7) +{ + auto p = make_scan_slice_program(0, 0); + + migraphx::parameter_map pm; + int64_t idx = 2; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + + EXPECT(test::throws([&] { p.eval(pm); })); +} diff --git a/test/verify/test_scan_slice.cpp b/test/verify/test_scan_slice.cpp new file mode 100644 index 00000000000..295f1dc4eb8 --- /dev/null +++ b/test/verify/test_scan_slice.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_scan_slice_base : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; + auto data_param = mm->add_parameter("data", data_sh); + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + auto idx_lit = mm->add_literal(migraphx::literal{idx_sh, {0}}); + + mm->add_instruction( + migraphx::make_op("scan_slice", {{"axis", axis}, {"direction", direction}}), + data_param, + idx_lit); + + return p; + } +}; + +struct test_scan_slice1 : test_scan_slice_base +{ +}; + +struct test_scan_slice2 : test_scan_slice_base +{ +}; + +struct test_scan_slice3: test_scan_slice_base +{ +}; From 3fcbb689a1ae1d6d0884ef4278f6b3ea32a8efc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 26 Mar 2024 23:27:46 +0000 Subject: [PATCH 06/18] Implement via adapting subgraph to a what a loop expects and using the loop operator --- src/include/migraphx/op/scan_slice.hpp | 7 +- src/onnx/parse_scan.cpp | 181 ++++++++++++++++++------- 2 files changed, 134 insertions(+), 54 deletions(-) diff --git a/src/include/migraphx/op/scan_slice.hpp b/src/include/migraphx/op/scan_slice.hpp index c31d9ed5d58..6bb5e56c4c7 100644 --- a/src/include/migraphx/op/scan_slice.hpp +++ b/src/include/migraphx/op/scan_slice.hpp @@ -66,11 +66,6 @@ struct scan_slice : op_name return shape{input_shape.type(), new_lens, input_shape.strides()}; } - auto compute_offset(const shape& s, int64_t idx) const - { - return idx * s.strides().at(axis) * s.type_size(); - } - argument compute(shape output_shape, std::vector args) const { auto input = args[0]; @@ -84,7 +79,7 @@ struct scan_slice : op_name std::to_string(max_idx) + "]"); idx = (1 - direction) * idx + direction * (max_idx - idx); - auto offset = compute_offset(input_sh, idx); + auto offset = idx * input_sh.strides().at(axis) * input_sh.type_size(); return {output_shape, [=] { return input.data() + offset; }}; } }; diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 739d19d9df5..50c6b111f11 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -24,8 +24,11 @@ #include "migraphx/argument.hpp" #include "migraphx/errors.hpp" #include "migraphx/instruction_ref.hpp" +#include "migraphx/iterator_for.hpp" +#include "migraphx/module_ref.hpp" #include #include +#include #include #include #include @@ -156,65 +159,87 @@ struct parse_scan : op_parser } // SCAN OUTPUT DIRECTIONS - std::vector alt_args(args.begin(), args.begin() + N); - for(int64_t i = 0; i < num_iters; ++i) - { - for(auto j = 0; j < M; ++j) - { - auto dir = scan_input_directions[j]; - auto idx = (1 - dir) * i + dir * (num_iters - 1 - i); - auto scan_axis = scan_input_axes[j]; - auto slice = info.add_instruction( - make_op("slice", - {{"axes", {scan_axis}}, {"starts", {idx}}, {"ends", {idx + 1}}}), - args[N + j]); - alt_args.push_back( - info.add_instruction(make_op("squeeze", {{"axes", {scan_axis}}}), slice)); - } - } + // std::vector alt_args(args.begin(), args.begin() + N); + // for(int64_t i = 0; i < num_iters; ++i) + // { + // for(auto j = 0; j < M; ++j) + // { + // auto dir = scan_input_directions[j]; + // auto idx = (1 - dir) * i + dir * (num_iters - 1 - i); + // auto scan_axis = scan_input_axes[j]; + // auto slice = info.add_instruction( + // make_op("slice", + // {{"axes", {scan_axis}}, {"starts", {idx}}, {"ends", {idx + 1}}}), + // args[N + j]); + // alt_args.push_back( + // info.add_instruction(make_op("squeeze", {{"axes", {scan_axis}}}), slice)); + // } + // } // TODO check that alt_args shapes match sub_mod input parameter shapes - auto scan = info.add_instruction( - make_op("scan", - {{"iterations", num_iters}, {"num_scan_inputs", M}, {"num_state_vars", N}}), - alt_args, - {sub_mod}); + modify_body(sub_mod, args, M, N, scan_input_axes, scan_input_directions); + auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); + args.insert(args.begin(), cond_lit); + auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); + args.insert(args.begin(), max_iter_lit); + + auto loop = + info.add_instruction(make_op("loop", {{"max_iterations", num_iters}}), args, {sub_mod}); std::vector ret; ret.reserve(N + K); - for(auto i = 0; i < N; ++i) + for(std::size_t i = 0; i < N + M + K; ++i) { - auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); - ret.push_back(get); + if(i >= N and i < N + M) + continue; + auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop); + ret.push_back(r); } + // TODO add transpose for scan outputs - for(auto i = N; i < N + K; ++i) - { - auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); - auto scan_axis = scan_output_axes[i - N]; - auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); - ret.push_back(usq); - } + return ret; - for(auto i = 1; i < num_iters; ++i) - { - for(auto j = 0; j < K; ++j) - { - auto tuple_idx = N + i * K + j; - auto get = - info.add_instruction(make_op("get_tuple_elem", {{"index", tuple_idx}}), scan); - auto scan_axis = scan_output_axes[j]; - auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); - std::vector concat_args{usq, usq}; - concat_args[scan_output_directions[j]] = ret[N + j]; - auto concat = - info.add_instruction(make_op("concat", {{"axis", scan_axis}}), concat_args); - ret[N + j] = concat; - } - } + // auto scan = info.add_instruction( + // make_op("scan", + // {{"iterations", num_iters}, {"num_scan_inputs", M}, {"num_state_vars", N}}), + // alt_args, + // {sub_mod}); - return ret; + // std::vector ret; + // ret.reserve(N + K); + // for(auto i = 0; i < N; ++i) + // { + // auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); + // ret.push_back(get); + // } + + // for(auto i = N; i < N + K; ++i) + // { + // auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), + // scan); auto scan_axis = scan_output_axes[i - N]; auto usq = + // info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); + // ret.push_back(usq); + // } + + // for(auto i = 1; i < num_iters; ++i) + // { + // for(auto j = 0; j < K; ++j) + // { + // auto tuple_idx = N + i * K + j; + // auto get = + // info.add_instruction(make_op("get_tuple_elem", {{"index", tuple_idx}}), + // scan); + // auto scan_axis = scan_output_axes[j]; + // auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), + // get); std::vector concat_args{usq, usq}; concat_args[scan_output_directions[j]] = + // ret[N + j]; auto concat = + // info.add_instruction(make_op("concat", {{"axis", scan_axis}}), concat_args); + // ret[N + j] = concat; + // } + // } + + // return ret; } void normalize_axes(std::vector& axes, const std::vector& ndims) const @@ -229,6 +254,66 @@ struct parse_scan : op_parser std::transform(axes.begin(), axes.end(), ndims.begin(), axes.begin(), normalize_axis); } + + void modify_body(module_ref mod, + const std::vector& args, + int64_t M, + int64_t N, + const std::vector& scan_input_axes, + const std::vector& scan_input_directions) const + { + auto param_names = mod->get_parameter_names(); + auto param_shapes = mod->get_parameter_shapes(); + + std::unordered_map> child_ins; + for(auto ins : iterator_for(*mod)) + { + for(const auto& name : param_names) + { + auto param = mod->get_parameter(name); + if(contains(ins->inputs(), param)) + child_ins[name].push_back(ins); + } + } + + auto iter_param = mod->add_parameter("iter", shape{shape::int64_type}); + auto cond_param = mod->add_parameter("cond", shape{shape::bool_type}); + for(auto i = 0; i < M; ++i) + { + auto var = + mod->add_parameter("state_var" + std::to_string(i), param_shapes[param_names[i]]); + auto param = mod->get_parameter(param_names[i]); + for(auto ins : child_ins[param_names[i]]) + ins->replace_argument(ins, param, var); + mod->remove_instruction(param); + } + + std::vector scan_in_params; + scan_in_params.reserve(N); + for(auto i = M; i < M + N; ++i) + { + auto param = mod->get_parameter(param_names[i]); + auto scan_in_param = + mod->add_parameter("scan_in" + std::to_string(i - M), args[i]->get_shape()); + scan_in_params.push_back(scan_in_param); + auto scan_axis = scan_input_axes[i - M]; + auto scan_dir = scan_input_directions[i - M]; + auto scan_in_slice = mod->insert_instruction( + param, + make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), + scan_in_param, + iter_param); + scan_in_slice = mod->insert_instruction( + param, make_op("squeeze", {{"axes", {scan_axis}}}), scan_in_slice); + for(auto ins : child_ins[param_names[i]]) + ins->replace_argument(ins, param, scan_in_slice); + mod->remove_instruction(param); + } + auto returns = mod->get_returns(); + returns.insert(returns.begin(), cond_param); + returns.insert(returns.begin() + M + 1, scan_in_params.begin(), scan_in_params.end()); + mod->replace_return(returns); + } }; } // namespace onnx From 81e0804f90e6c83e3e452b2b17d2acbb15b08f52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 27 Mar 2024 12:31:50 +0000 Subject: [PATCH 07/18] Add support for scan_output_axes and code refactoring --- src/CMakeLists.txt | 1 - src/include/migraphx/op/scan.hpp | 131 -------------- src/include/migraphx/operators.hpp | 1 - src/onnx/parse_scan.cpp | 267 +++++++++++------------------ test/onnx/gen_onnx.py | 2 +- test/onnx/scan_test3.onnx | Bin 633 -> 687 bytes 6 files changed, 99 insertions(+), 303 deletions(-) delete mode 100644 src/include/migraphx/op/scan.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 52ab9d877d9..4860de409e3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -226,7 +226,6 @@ register_migraphx_ops( rsqrt run_on_target scalar - scan scan_slice scatter_none scatter_add diff --git a/src/include/migraphx/op/scan.hpp b/src/include/migraphx/op/scan.hpp deleted file mode 100644 index 738dde181ea..00000000000 --- a/src/include/migraphx/op/scan.hpp +++ /dev/null @@ -1,131 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_HPP -#define MIGRAPHX_GUARD_OPERATORS_SCAN_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace op { - -struct scan : op_name -{ - int64_t iterations; - int64_t num_scan_inputs; - int64_t num_state_vars; - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.iterations, "iterations"), - f(self.num_scan_inputs, "num_scan_inputs"), - f(self.num_state_vars, "num_state_vars")); - } - - shape compute_shape(const std::vector& inputs, std::vector mods) const - { - assert(mods.size() == 1); - // check_shapes{inputs, *this}.standard(); - auto mod = mods.front(); - // The module has N + K outputs - auto mod_output_shapes = mod->get_output_shapes(); - std::vector op_output_shapes{mod_output_shapes.begin(), - mod_output_shapes.begin() + num_state_vars}; - auto K = mod_output_shapes.size() - num_state_vars; - op_output_shapes.reserve(num_state_vars + iterations * K); - for(auto i = 0; i < iterations; ++i) - op_output_shapes.insert(op_output_shapes.end(), - mod_output_shapes.begin() + num_state_vars, - mod_output_shapes.end()); - - return shape{op_output_shapes}; - } - - std::unordered_map get_output_params(const module_ref mod) const - { - std::unordered_map ret; - const std::string output_prefix = "#output_"; - - const auto& param_names = mod->get_parameter_names(); - for(const auto& name : param_names) - { - auto n = name.find(output_prefix); - if(n == std::string::npos) - continue; - int idx = std::stoi(name.substr(n + output_prefix.size())); - ret[name] = idx; - } - - return ret; - } - - argument compute(context& ctx, - const shape& out_shape, - const std::vector& args, - const std::vector& mods, - const std::function( - module_ref&, const std::unordered_map&)>& run) const - { - assert(mods.size() == 1); - auto mod = mods.front(); - auto param_shapes = mod->get_parameter_shapes(); - auto param_names = mod->get_parameter_names(); - - auto K = mod->get_output_shapes().size() - num_state_vars; - parameter_map pm; - std::vector ret{args.begin(), args.begin() + num_state_vars}; - for(auto i = 0; i < iterations; ++i) - { - for(auto j = 0; j < num_state_vars; ++j) - pm[param_names[j]] = ret[j]; - for(auto j = num_state_vars; j < num_state_vars + num_scan_inputs; ++j) - pm[param_names[j]] = args[i * num_scan_inputs + j]; - - auto mod_output = run(mod, pm); - - std::copy(mod_output.begin(), mod_output.begin() + num_state_vars, ret.begin()); - ret.insert(ret.end(), mod_output.begin() + num_state_vars, mod_output.end()); - } - - return argument{ret}; - } -}; - -} // namespace op -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index 2d9fa350bde..752fe67d7a9 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -113,7 +113,6 @@ #include #include #include -#include #include #include #include diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 50c6b111f11..0befe8c663d 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -26,6 +26,7 @@ #include "migraphx/instruction_ref.hpp" #include "migraphx/iterator_for.hpp" #include "migraphx/module_ref.hpp" +#include "migraphx/onnx/onnx_parser.hpp" #include #include #include @@ -45,50 +46,26 @@ struct parse_scan : op_parser { std::vector operators() const { return {{"Scan"}}; } - std::vector parse(const op_desc& opd, + std::vector parse(const op_desc& /*opd*/, onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { - // NOTE Version 8 of the operator differs to all the later versions - if(not contains(info.attributes, "body")) - MIGRAPHX_THROW("Scan: body attribute required"); + check_for_required_attributes(info, {"body", "num_scan_inputs"}); - if(not contains(info.attributes, "num_scan_inputs")) - MIGRAPHX_THROW("Scan: num_scan_inputs attribute required"); + const auto& body_graph = info.attributes["body"].g(); + auto body = parser.prog.create_module(info.name + "_scan"); + (void)parser.parse_graph(body, body_graph); - const auto& body = info.attributes["body"].g(); - auto sub_mod = parser.prog.create_module(info.name + "_scan"); - (void)parser.parse_graph(sub_mod, body); + auto body_outs = body->get_returns(); + const auto M = info.attributes["num_scan_inputs"].i(); + const auto N = args.size() - M; + const auto K = body_outs.size() - N; - auto sub_mod_output_shapes = sub_mod->get_output_shapes(); - const auto M = info.attributes["num_scan_inputs"].i(); - const auto N = args.size() - M; - const auto K = sub_mod_output_shapes.size() - N; - - // NOTE Does not apply to opset 8 version - if(sub_mod->get_parameter_names().size() != N + M) + if(body->get_parameter_names().size() != N + M) MIGRAPHX_THROW("Lorem ipsum"); - // SCAN INPUT AXES - std::vector scan_input_axes(M, 0); - if(contains(info.attributes, "scan_input_axes")) - { - auto&& axes = info.attributes["scan_input_axes"].ints(); - scan_input_axes.assign(axes.begin(), axes.end()); - - if(scan_input_axes.size() != M) - MIGRAPHX_THROW("Number of scan input axes (" + to_string(scan_input_axes.size()) + - ") does not match number of scan inputs(" + to_string(M) + ")"); - - std::vector ndims; - ndims.reserve(M); - std::transform(args.begin() + N, - args.end(), - std::back_inserter(ndims), - [](instruction_ref arg) { return arg->get_shape().ndim(); }); - normalize_axes(scan_input_axes, ndims); - } + const auto scan_input_axes = parse_axes(info, "scan_input_axes", M, args.begin() + N, 0); size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; for(auto i = 1; i < M; ++i) @@ -96,163 +73,115 @@ struct parse_scan : op_parser if(args[i]->get_shape().lens()[scan_input_axes[i]] != num_iters) MIGRAPHX_THROW("Lorem ipsum"); } - // SCAN INPUT AXES - // SCAN INPUT DIRECTIONS - std::vector scan_input_directions(M, 0); - if(contains(info.attributes, "scan_input_directions")) - { - auto&& dirs = info.attributes["scan_input_directions"].ints(); - scan_input_directions.assign(dirs.begin(), dirs.end()); - - if(scan_input_directions.size() != M) - MIGRAPHX_THROW("Number of scan input directions (" + - to_string(scan_input_directions.size()) + - ") does not match number of scan inputs(" + to_string(M) + ")"); - - if(any_of(scan_input_directions, [](auto i) { return i != 0 and i != 1; })) - MIGRAPHX_THROW( - "Scan output directions may contain only 1s and 0s, actual values: " + - to_string_range(scan_input_directions)); - } - // SCAN INPUT DIRECTIONS + const auto scan_input_directions = parse_dirs(info, "scan_input_directions", M); - // SCAN OUTPUT AXES - std::vector scan_output_axes(K, 0); - if(contains(info.attributes, "scan_output_axes")) - { - auto&& axes = info.attributes["scan_output_axes"].ints(); - scan_output_axes.assign(axes.begin(), axes.end()); - - if(scan_output_axes.size() != K) - MIGRAPHX_THROW("Number of scan output axes (" + to_string(scan_output_axes.size()) + - ") does not match number of body scan outputs(" + to_string(K) + - ")"); - - std::vector ndims; - ndims.reserve(K); - std::transform(sub_mod_output_shapes.begin() + N, - sub_mod_output_shapes.end(), - std::back_inserter(ndims), - [](const shape& sh) { return sh.ndim() + 1; }); - normalize_axes(scan_output_axes, ndims); - } - // SCAN OUTPUT AXES + const auto scan_output_axes = + parse_axes(info, "scan_output_axes", K, body_outs.begin() + N, 1); - // SCAN OUTPUT DIRECTIONS - std::vector scan_output_directions(K, 0); - if(contains(info.attributes, "scan_output_directions")) - { - auto&& dirs = info.attributes["scan_output_directions"].ints(); - scan_output_directions.assign(dirs.begin(), dirs.end()); - - if(scan_output_directions.size() != K) - MIGRAPHX_THROW("Number of scan output directions (" + - to_string(scan_output_directions.size()) + - ") does not match number of body scan outputs(" + to_string(K) + - ")"); - - if(any_of(scan_output_directions, [](auto i) { return i != 0 and i != 1; })) - MIGRAPHX_THROW( - "Scan output directions may contain only 1s and 0s, actual values: " + - to_string_range(scan_output_directions)); - } - // SCAN OUTPUT DIRECTIONS - - // std::vector alt_args(args.begin(), args.begin() + N); - // for(int64_t i = 0; i < num_iters; ++i) - // { - // for(auto j = 0; j < M; ++j) - // { - // auto dir = scan_input_directions[j]; - // auto idx = (1 - dir) * i + dir * (num_iters - 1 - i); - // auto scan_axis = scan_input_axes[j]; - // auto slice = info.add_instruction( - // make_op("slice", - // {{"axes", {scan_axis}}, {"starts", {idx}}, {"ends", {idx + 1}}}), - // args[N + j]); - // alt_args.push_back( - // info.add_instruction(make_op("squeeze", {{"axes", {scan_axis}}}), slice)); - // } - // } - - // TODO check that alt_args shapes match sub_mod input parameter shapes - - modify_body(sub_mod, args, M, N, scan_input_axes, scan_input_directions); + const auto scan_output_directions = parse_dirs(info, "scan_output_directions", K); + + // TODO check that alt_args shapes match body input parameter shapes + + modify_body(body, args, M, N, scan_input_axes, scan_input_directions); auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); args.insert(args.begin(), cond_lit); auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); args.insert(args.begin(), max_iter_lit); auto loop = - info.add_instruction(make_op("loop", {{"max_iterations", num_iters}}), args, {sub_mod}); + info.add_instruction(make_op("loop", {{"max_iterations", num_iters}}), args, {body}); std::vector ret; ret.reserve(N + K); - for(std::size_t i = 0; i < N + M + K; ++i) + for(std::size_t i = 0; i < N; ++i) { - if(i >= N and i < N + M) - continue; auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop); ret.push_back(r); } - // TODO add transpose for scan outputs + + for(std::size_t i = N + M; i < N + M + K; ++i) + { + auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop); + auto scan_axis = scan_output_axes[i - N - M]; + std::vector perm(r->get_shape().ndim(), 0); + std::iota(perm.begin(), perm.end(), 0); + std::copy(perm.begin() + 1, perm.begin() + 1 + scan_axis, perm.begin()); + perm[scan_axis] = 0; + r = info.add_instruction(make_op("transpose", {{"permutation", perm}}), r); + ret.push_back(r); + } return ret; + } + + void check_for_required_attributes(onnx_parser::node_info& info, + std::vector attribute_names) const + { + for(const auto& name : attribute_names) + if(not contains(info.attributes, name)) + MIGRAPHX_THROW("Scan: " + name + " attribute required"); + } + + std::vector parse_vector_attribute(onnx_parser::node_info& info, + const std::string& attr_name, + size_t expected_size) const + { + if(not contains(info.attributes, attr_name)) + return {}; + + std::vector res; + auto&& attr = info.attributes[attr_name].ints(); + if(attr.size() != expected_size) + MIGRAPHX_THROW("Scan: " + attr_name + " size is " + to_string(attr.size()) + + ", should be " + to_string(expected_size)); + res.assign(attr.begin(), attr.end()); - // auto scan = info.add_instruction( - // make_op("scan", - // {{"iterations", num_iters}, {"num_scan_inputs", M}, {"num_state_vars", N}}), - // alt_args, - // {sub_mod}); - - // std::vector ret; - // ret.reserve(N + K); - // for(auto i = 0; i < N; ++i) - // { - // auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), scan); - // ret.push_back(get); - // } - - // for(auto i = N; i < N + K; ++i) - // { - // auto get = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), - // scan); auto scan_axis = scan_output_axes[i - N]; auto usq = - // info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), get); - // ret.push_back(usq); - // } - - // for(auto i = 1; i < num_iters; ++i) - // { - // for(auto j = 0; j < K; ++j) - // { - // auto tuple_idx = N + i * K + j; - // auto get = - // info.add_instruction(make_op("get_tuple_elem", {{"index", tuple_idx}}), - // scan); - // auto scan_axis = scan_output_axes[j]; - // auto usq = info.add_instruction(make_op("unsqueeze", {{"axes", {scan_axis}}}), - // get); std::vector concat_args{usq, usq}; concat_args[scan_output_directions[j]] = - // ret[N + j]; auto concat = - // info.add_instruction(make_op("concat", {{"axis", scan_axis}}), concat_args); - // ret[N + j] = concat; - // } - // } - - // return ret; + return res; } - void normalize_axes(std::vector& axes, const std::vector& ndims) const + std::vector + parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const { - auto normalize_axis = [=](int64_t axis, int64_t ndim) { - if(axis < -ndim or axis >= ndim) - MIGRAPHX_THROW("Axis value {" + to_string(axis) + "} out of range [" + - to_string(-ndim) + ", " + to_string(ndim) + ")"); + auto dirs = parse_vector_attribute(info, name, expected_size); + if(dirs.empty()) + return std::vector(expected_size, 0); - return axis < 0 ? ndim + axis : axis; - }; + if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) + MIGRAPHX_THROW("Scan: " + name + + " may contain only 1s and 0s, actual values: " + to_string_range(dirs)); - std::transform(axes.begin(), axes.end(), ndims.begin(), axes.begin(), normalize_axis); + return dirs; + } + + int64_t normalize_axis(int64_t axis, int64_t rank) const + { + if(axis < -rank or axis >= rank) + MIGRAPHX_THROW("Axis value {" + to_string(axis) + "} out of range [" + + to_string(-rank) + ", " + to_string(rank) + ")"); + + return axis < 0 ? rank + axis : axis; + } + + std::vector parse_axes(onnx_parser::node_info& info, + const std::string& name, + size_t expected_size, + std::vector::iterator ins_begin, + size_t rank_offset) const + { + auto axes = parse_vector_attribute(info, name, expected_size); + if(axes.empty()) + return std::vector(expected_size, 0); + + std::transform(axes.begin(), + axes.end(), + ins_begin, + axes.begin(), + [&](int64_t axis, instruction_ref arg) { + return normalize_axis(axis, arg->get_shape().ndim() + rank_offset); + }); + + return axes; } void modify_body(module_ref mod, @@ -314,7 +243,7 @@ struct parse_scan : op_parser returns.insert(returns.begin() + M + 1, scan_in_params.begin(), scan_in_params.end()); mod->replace_return(returns); } -}; +}; // namespace onnx } // namespace onnx } // namespace MIGRAPHX_INLINE_NS diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index fe4bbc1f379..3d8a84c17d9 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10757,7 +10757,7 @@ def scan_test2(): @onnx_test() def scan_test3(): - return scan_test(scan_output_axes=[1, -1], scan_output_directions=[0, 1]) + return scan_test(scan_output_axes=[1, -1]) @onnx_test() diff --git a/test/onnx/scan_test3.onnx b/test/onnx/scan_test3.onnx index 1a49a6f229b44b9d168711097a3c0134476f57cb..9255088ea0d1b7e2f1ba04ae9cd9655fd32a1e01 100644 GIT binary patch delta 92 zcmey#vYu6#gHwpBI5{ydz9hA{#Mo*UE7!Y?!ljI+VqE-S;mo{((vtYZiqv8Uh6Rl5 gS~6UsNa87(MXAXpnfZAT*~#sUj*JYGPcm)*09QaB0ssI2 delta 38 ucmZ3_`jbVNgHwpBI5{ydz9hA{#MtTy3)hZ~!ljIp_c1y!GEV-_xB&p{w+!9@ From 18aabadc61aac45d670187d483ff5eb806f38340 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 27 Mar 2024 18:59:22 +0000 Subject: [PATCH 08/18] Refactoring --- src/onnx/parse_scan.cpp | 126 +++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 67 deletions(-) diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 0befe8c663d..8a6c86b7a8e 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -55,7 +55,7 @@ struct parse_scan : op_parser const auto& body_graph = info.attributes["body"].g(); auto body = parser.prog.create_module(info.name + "_scan"); - (void)parser.parse_graph(body, body_graph); + parser.parse_graph(body, body_graph); auto body_outs = body->get_returns(); const auto M = info.attributes["num_scan_inputs"].i(); @@ -69,10 +69,8 @@ struct parse_scan : op_parser size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; for(auto i = 1; i < M; ++i) - { if(args[i]->get_shape().lens()[scan_input_axes[i]] != num_iters) MIGRAPHX_THROW("Lorem ipsum"); - } const auto scan_input_directions = parse_dirs(info, "scan_input_directions", M); @@ -83,33 +81,26 @@ struct parse_scan : op_parser // TODO check that alt_args shapes match body input parameter shapes - modify_body(body, args, M, N, scan_input_axes, scan_input_directions); - auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); - args.insert(args.begin(), cond_lit); + modify_body(body, args, N, M, scan_input_axes, scan_input_directions); + auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); - args.insert(args.begin(), max_iter_lit); + auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); + std::vector loop_args{max_iter_lit, cond_lit}; + loop_args.insert(loop_args.end(), args.begin(), args.begin() + N); - auto loop = - info.add_instruction(make_op("loop", {{"max_iterations", num_iters}}), args, {body}); + auto loop = info.add_instruction( + make_op("loop", {{"max_iterations", num_iters}}), loop_args, {body}); std::vector ret; ret.reserve(N + K); - for(std::size_t i = 0; i < N; ++i) - { - auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop); - ret.push_back(r); - } + for(auto i = 0; i < N; ++i) + ret.push_back(info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop)); - for(std::size_t i = N + M; i < N + M + K; ++i) + for(auto i = 0; i < K; ++i) { - auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop); - auto scan_axis = scan_output_axes[i - N - M]; - std::vector perm(r->get_shape().ndim(), 0); - std::iota(perm.begin(), perm.end(), 0); - std::copy(perm.begin() + 1, perm.begin() + 1 + scan_axis, perm.begin()); - perm[scan_axis] = 0; - r = info.add_instruction(make_op("transpose", {{"permutation", perm}}), r); - ret.push_back(r); + auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + N}}), loop); + auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); + ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); } return ret; @@ -186,64 +177,65 @@ struct parse_scan : op_parser void modify_body(module_ref mod, const std::vector& args, - int64_t M, int64_t N, + int64_t M, const std::vector& scan_input_axes, const std::vector& scan_input_directions) const { - auto param_names = mod->get_parameter_names(); - auto param_shapes = mod->get_parameter_shapes(); - - std::unordered_map> child_ins; - for(auto ins : iterator_for(*mod)) - { - for(const auto& name : param_names) - { - auto param = mod->get_parameter(name); - if(contains(ins->inputs(), param)) - child_ins[name].push_back(ins); - } - } + std::vector params; + params.reserve(N + M); + transform(mod->get_parameter_names(), + std::back_inserter(params), + [&](const std::string& name) { return mod->get_parameter(name); }); auto iter_param = mod->add_parameter("iter", shape{shape::int64_type}); auto cond_param = mod->add_parameter("cond", shape{shape::bool_type}); - for(auto i = 0; i < M; ++i) - { - auto var = - mod->add_parameter("state_var" + std::to_string(i), param_shapes[param_names[i]]); - auto param = mod->get_parameter(param_names[i]); - for(auto ins : child_ins[param_names[i]]) - ins->replace_argument(ins, param, var); - mod->remove_instruction(param); - } + std::vector new_params; + new_params.reserve(N); + for(auto i = 0; i < N; ++i) + new_params.push_back( + mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); - std::vector scan_in_params; - scan_in_params.reserve(N); - for(auto i = M; i < M + N; ++i) + for(auto i = 0; i < params.size(); ++i) { - auto param = mod->get_parameter(param_names[i]); - auto scan_in_param = - mod->add_parameter("scan_in" + std::to_string(i - M), args[i]->get_shape()); - scan_in_params.push_back(scan_in_param); - auto scan_axis = scan_input_axes[i - M]; - auto scan_dir = scan_input_directions[i - M]; - auto scan_in_slice = mod->insert_instruction( - param, - make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), - scan_in_param, - iter_param); - scan_in_slice = mod->insert_instruction( - param, make_op("squeeze", {{"axes", {scan_axis}}}), scan_in_slice); - for(auto ins : child_ins[param_names[i]]) - ins->replace_argument(ins, param, scan_in_slice); - mod->remove_instruction(param); + for(auto ins : iterator_for(*mod)) + { + if(not contains(ins->inputs(), params[i])) + continue; + + auto new_ins = i < N ? new_params[i] : args[i]; + if(i >= N) + { + auto scan_axis = scan_input_axes[i - N]; + auto scan_dir = scan_input_directions[i - N]; + new_ins = mod->insert_instruction( + params[i], + make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), + new_ins, + iter_param); + new_ins = mod->insert_instruction( + params[i], make_op("squeeze", {{"axes", {scan_axis}}}), new_ins); + } + ins->replace_argument(ins, params[i], new_ins); + } + mod->remove_instruction(params[i]); } + auto returns = mod->get_returns(); returns.insert(returns.begin(), cond_param); - returns.insert(returns.begin() + M + 1, scan_in_params.begin(), scan_in_params.end()); mod->replace_return(returns); } -}; // namespace onnx + + std::vector make_perm_for_scan_out(int64_t rank, int64_t axis) const + { + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::copy(perm.begin() + 1, perm.begin() + 1 + axis, perm.begin()); + perm[axis] = 0; + + return perm; + } +}; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS From ed1382a0b598327f2c9bf4a14fd4ad3a80605137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Thu, 28 Mar 2024 21:44:44 +0000 Subject: [PATCH 09/18] Add support for scan_output_directions, modify tests to have two scan inputs --- src/include/migraphx/op/loop.hpp | 19 ++-- src/include/migraphx/run_loop.hpp | 12 ++- src/onnx/parse_scan.cpp | 14 ++- src/targets/gpu/loop.cpp | 11 ++- test/onnx/gen_onnx.py | 47 ++++++---- test/onnx/scan_test1.onnx | Bin 564 -> 793 bytes test/onnx/scan_test2.onnx | Bin 597 -> 793 bytes test/onnx/scan_test3.onnx | Bin 687 -> 802 bytes test/onnx/scan_test4.onnx | Bin 654 -> 793 bytes test/onnx/scan_test5.onnx | Bin 678 -> 793 bytes test/onnx/verify/scan_test.cpp | 147 +++++++++--------------------- test/run_loop_test.cpp | 2 +- 12 files changed, 112 insertions(+), 140 deletions(-) diff --git a/src/include/migraphx/op/loop.hpp b/src/include/migraphx/op/loop.hpp index 969c16ef7cd..658368f757f 100644 --- a/src/include/migraphx/op/loop.hpp +++ b/src/include/migraphx/op/loop.hpp @@ -41,12 +41,14 @@ namespace op { struct loop { - int64_t max_iterations = 10; + int64_t max_iterations = 10; + std::vector scan_output_directions = {}; template static auto reflect(Self& self, F f) { - return pack(f(self.max_iterations, "max_iterations")); + return pack(f(self.max_iterations, "max_iterations"), + f(self.scan_output_directions, "scan_output_directions")); } std::string name() const { return "loop"; } @@ -97,7 +99,9 @@ struct loop void append(const std::vector& iter_state, const std::vector& concatenated_outputs, - int iter) const + const std::vector& scan_output_dirs, + int64_t iter, + int64_t iter_num) const { assert(iter_state.size() == concatenated_outputs.size()); for(auto i : range(iter_state.size())) @@ -105,11 +109,14 @@ struct loop const auto& iter_stat = iter_state.at(i); const auto& scan_out = concatenated_outputs.at(i); + auto dir = scan_output_dirs.empty() ? 0 : scan_output_dirs[i]; + auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter); + auto* in_data = iter_stat.data(); auto* out_data = scan_out.data(); std::size_t out_size = iter_stat.get_shape().bytes(); - assert((iter + 1) * out_size <= scan_out.get_shape().bytes()); - std::copy(in_data, in_data + out_size, out_data + iter * out_size); + assert((idx + 1) * out_size <= scan_out.get_shape().bytes()); + std::copy(in_data, in_data + out_size, out_data + idx * out_size); } } @@ -153,7 +160,7 @@ struct loop cpy_args.push_back(argument(out_shape)); // run loop - return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run); + return run_loop(ref_loop{max_iterations}, scan_output_directions, ctx, cpy_args, mods, run); } }; diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp index 859d7dfad7f..d6b3c51bb37 100644 --- a/src/include/migraphx/run_loop.hpp +++ b/src/include/migraphx/run_loop.hpp @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP #define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP +#include "stringutils.hpp" #include #include #include @@ -39,6 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS { template argument run_loop(const LoopModel& model, + const std::vector& scan_output_directions, T& ctx, std::vector args, const std::vector& mods, @@ -101,9 +103,13 @@ argument run_loop(const LoopModel& model, auto output_index = out_param_indices[name]; if(output_index > dep_num) { + int64_t dir = scan_output_directions.empty() + ? 0 + : scan_output_directions[output_index - dep_num - 1]; + auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter); const auto& arg = out_args.at(output_index); - assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes()); - params[name] = argument(ps, arg.data() + iter * ps.bytes()); + assert((idx + 1) * ps.bytes() <= arg.get_shape().bytes()); + params[name] = argument(ps, arg.data() + idx * ps.bytes()); } else { @@ -123,7 +129,7 @@ argument run_loop(const LoopModel& model, std::copy(dep_out.begin(), dep_out.end(), out_args.begin()); std::vector mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end()); - model.append(mod_scan_outs, scan_outputs, iter); + model.append(mod_scan_outs, scan_outputs, scan_output_directions, iter, iter_num); } out_args.erase(out_args.begin()); diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 8a6c86b7a8e..85b957b14bc 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -63,14 +63,14 @@ struct parse_scan : op_parser const auto K = body_outs.size() - N; if(body->get_parameter_names().size() != N + M) - MIGRAPHX_THROW("Lorem ipsum"); + MIGRAPHX_THROW("Lorem ipsum 1"); const auto scan_input_axes = parse_axes(info, "scan_input_axes", M, args.begin() + N, 0); size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; for(auto i = 1; i < M; ++i) - if(args[i]->get_shape().lens()[scan_input_axes[i]] != num_iters) - MIGRAPHX_THROW("Lorem ipsum"); + if(args[N + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) + MIGRAPHX_THROW("Lorem ipsum 2"); const auto scan_input_directions = parse_dirs(info, "scan_input_directions", M); @@ -88,8 +88,12 @@ struct parse_scan : op_parser std::vector loop_args{max_iter_lit, cond_lit}; loop_args.insert(loop_args.end(), args.begin(), args.begin() + N); - auto loop = info.add_instruction( - make_op("loop", {{"max_iterations", num_iters}}), loop_args, {body}); + auto loop = + info.add_instruction(make_op("loop", + {{"max_iterations", num_iters}, + {"scan_output_directions", scan_output_directions}}), + loop_args, + {body}); std::vector ret; ret.reserve(N + K); diff --git a/src/targets/gpu/loop.cpp b/src/targets/gpu/loop.cpp index ee2731938d7..ad5fc210c7c 100644 --- a/src/targets/gpu/loop.cpp +++ b/src/targets/gpu/loop.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include @@ -56,7 +57,13 @@ struct gpu_loop copy_to_gpu(ctx, arg_src, dst); } - void append(const std::vector&, const std::vector&, int) const {} + void append(const std::vector&, + const std::vector&, + const std::vector&, + int64_t, + int64_t) const + { + } void set_zero(context& ctx, const std::vector& concatenated_outputs, int iter) const { @@ -111,7 +118,7 @@ hip_loop::compute(context& ctx, const std::function( module_ref&, const std::unordered_map&)>& run) const { - return run_loop(gpu_loop{op.max_iterations}, ctx, args, mods, run); + return run_loop(gpu_loop{op.max_iterations}, op.scan_output_directions, ctx, args, mods, run); } } // namespace gpu diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 3d8a84c17d9..d6d6f3f0f3c 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10686,21 +10686,27 @@ def where_mixed_test(): return ([node], [c, x, y], [z]) -def scan_test(scan_input_axes=[0], - scan_input_directions=[0], +def scan_test(scan_input_axes=[0, 0], + scan_input_directions=[0, 0], scan_output_axes=[0, 0], scan_output_directions=[0, 0]): sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) - next = helper.make_tensor_value_info("next", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [1]) sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, [2, 2]) scan_out1 = helper.make_tensor_value_info("scan_out1", TensorProto.FLOAT, [2, 2]) scan_out2 = helper.make_tensor_value_info("scan_out2", TensorProto.FLOAT, [2]) - add = helper.make_node("Add", - inputs=["sum_in", "next"], - outputs=["sum_out"]) + add1 = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["add1_out"]) + add2 = helper.make_node("Add", + inputs=["add1_out", "scan_in2"], + outputs=["sum_out"]) id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out1"]) @@ -10709,16 +10715,21 @@ def scan_test(scan_input_axes=[0], keepdims=0, inputs=["sum_out"], outputs=["scan_out2"]) - scan_body = helper.make_graph([add, id, reduce_sum], "scan_body", - [sum_in, next], + scan_body = helper.make_graph([add1, add2, id, reduce_sum], "scan_body", + [sum_in, scan_in1, scan_in2], [sum_out, scan_out1, scan_out2]) init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, [2, 2]) - scan_ins_sh = [2, 2, 2] - scan_ins_sh[scan_input_axes[0]] = 3 - scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, - scan_ins_sh) + scan_ins1_sh = [2, 2, 2] + scan_ins1_sh[scan_input_axes[0]] = 3 + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + scan_ins1_sh) + scan_ins2_sh = [1, 1] + scan_ins2_sh[scan_input_axes[1]] = 3 + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + scan_ins2_sh) + final_state = helper.make_tensor_value_info("final_state", TensorProto.FLOAT, [2, 2]) scan_outs1_sh = [2, 2, 2] @@ -10731,9 +10742,9 @@ def scan_test(scan_input_axes=[0], scan_outs2_sh) node = helper.make_node( "Scan", - inputs=["init_state", "scan_ins"], + inputs=["init_state", "scan_ins1", "scan_ins2"], outputs=["final_state", "scan_outs1", "scan_outs2"], - num_scan_inputs=1, + num_scan_inputs=2, scan_input_axes=scan_input_axes, scan_input_directions=scan_input_directions, scan_output_axes=scan_output_axes, @@ -10741,8 +10752,8 @@ def scan_test(scan_input_axes=[0], body=scan_body, ) - return ([node], [init_state, - scan_ins], [final_state, scan_outs1, scan_outs2]) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs1, scan_outs2]) @onnx_test() @@ -10762,9 +10773,9 @@ def scan_test3(): @onnx_test() def scan_test4(): - return scan_test(scan_input_directions=[1]) + return scan_test(scan_input_directions=[1, 0]) @onnx_test() def scan_test5(): - return scan_test(scan_input_axes=[1]) + return scan_test(scan_input_axes=[2, 1]) diff --git a/test/onnx/scan_test1.onnx b/test/onnx/scan_test1.onnx index 564c8f0768a73b6a5ba812d2af34c17d1ef584b9..ef11c728bdaa5ced4d7c05a3ab30cd3bf389ef4c 100644 GIT binary patch literal 793 zcmZuvOHRWu6eMn-?rTLWBTCArg)FmkmPkmzjxF6~6D75gET~%%N1}TU(sOVM4!~h} zX}f`}v@&60A|s)*L6uH;{pz#so6^xozETJ* zQj_q;yYJdl6!!iI*zrQF{9gihaFV3+P-&orEN3(lf``ax5hRXbxS zzgcb%5$@6#A+=Fv1rG;s-|f6gkZ#ug&ir!+8S{>Ql~)T4+oo$44;O;AjQqwx|;rHlaOkvrEBs zvq!r(1(Ui?oG}<+NIeXxRRvrog4OdI>}EQnR1b%D)nMxfg9cuxd4JuRRH?INfzR5KHC{{{= zMW&T+ya%p5MPcucfSoM(D)=RE2PZ0m+v@{;7h=ifTsb%(rDpCq zbSlH_6nfNR6bL2x6JFNt{5@<~V*vx8$g`1SC8NgIj$Cq8rkh`~tCw6fpeab!*0DE) zeu~+YaW3&EI5;Ps1N~+e71tK`j^W%uxcn3|Ef*=KQQ#vOY;d%Ka$D313{A-3S{ts} zrSRJ9(e6#*q^Xn77z{9^9){Ga1FjOm>IDWnU$BxJ*lOTznzvAt8c)DZ9#BuvTy@Ew JJb5OAvwxpkiHjvEKc&*>0uz@a7h7>@ZhU4Q7Y9_ip%6!6N{V58erbskvtvpMml79T0ItYr zV&PwehgHecsi;05?grmgZaz;X2TpU0-7N9ueWNF3>1*QdzOj;6L{7?%E zN=xDsD^iOc7#tWDFtTgOOBsBs0QCGm+Bsl^Tq3mDn8WVl3;#8WbhQj<$E^Ye-)^D>pn%OTt#B*Y~E PG?@ozDl^>l{Y=RKZ%- z(s8-N#Kn@7pHgXbfr(3zi>3TjCB-m4zqCY&*)b)BONk3E09Rx* zv2e1n6r5`)#K*StxI~d;QZkEDlS?x5^NJl99T+ArV$8CYNBGB3NQg@SXgCkh WaAvsW(g+1cNG<~@V4SSYlnelu%}~Su delta 181 zcmbQq*2k*F!70R5oSc{!Uy@o}Vq*1^h3hOc7guIpW=VW;Nn%MV7YA52Gp~4}nU3^y zCN7qw{FF+gZYC}nF1F&*-1y8qE|$F1ijs-TCX0(fm_mG9JRFQd99&EsOdvcthcQE* zaRDQfmJF9D)VPAulK7O&qSWM)%>2A!2gb?#Oj+`B2-}2&xCDS^@BsBQ!_C>olnekP C2`pkiHjvEKc&*>0uz@a7h7>@ZhU4Q7Y9_ip%6!6N{V58erbskvtvpMml79T0ItYr zV&PwehgHecsi;05?grmgZaz;X2TpU0-7N9ueWNF3>1*QdzOj;6L{7?%E zN=xDsD^iOcm>d`vFtTgOOBsBs0QCGm+Bsl^UJads^kE>R@$l+2>k arg_to_vec(const migraphx::argument& arg) return ret; } -TEST_CASE(scan_test1) +auto scan_test(const std::string& test_file, + migraphx::shape scan_ins1_sh, + migraphx::shape scan_ins2_sh) { - auto prog = migraphx::parse_onnx("scan_test1.onnx"); + auto prog = migraphx::parse_onnx(test_file); prog.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; - std::vector init_state(4, 0); + std::vector init_state(init_state_sh.elements(), 0); pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); - migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; - std::vector scan_ins(12); - std::iota(scan_ins.begin(), scan_ins.end(), 1); - pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + std::vector scan_ins1(scan_ins1_sh.elements()); + std::iota(scan_ins1.begin(), scan_ins1.end(), 1); + pm["scan_ins1"] = migraphx::argument(scan_ins1_sh, scan_ins1.data()); + + std::vector scan_ins2(scan_ins2_sh.elements()); + std::iota(scan_ins2.begin(), scan_ins2.end(), 0); + pm["scan_ins2"] = migraphx::argument(scan_ins2_sh, scan_ins2.data()); auto result = prog.eval(pm); EXPECT(result.size() == 3); + return std::make_tuple(result[0], result[1], result[2]); +} - auto final_state = result[0]; - auto scan_out1 = result[1]; - auto scan_out2 = result[2]; +TEST_CASE(scan_test1) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test1.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); EXPECT(final_state.get_shape() == make_shape({2, 2})); - std::vector final_state_gold{15, 18, 21, 24}; + std::vector final_state_gold{18, 21, 24, 27}; EXPECT(arg_to_vec(final_state) == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); - std::vector scan_out1_gold{1, 2, 3, 4, 6, 8, 10, 12, 15, 18, 21, 24}; + std::vector scan_out1_gold{1, 2, 3, 4, 7, 9, 11, 13, 18, 21, 24, 27}; EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); - std::vector scan_out2_gold{4, 6, 16, 20, 36, 42}; + std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } TEST_CASE(scan_test2) { - auto prog = migraphx::parse_onnx("scan_test2.onnx"); - prog.compile(migraphx::make_target("ref")); - - migraphx::parameter_map pm; - - migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; - std::vector init_state(4, 0); - pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); - - migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; - std::vector scan_ins(12); - std::iota(scan_ins.begin(), scan_ins.end(), 1); - pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); - - auto result = prog.eval(pm); - EXPECT(result.size() == 3); - - auto final_state = result[0]; - auto scan_out1 = result[1]; - auto scan_out2 = result[2]; + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test2.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); EXPECT(final_state.get_shape() == make_shape({2, 2})); - std::vector final_state_gold{15, 18, 21, 24}; + std::vector final_state_gold{18, 21, 24, 27}; EXPECT(arg_to_vec(final_state) == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); - std::vector scan_out1_gold{15, 18, 21, 24, 6, 8, 10, 12, 1, 2, 3, 4}; + std::vector scan_out1_gold{18, 21, 24, 27, 7, 9, 11, 13, 1, 2, 3, 4}; EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); - std::vector scan_out2_gold{4, 6, 16, 20, 36, 42}; + std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } TEST_CASE(scan_test3) { - auto prog = migraphx::parse_onnx("scan_test3.onnx"); - prog.compile(migraphx::make_target("ref")); - - migraphx::parameter_map pm; - - migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; - std::vector init_state(4, 0); - pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); - - migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; - std::vector scan_ins(12); - std::iota(scan_ins.begin(), scan_ins.end(), 1); - pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); - - auto result = prog.eval(pm); - EXPECT(result.size() == 3); - - auto final_state = result[0]; - auto scan_out1 = result[1]; - auto scan_out2 = result[2]; + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test3.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); EXPECT(final_state.get_shape() == make_shape({2, 2})); - std::vector final_state_gold{15, 18, 21, 24}; + std::vector final_state_gold{18, 21, 24, 27}; EXPECT(arg_to_vec(final_state) == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({2, 3, 2})); - std::vector scan_out1_gold{1, 2, 6, 8, 15, 18, 3, 4, 10, 12, 21, 24}; + std::vector scan_out1_gold{1, 2, 7, 9, 18, 21, 3, 4, 11, 13, 24, 27}; EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({2, 3})); - std::vector scan_out2_gold{36, 16, 4, 42, 20, 6}; + std::vector scan_out2_gold{4, 18, 42, 6, 22, 48}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } TEST_CASE(scan_test4) { - auto prog = migraphx::parse_onnx("scan_test4.onnx"); - prog.compile(migraphx::make_target("ref")); - - migraphx::parameter_map pm; - - migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; - std::vector init_state(4, 0); - pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); - - migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; - std::vector scan_ins(12); - std::iota(scan_ins.begin(), scan_ins.end(), 1); - pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); - - auto result = prog.eval(pm); - EXPECT(result.size() == 3); - - auto final_state = result[0]; - auto scan_out1 = result[1]; - auto scan_out2 = result[2]; + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test4.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); EXPECT(final_state.get_shape() == make_shape({2, 2})); - std::vector final_state_gold{15, 18, 21, 24}; + std::vector final_state_gold{18, 21, 24, 27}; EXPECT(arg_to_vec(final_state) == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); - std::vector scan_out1_gold{9, 10, 11, 12, 14, 16, 18, 20, 15, 18, 21, 24}; + std::vector scan_out1_gold{9, 10, 11, 12, 15, 17, 19, 21, 18, 21, 24, 27}; EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); - std::vector scan_out2_gold{20, 22, 32, 36, 36, 42}; + std::vector scan_out2_gold{20, 22, 34, 38, 42, 48}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } TEST_CASE(scan_test5) { - auto prog = migraphx::parse_onnx("scan_test5.onnx"); - prog.compile(migraphx::make_target("ref")); - - migraphx::parameter_map pm; - - migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; - std::vector init_state(4, 0); - pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); - - migraphx::shape scan_ins_sh{migraphx::shape::float_type, {2, 3, 2}}; - std::vector scan_ins(12); - std::iota(scan_ins.begin(), scan_ins.end(), 1); - pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); - - auto result = prog.eval(pm); - EXPECT(result.size() == 3); - - auto final_state = result[0]; - auto scan_out1 = result[1]; - auto scan_out2 = result[2]; + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test5.onnx", make_shape({2, 2, 3}), make_shape({1, 3})); EXPECT(final_state.get_shape() == make_shape({2, 2})); - std::vector final_state_gold{9, 12, 27, 30}; + std::vector final_state_gold{9, 18, 27, 36}; EXPECT(arg_to_vec(final_state) == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); - std::vector scan_out1_gold{1, 2, 7, 8, 4, 6, 16, 18, 9, 12, 27, 30}; + std::vector scan_out1_gold{1, 4, 7, 10, 4, 10, 16, 22, 9, 18, 27, 36}; EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); - std::vector scan_out2_gold{8, 10, 20, 24, 36, 42}; + std::vector scan_out2_gold{8, 14, 20, 32, 36, 54}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } diff --git a/test/run_loop_test.cpp b/test/run_loop_test.cpp index 7c1f148c17e..c371ef59c63 100644 --- a/test/run_loop_test.cpp +++ b/test/run_loop_test.cpp @@ -156,7 +156,7 @@ struct test_loop_op cpy_args.push_back(migraphx::argument(s_cond)); cpy_args.push_back(migraphx::argument(out_shape)); // run loop - return run_loop(test_loop{max_iterations}, ctx, cpy_args, mods, run); + return run_loop(test_loop{max_iterations}, {}, ctx, cpy_args, mods, run); } }; From 3e574896be01307b876e6bc1605a62acb4edef2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Sat, 30 Mar 2024 21:27:14 +0000 Subject: [PATCH 10/18] Implement additional tests, add comments, fix cppcheck and tidy issues --- src/include/migraphx/run_loop.hpp | 2 +- src/onnx/parse_scan.cpp | 161 ++++++++++++++++++++---------- test/onnx/gen_onnx.py | 10 +- test/onnx/parse/scan_test.cpp | 75 ++++++++++++++ test/onnx/scan_test5.onnx | Bin 793 -> 802 bytes test/onnx/scan_test6.onnx | Bin 0 -> 802 bytes test/onnx/verify/scan_test.cpp | 18 ++++ test/verify/test_scan_slice.cpp | 6 +- 8 files changed, 216 insertions(+), 56 deletions(-) create mode 100644 test/onnx/parse/scan_test.cpp create mode 100644 test/onnx/scan_test6.onnx diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp index d6b3c51bb37..ee80f41f4cb 100644 --- a/src/include/migraphx/run_loop.hpp +++ b/src/include/migraphx/run_loop.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 85b957b14bc..c73fa690987 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -51,42 +51,76 @@ struct parse_scan : op_parser onnx_parser::node_info info, std::vector args) const { + if(parser.opset_version == 8) + MIGRAPHX_THROW("Scan: Opset 8 version not supported"); + check_for_required_attributes(info, {"body", "num_scan_inputs"}); const auto& body_graph = info.attributes["body"].g(); - auto body = parser.prog.create_module(info.name + "_scan"); + auto* body = parser.prog.create_module(info.name + "_scan"); parser.parse_graph(body, body_graph); + // Scan has: + // N + M inputs (N state variables, M scan inputs) + // N + K outputs (N state variables, K scan outputs) + // Same input and output counts apply for body auto body_outs = body->get_returns(); - const auto M = info.attributes["num_scan_inputs"].i(); - const auto N = args.size() - M; - const auto K = body_outs.size() - N; - - if(body->get_parameter_names().size() != N + M) - MIGRAPHX_THROW("Lorem ipsum 1"); - - const auto scan_input_axes = parse_axes(info, "scan_input_axes", M, args.begin() + N, 0); - - size_t num_iters = args[N]->get_shape().lens()[scan_input_axes[0]]; - for(auto i = 1; i < M; ++i) - if(args[N + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) - MIGRAPHX_THROW("Lorem ipsum 2"); - - const auto scan_input_directions = parse_dirs(info, "scan_input_directions", M); - + const auto m = info.attributes["num_scan_inputs"].i(); + const auto n = args.size() - m; + const auto k = body_outs.size() - n; + + std::vector body_params; + transform(body->get_parameter_names(), + std::back_inserter(body_params), + [&](const auto& name) { return body->get_parameter(name); }); + + if(auto num_body_ins = body_params.size(); num_body_ins != n + m) + MIGRAPHX_THROW("Scan: Number of inputs to body {" + std::to_string(num_body_ins) + + "} does not match number of inputs to Scan {" + std::to_string(n + m) + + "}"); + + const auto scan_input_axes = parse_axes(info, "scan_input_axes", m, args.begin() + n, 0); + const auto scan_input_directions = parse_dirs(info, "scan_input_directions", m); const auto scan_output_axes = - parse_axes(info, "scan_output_axes", K, body_outs.begin() + N, 1); - - const auto scan_output_directions = parse_dirs(info, "scan_output_directions", K); - - // TODO check that alt_args shapes match body input parameter shapes + parse_axes(info, "scan_output_axes", k, body_outs.begin() + n, 1); + const auto scan_output_directions = parse_dirs(info, "scan_output_directions", k); + + // Check that scan axes sizes are the same across all scan inputs + size_t num_iters = args[n]->get_shape().lens()[scan_input_axes[0]]; + for(auto i = 1; i < m; ++i) + if(args[n + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) + MIGRAPHX_THROW("Lorem ipsum 1"); + + if(num_iters > parser.max_loop_iterations) + MIGRAPHX_THROW("Scan: Number of required iterations {" + std::to_string(num_iters) + + "} would exceed the maximum iteration limit {" + + std::to_string(parser.max_loop_iterations) + "}"); + + // Check that state variable shapes match between the Scan node and its body attribute + for(auto i = 0; i < n; ++i) + if(args[i]->get_shape() != body_params[i]->get_shape()) + MIGRAPHX_THROW("Scan: State input " + std::to_string(i) + " shape " + + to_string(args[i]->get_shape()) + + " does not match corresponding body input shape " + + to_string(body_params[i]->get_shape())); + + // Check that the shapes of scan inputs sliced across scan input axes match the shapes of + // the body attribute scan inputs + for(auto i = 0; i < m; ++i) + { + auto node_shape = args[i + n]->get_shape(); + auto node_lens = node_shape.lens(); + node_lens.erase(node_lens.begin() + scan_input_axes[i]); + if(body_params[i + n]->get_shape() != shape(node_shape.type(), std::move(node_lens))) + MIGRAPHX_THROW("Lorem ipsum 2"); + } - modify_body(body, args, N, M, scan_input_axes, scan_input_directions); + modify_body(body, args, n, m, scan_input_axes, scan_input_directions); auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); std::vector loop_args{max_iter_lit, cond_lit}; - loop_args.insert(loop_args.end(), args.begin(), args.begin() + N); + loop_args.insert(loop_args.end(), args.begin(), args.begin() + n); auto loop = info.add_instruction(make_op("loop", @@ -96,13 +130,15 @@ struct parse_scan : op_parser {body}); std::vector ret; - ret.reserve(N + K); - for(auto i = 0; i < N; ++i) + ret.reserve(n + k); + for(auto i = 0; i < n; ++i) ret.push_back(info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop)); - for(auto i = 0; i < K; ++i) + for(auto i = 0; i < k; ++i) { - auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + N}}), loop); + auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + n}}), loop); + // Loop scan_outputs are concatenated along axis 0, so it must be transposed to the + // index specified by the corresponding scan_output_axis auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); } @@ -111,11 +147,14 @@ struct parse_scan : op_parser } void check_for_required_attributes(onnx_parser::node_info& info, - std::vector attribute_names) const + const std::vector& attribute_names) const { - for(const auto& name : attribute_names) - if(not contains(info.attributes, name)) - MIGRAPHX_THROW("Scan: " + name + " attribute required"); + auto it = std::find_if( + attribute_names.cbegin(), attribute_names.cend(), [&](const std::string& name) { + return not contains(info.attributes, name); + }); + if(it != attribute_names.cend()) + MIGRAPHX_THROW("Scan: " + *it + " attribute required"); } std::vector parse_vector_attribute(onnx_parser::node_info& info, @@ -136,11 +175,11 @@ struct parse_scan : op_parser } std::vector - parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const + parse_dirs(onnx_parser::node_info& info, const std::string& name, long expected_size) const { auto dirs = parse_vector_attribute(info, name, expected_size); if(dirs.empty()) - return std::vector(expected_size, 0); + return {expected_size, 0}; if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) MIGRAPHX_THROW("Scan: " + name + @@ -160,13 +199,13 @@ struct parse_scan : op_parser std::vector parse_axes(onnx_parser::node_info& info, const std::string& name, - size_t expected_size, + long expected_size, std::vector::iterator ins_begin, size_t rank_offset) const { auto axes = parse_vector_attribute(info, name, expected_size); if(axes.empty()) - return std::vector(expected_size, 0); + return {expected_size, 0}; std::transform(axes.begin(), axes.end(), @@ -179,26 +218,46 @@ struct parse_scan : op_parser return axes; } + // Alter the Scan body to match a body that Loop would expect. + // + // Loop body inputs: iteration_num, condition, loop_state_variables + // Scan body inputs: loop_state_variables, scan_input_slices + // iteration_num and condition parameters are prepended to the Scan body parameter list, while + // scan_input_slices are removed from parameters. + // Instead, scan_inputs are used directly in Scan body(as values from enclosing scope), and + // together with iteration_num passed to the scan_slice operator which produces slices that are + // used instead of the scan_inputs_slices. + // + // Loop body outputs: condition, loop_state_variables, scan_output_slices + // Scan body outputs: loop_state_variables, scan_output_slices + // The inserted Scan body condition parameter is prepended to the Scan body returns void modify_body(module_ref mod, const std::vector& args, - int64_t N, - int64_t M, + int64_t n, + int64_t m, const std::vector& scan_input_axes, const std::vector& scan_input_directions) const { std::vector params; - params.reserve(N + M); - transform(mod->get_parameter_names(), - std::back_inserter(params), - [&](const std::string& name) { return mod->get_parameter(name); }); - + params.reserve(n + m); + auto param_names = mod->get_parameter_names(); + transform(param_names, std::back_inserter(params), [&](const std::string& name) { + return mod->get_parameter(name); + }); + + // iteration_num, condition, and duplicate loop_state_variables are appended to parameters. + // References to the original loop_state_variables in other instructions are then replaced + // with references to the duplicate ones, after which the originals are removed. + // + // References to the scan_input_slices are replaced with references to inserted + // scan_slice->squeeze instructions, after which the scan_input_slices parameters are + // removed. auto iter_param = mod->add_parameter("iter", shape{shape::int64_type}); auto cond_param = mod->add_parameter("cond", shape{shape::bool_type}); std::vector new_params; - new_params.reserve(N); - for(auto i = 0; i < N; ++i) - new_params.push_back( - mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); + new_params.reserve(n); + for(auto i = 0; i < n; ++i) + new_params.push_back(mod->add_parameter(param_names[i], params[i]->get_shape())); for(auto i = 0; i < params.size(); ++i) { @@ -207,11 +266,11 @@ struct parse_scan : op_parser if(not contains(ins->inputs(), params[i])) continue; - auto new_ins = i < N ? new_params[i] : args[i]; - if(i >= N) + auto new_ins = i < n ? new_params[i] : args[i]; + if(i >= n) { - auto scan_axis = scan_input_axes[i - N]; - auto scan_dir = scan_input_directions[i - N]; + auto scan_axis = scan_input_axes[i - n]; + auto scan_dir = scan_input_directions[i - n]; new_ins = mod->insert_instruction( params[i], make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index d6d6f3f0f3c..41b8d6443dc 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10778,4 +10778,12 @@ def scan_test4(): @onnx_test() def scan_test5(): - return scan_test(scan_input_axes=[2, 1]) + return scan_test(scan_input_axes=[2, -1]) + + +@onnx_test() +def scan_test6(): + return scan_test(scan_input_axes=[-2, 0], + scan_input_directions=[0, 1], + scan_output_directions=[1, 1], + scan_output_axes=[2, 1]) diff --git a/test/onnx/parse/scan_test.cpp b/test/onnx/parse/scan_test.cpp new file mode 100644 index 00000000000..fbac2ad6d06 --- /dev/null +++ b/test/onnx/parse/scan_test.cpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "migraphx/common.hpp" +#include "migraphx/literal.hpp" +#include + +TEST_CASE(scan_test) +{ + namespace mgx = migraphx; + migraphx::program prog; + auto* mm = prog.get_main_module(); + auto init_state = mm->add_parameter("init_state", mgx::shape{mgx::shape::float_type, {2, 2}}); + auto scan_ins1 = mm->add_parameter("scan_ins1", mgx::shape{mgx::shape::float_type, {2, 3, 2}}); + auto scan_ins2 = mm->add_parameter("scan_ins2", mgx::shape{mgx::shape::float_type, {3, 1}}); + + auto* body = prog.create_module("Scan_3_scan"); + auto iter = body->add_parameter("iter", mgx::shape{mgx::shape::int64_type}); + auto cond = body->add_parameter("cond", mgx::shape{mgx::shape::bool_type}); + auto sum_in = body->add_parameter("sum_in", mgx::shape{mgx::shape::float_type, {2, 2}}); + + auto scan_in2 = body->add_instruction( + mgx::make_op("scan_slice", {{"axis", 0}, {"direction", 1}}), scan_ins2, iter); + scan_in2 = body->add_instruction(mgx::make_op("squeeze", {{"axes", {0}}}), scan_in2); + auto scan_in1 = body->add_instruction( + mgx::make_op("scan_slice", {{"axis", 1}, {"direction", 0}}), scan_ins1, iter); + scan_in1 = body->add_instruction(mgx::make_op("squeeze", {{"axes", {1}}}), scan_in1); + + auto add1 = mgx::add_common_op(*body, mgx::make_op("add"), {sum_in, scan_in1}); + auto add2 = mgx::add_common_op(*body, mgx::make_op("add"), {add1, scan_in2}); + auto id = body->add_instruction(mgx::make_op("identity"), add2); + auto reduce_sum = body->add_instruction(mgx::make_op("reduce_sum", {{"axes", {0}}}), add2); + reduce_sum = body->add_instruction(mgx::make_op("squeeze", {{"axes", {0}}}), reduce_sum); + body->add_return({cond, add2, id, reduce_sum}); + + auto iter_lit = mm->add_literal(mgx::literal{mgx::shape{mgx::shape::int64_type}, {3}}); + auto cond_lit = mm->add_literal(mgx::literal{mgx::shape{mgx::shape::bool_type}, {true}}); + auto loop = mm->add_instruction( + mgx::make_op("loop", {{"max_iterations", 3}, {"scan_output_directions", {1, 1}}}), + {iter_lit, cond_lit, init_state}, + {body}); + + auto final_state = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 0}}), loop); + auto scan_outs1 = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 1}}), loop); + scan_outs1 = + mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 2, 0}}}), scan_outs1); + auto scan_outs2 = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 2}}), loop); + scan_outs2 = + mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 0}}}), scan_outs2); + mm->add_return({final_state, scan_outs1, scan_outs2}); + + auto prog_gold = migraphx::parse_onnx("scan_test6.onnx"); + EXPECT(prog == prog_gold); +} diff --git a/test/onnx/scan_test5.onnx b/test/onnx/scan_test5.onnx index f5a874c02025da42ae2b0f6c068294f2e6ad86a2..9baa839788507213d0cd17c1006ef7b49f9f2aba 100644 GIT binary patch delta 64 zcmbQqwunubgHwpBI5{ydz9hA{#MG*njqC76;bV+)id_6);mo{((vtYZiqv8UCWrqp Ku=yRMHX{JKkQ#{q delta 55 zcmZ3)Hj_=5gHwpBI5{ydz9hA{#MG*VjceCN;bV-V5?uUX;mo{((vtYZiqv8UCWp=6 H8MPS!&h-&@ diff --git a/test/onnx/scan_test6.onnx b/test/onnx/scan_test6.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2de8cae29d267be8218574823aa6c29676d179d2 GIT binary patch literal 802 zcmZuv%TB^T6tzQz;X;Ir5y1yYm)*4s7bYg?&aLdaXg~H!n>3lZ=iZZZ&m7uj5G?}bYZ2+|oA)Msm;h2q?MFJ$0@?{GRpgXMmvx?`3SQEd zj9|6VkyAOlmN`da`-^}XuK3pdBrt_$v|0L60nKO~Mgz_rAE&Lx9p_*{Q7D>R){4i9 zo-yQimdZoKvFHOAN=v&X?MU3-b6Gz*cmOm2lQEpHJ8h1hy{{y z#+tohE|?ubml})$p(OsjvqGKo2OAWa%K#|OnK)yFqf%FmElcfp*NYGYkY(u)`>lZURm>#ZNGoA&qx;^!%3xu%U7|`a;4CgnGqW^f_{RdCge>y zND570P=zY9M$%9Xm3L7MRDC&bQs&BG4EpF+H*>j>IGtOA8FVYtF=&LBuouO;tyhqy Z8&|+^?z6}5*a%TvzR4bag5Hz<@h>>M%F_S< literal 0 HcmV?d00001 diff --git a/test/onnx/verify/scan_test.cpp b/test/onnx/verify/scan_test.cpp index ae62c11f6c9..4695b9e8faa 100644 --- a/test/onnx/verify/scan_test.cpp +++ b/test/onnx/verify/scan_test.cpp @@ -159,3 +159,21 @@ TEST_CASE(scan_test5) std::vector scan_out2_gold{8, 14, 20, 32, 36, 54}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } + +TEST_CASE(scan_test6) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test6.onnx", make_shape({2, 3, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{12, 15, 30, 33}; + EXPECT(arg_to_vec(final_state) == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({2, 2, 3})); + std::vector scan_out1_gold{12, 7, 3, 15, 9, 4, 30, 19, 9, 33, 21, 10}; + EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({2, 3})); + std::vector scan_out2_gold{42, 26, 12, 48, 30, 14}; + EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); +} diff --git a/test/verify/test_scan_slice.cpp b/test/verify/test_scan_slice.cpp index 295f1dc4eb8..3326ebd8173 100644 --- a/test/verify/test_scan_slice.cpp +++ b/test/verify/test_scan_slice.cpp @@ -27,7 +27,7 @@ #include #include -template +template struct test_scan_slice_base : verify_program { migraphx::program create_program() const @@ -38,10 +38,10 @@ struct test_scan_slice_base : verify_program migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; auto data_param = mm->add_parameter("data", data_sh); migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; - auto idx_lit = mm->add_literal(migraphx::literal{idx_sh, {0}}); + auto idx_lit = mm->add_literal(migraphx::literal{idx_sh, {Idx}}); mm->add_instruction( - migraphx::make_op("scan_slice", {{"axis", axis}, {"direction", direction}}), + migraphx::make_op("scan_slice", {{"axis", Axis}, {"direction", Direction}}), data_param, idx_lit); From 4824f778b60e56b714600e01f980171b93d74b33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Mon, 1 Apr 2024 14:26:32 +0000 Subject: [PATCH 11/18] Implement negative tests --- src/onnx/parse_scan.cpp | 31 ++- test/onnx/gen_onnx.py | 214 ++++++++++++++++++ test/onnx/parse/scan_test.cpp | 69 ++++++ test/onnx/scan_arg_count_mismatch_test.onnx | Bin 0 -> 634 bytes test/onnx/scan_arg_shapes_mismatch_test.onnx | Bin 0 -> 660 bytes .../scan_input_axes_lens_mismatch_test.onnx | Bin 0 -> 674 bytes .../scan_invalid_input_axes_len_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_input_axes_vals_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_input_dirs_len_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_input_dirs_vals_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_output_axes_len_test.onnx | Bin 0 -> 594 bytes .../scan_invalid_output_axes_vals_test.onnx | Bin 0 -> 603 bytes .../scan_invalid_output_dirs_len_test.onnx | Bin 0 -> 594 bytes .../scan_invalid_output_dirs_vals_test.onnx | Bin 0 -> 603 bytes test/py/onnx_backend_test.py | 5 +- 15 files changed, 304 insertions(+), 15 deletions(-) create mode 100644 test/onnx/scan_arg_count_mismatch_test.onnx create mode 100644 test/onnx/scan_arg_shapes_mismatch_test.onnx create mode 100644 test/onnx/scan_input_axes_lens_mismatch_test.onnx create mode 100644 test/onnx/scan_invalid_input_axes_len_test.onnx create mode 100644 test/onnx/scan_invalid_input_axes_vals_test.onnx create mode 100644 test/onnx/scan_invalid_input_dirs_len_test.onnx create mode 100644 test/onnx/scan_invalid_input_dirs_vals_test.onnx create mode 100644 test/onnx/scan_invalid_output_axes_len_test.onnx create mode 100644 test/onnx/scan_invalid_output_axes_vals_test.onnx create mode 100644 test/onnx/scan_invalid_output_dirs_len_test.onnx create mode 100644 test/onnx/scan_invalid_output_dirs_vals_test.onnx diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index c73fa690987..a846dc654ec 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -85,11 +85,16 @@ struct parse_scan : op_parser parse_axes(info, "scan_output_axes", k, body_outs.begin() + n, 1); const auto scan_output_directions = parse_dirs(info, "scan_output_directions", k); - // Check that scan axes sizes are the same across all scan inputs + // Check that scan axes lens are the same across all scan inputs size_t num_iters = args[n]->get_shape().lens()[scan_input_axes[0]]; for(auto i = 1; i < m; ++i) if(args[n + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) - MIGRAPHX_THROW("Lorem ipsum 1"); + MIGRAPHX_THROW( + "Scan: Lengths of scan_input_axes do not match across all scan inputs.\n" + "Scan input shapes: " + + to_string_range( + to_shapes(std::vector(args.begin() + n, args.end()))) + + "\nScan input axes: " + to_string_range(scan_input_axes)); if(num_iters > parser.max_loop_iterations) MIGRAPHX_THROW("Scan: Number of required iterations {" + std::to_string(num_iters) + @@ -99,10 +104,10 @@ struct parse_scan : op_parser // Check that state variable shapes match between the Scan node and its body attribute for(auto i = 0; i < n; ++i) if(args[i]->get_shape() != body_params[i]->get_shape()) - MIGRAPHX_THROW("Scan: State input " + std::to_string(i) + " shape " + + MIGRAPHX_THROW("Scan: State input " + std::to_string(i) + " shape {" + to_string(args[i]->get_shape()) + - " does not match corresponding body input shape " + - to_string(body_params[i]->get_shape())); + "} does not match corresponding body input shape {" + + to_string(body_params[i]->get_shape()) + "}"); // Check that the shapes of scan inputs sliced across scan input axes match the shapes of // the body attribute scan inputs @@ -111,8 +116,12 @@ struct parse_scan : op_parser auto node_shape = args[i + n]->get_shape(); auto node_lens = node_shape.lens(); node_lens.erase(node_lens.begin() + scan_input_axes[i]); - if(body_params[i + n]->get_shape() != shape(node_shape.type(), std::move(node_lens))) - MIGRAPHX_THROW("Lorem ipsum 2"); + auto slice_sh = shape(node_shape.type(), std::move(node_lens)); + if(body_params[i + n]->get_shape() != slice_sh) + MIGRAPHX_THROW("Slice: Sliced scan input " + std::to_string(i) + " shape {" + + to_string(slice_sh) + + "} does not match corresponding body input shape {" + + to_string(body_params[i + n]->get_shape()) + "}"); } modify_body(body, args, n, m, scan_input_axes, scan_input_directions); @@ -188,11 +197,11 @@ struct parse_scan : op_parser return dirs; } - int64_t normalize_axis(int64_t axis, int64_t rank) const + int64_t normalize_axis(int64_t axis, int64_t rank, const std::string& attr_name) const { if(axis < -rank or axis >= rank) - MIGRAPHX_THROW("Axis value {" + to_string(axis) + "} out of range [" + - to_string(-rank) + ", " + to_string(rank) + ")"); + MIGRAPHX_THROW("Scan: " + attr_name + " axis value {" + to_string(axis) + + "} out of range [" + to_string(-rank) + ", " + to_string(rank) + ")"); return axis < 0 ? rank + axis : axis; } @@ -212,7 +221,7 @@ struct parse_scan : op_parser ins_begin, axes.begin(), [&](int64_t axis, instruction_ref arg) { - return normalize_axis(axis, arg->get_shape().ndim() + rank_offset); + return normalize_axis(axis, arg->get_shape().ndim() + rank_offset, name); }); return axes; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 41b8d6443dc..f727992baf4 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10787,3 +10787,217 @@ def scan_test6(): scan_input_directions=[0, 1], scan_output_directions=[1, 1], scan_output_axes=[2, 1]) + + +def scan_negative_test(scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0], + scan_output_directions=[0]): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", [sum_in, scan_in], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + [3, 2, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) + + +@onnx_test() +def scan_invalid_input_axes_len_test(): + return scan_negative_test(scan_input_axes=[0, 0]) + + +@onnx_test() +def scan_invalid_input_dirs_len_test(): + return scan_negative_test(scan_input_directions=[0, 0]) + + +@onnx_test() +def scan_invalid_output_axes_len_test(): + return scan_negative_test(scan_output_axes=[0, 0]) + + +@onnx_test() +def scan_invalid_output_dirs_len_test(): + return scan_negative_test(scan_output_directions=[0, 0]) + + +@onnx_test() +def scan_invalid_input_axes_vals_test(): + return scan_negative_test(scan_input_axes=[3]) + + +@onnx_test() +def scan_invalid_input_dirs_vals_test(): + return scan_negative_test(scan_input_directions=[2]) + + +@onnx_test() +def scan_invalid_output_axes_vals_test(): + return scan_negative_test(scan_output_axes=[-4]) + + +@onnx_test() +def scan_invalid_output_dirs_vals_test(): + return scan_negative_test(scan_output_directions=[-1]) + + +@onnx_test() +def scan_arg_count_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", [sum_in, scan_in1], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [2, 3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) + + +@onnx_test() +def scan_input_axes_lens_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [2, 3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) + + +@onnx_test() +def scan_arg_shapes_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) diff --git a/test/onnx/parse/scan_test.cpp b/test/onnx/parse/scan_test.cpp index fbac2ad6d06..b819bc751a3 100644 --- a/test/onnx/parse/scan_test.cpp +++ b/test/onnx/parse/scan_test.cpp @@ -73,3 +73,72 @@ TEST_CASE(scan_test) auto prog_gold = migraphx::parse_onnx("scan_test6.onnx"); EXPECT(prog == prog_gold); } + +TEST_CASE(scan_invalid_input_axes_len_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_input_axes_len_test.onnx"); }, "scan_input_axes")); +} + +TEST_CASE(scan_invalid_input_dirs_len_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_input_dirs_len_test.onnx"); }, + "scan_input_directions")); +} + +TEST_CASE(scan_invalid_output_axes_len_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_output_axes_len_test.onnx"); }, + "scan_output_axes")); +} + +TEST_CASE(scan_invalid_output_dirs_len_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_output_dirs_len_test.onnx"); }, + "scan_output_directions")); +} + +TEST_CASE(scan_invalid_input_axes_vals_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_input_axes_vals_test.onnx"); }, "scan_input_axes")); +} + +TEST_CASE(scan_invalid_input_dirs_vals_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_input_dirs_vals_test.onnx"); }, + "scan_input_directions")); +} + +TEST_CASE(scan_invalid_output_axes_vals_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_output_axes_vals_test.onnx"); }, + "scan_output_axes")); +} + +TEST_CASE(scan_invalid_output_dirs_vals_test) +{ + EXPECT(test::throws( + [] { migraphx::parse_onnx("scan_invalid_output_dirs_vals_test.onnx"); }, + "scan_output_directions")); +} + +TEST_CASE(scan_arg_count_mismatch_test) +{ + EXPECT(test::throws([] { migraphx::parse_onnx("scan_arg_count_mismatch_test.onnx"); })); +} + +TEST_CASE(scan_arg_shapes_mismatch_test) +{ + EXPECT(test::throws([] { migraphx::parse_onnx("scan_arg_shapes_mismatch_test.onnx"); })); +} + +TEST_CASE(scan_input_axes_lens_mismatch_test) +{ + EXPECT(test::throws([] { migraphx::parse_onnx("scan_input_axes_lens_mismatch_test.onnx"); })); +} diff --git a/test/onnx/scan_arg_count_mismatch_test.onnx b/test/onnx/scan_arg_count_mismatch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c14171c22688671bf490902a4590826617a176c8 GIT binary patch literal 634 zcmaJ;OHRWu5N$qCcc^G&LHol1q{^kyDwx5}65k_hZAS1;99yL`W(#STuabMKE4N zeDcgVeQPPWr4*$UTl-15``-rGagiRvF9UAjSd=+o98l*J;CWq9{H=ALrr<&U#x12r z&|{i$$>gDPuk9kY_rj;JL*#03iUG8acWr^NB$aCT>Z5B_q!U< zH7mpPU097ZLO(LGo9$+mG54?GTtiZ}&l1TBKHwMNG`MQ0>^f;J^kA(6x>CXaZ!OyB w8=DRi-HeM1Fa}2jw^X57LwFh?)ChI54Jr2E&%UQA_Ifpx(DXxt2h%zG1xV$gh5!Hn literal 0 HcmV?d00001 diff --git a/test/onnx/scan_arg_shapes_mismatch_test.onnx b/test/onnx/scan_arg_shapes_mismatch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1226c00bb68a4642bcf4141a7b0ed6b3cc93176c GIT binary patch literal 660 zcmah{yH3L}6s?m$-AhFyBT9G#P{(Y^$W+*v*|J#918RvlRqaajZ}J8F4IjYHN8r4g z5P=vXIX(~joNGKGu8f(WZ1Y0pimg*gGcGfx;uTe?RCiw%yqf@Aa7CqJDuqcU;6nP{ zG9X9KT(H-gLTXCBRdQiHX?yPzffeP+F8Dxj3x{%>Q7!;CPCglzJDR^R?~?>vX#Yt~ zDG~gbq(X7EYurn_&}qHMF-$Qc7z}KbUwF}T{~HU2y6C(+dmxKiOMRQ6tUw3e9A-kN zXfmyJc*iJ^;W;qt_9%Nx<#Mzf6=>{T!)Xsm_$G}N&&3`;1G~cYf!eN-mO>ZieW0(@ z@INCBulnO0CJAr4srrO~ty`VyM$LV*mmQ&BUqy&q`0)^i6}IL&k^Ig4!LU!h0m;Xv Aod5s; literal 0 HcmV?d00001 diff --git a/test/onnx/scan_input_axes_lens_mismatch_test.onnx b/test/onnx/scan_input_axes_lens_mismatch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1ddb7b6cf5511f1a639ec0bb0717273ef0ebffcc GIT binary patch literal 674 zcma)4%T9za6rBP(%yr_B8b>suh9x^#x*FZNcE^VBsIkdVj4g@tH++GA;|I(~xVJny zR1y;xB+&EfJr{d{eV(T*q3o>|N%}!}vL%csCFNx*^4CNVF7CfA_%s11qaxuV6$BQ8 ziihKCgil^5OSe4*_mrv@+_#=(MzBX(I&4U{_*N}p~lU&e>afM%i)8R5ut*94zu+{-xDSiL5 z*Z5lB>2#dGdpth=xOe@vo|gT)|xcIVr(()>M!BqZ}11~k6EAu zY)m{5!uRmLU;DmzO77%76Db$(S;47(zg8;EKA21kCQ`+udie6-(*@vyt5m8?F~CM6 z$=Nd(S<#Kjv>lg~l56ityT(TZUS86}{WpRYoXYA&F96@DiQf$L?`!vwQiwF@zDeib z5WJ&ID6S61ym3py-in;V3?qWU!A|oe*Mqr*@&*}u8{L=->4mmqjZWeP%!O8(Q|o+W zh-&1-A-Tb3xPgJh{V2s?;wd`brXfe+xvcupK8 zSPTr7tHr%;i}eZ6+EJC@=YKpD!yNT5m>^Js-JNJgjUT$1K{;aSH$qEOgE$@4E3@<^!HwL_5)SboBsd+ literal 0 HcmV?d00001 diff --git a/test/onnx/scan_invalid_input_dirs_len_test.onnx b/test/onnx/scan_invalid_input_dirs_len_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05ea46095969cf7e489968f8cdb51eb7d76221b3 GIT binary patch literal 592 zcmah`u};G<6vRm&?xTXph*A(~QOAsoOu)v>mc>eHBUuU#!VaP~e*-I;a%j-d}i zYOCzVn>)9BzO!OSFr|c12x#C|e%_mpsM1m1eoS5X58=SEV@vk4860a=`d4EiX zIK0Pr@h;8nUQ6O^0jC2Tvy^MAYK?5ipFy(aahwag7p5>95-TNW#+gJh{V2s?;wd`brXfe+xvcy3%J zL<|g;mcmd4(GhVT&pD6On8Hn$}(FEaf2 zN$Gsk!hE!YtJ<1H_=vXYH-WIIqJ z>hKN|`Fm+r2MvkR8JzcU46|HYRcW)Lj=g|*$Ko)Tbtg+;(kFU4=lYMLvRS{R;ZL&b gUTqv=%x#RB1^euZ1b&q{3~3Hsz5EaTJ&n(P09Q?%00000 literal 0 HcmV?d00001 diff --git a/test/onnx/scan_invalid_output_axes_len_test.onnx b/test/onnx/scan_invalid_output_axes_len_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0b649d7cb6563103dc7a9de659ba5f7ef8f01ac9 GIT binary patch literal 594 zcma)3u};G<6s(g#+(!kG5mk#CDmr##WGZaTY+0Oohw>$Ni|pc_A`Qy5I{(wsnWUTfSpK@ zlP4kbvV+OAg^h>WM32UH`;8zfDt5T}La>HoRX>|a5IZ%AEnuE+!aK$wHAHtOoxf%9 z1LIN)eQ@U1EnDHY5exwi?8%mcxknY2a`rxT=@v2rea8kJZRc<%jne(vHBLq|IEE!lm##*4uHCUAFvXHtJfja<8`^4bN5XUgrz4$<=W4A*Emh}ULe#prtj~J&^=%*IhG`75@d*d;pUSeu lxGwij*~O?p1`#GW#q|6>#WIFpwYqUFVJMpYF@NhtCqI&irV#)D literal 0 HcmV?d00001 diff --git a/test/onnx/scan_invalid_output_dirs_len_test.onnx b/test/onnx/scan_invalid_output_dirs_len_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0739fab0537f8bbb44d2e6fb0b0d706e69e80cd1 GIT binary patch literal 594 zcma)(u};G<6h(0oi2JA@GNKgJ649|EBNMPOvt_Z8+DMjygRq0>%9mu~7x(~vjOWBv zLd3w}!T08#7vD!G?8u&;4QHkyK7l9NGO#O_LAogh%I{`mmg?B>05{>T8I=*x0 zgOJ)PyY+BK%jP>2JAx@Cj6y&Ice3*yKC%jDdHmSA@DIX)V`4-0*E2ZQnDl>*{g^pv zNZ|Bd`VQqHnSz5agxUF24vY;7C``I0{0iW}yY}dhFi|0dpjpGi&@U zLmxz;j5J$kZr##pYsF4rObMglQP1x5vNIo0xuu-_kh*do!hu7_nj9t*I8~^0uf~A% zl+?ttc!%-qz0h|DH3^d$ocC}HvsfEhD#ug5vEhFzOK1JM khJQ-0dj&FxFu^gV=eH^5bNE&3Fs=o3MYBKjD=#|x0pJp*761SM literal 0 HcmV?d00001 diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index f7a20492958..c8638424838 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -161,7 +161,6 @@ def disabled_tests_onnx_1_7_0(backend_test): r'test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cpu') backend_test.exclude(r'test_reversesequence_batch_cpu') backend_test.exclude(r'test_reversesequence_time_cpu') - backend_test.exclude(r'test_scan9_sum_cpu') backend_test.exclude(r'test_scan_sum_cpu') backend_test.exclude(r'test_slice_cpu') backend_test.exclude(r'test_slice_default_axes_cpu') @@ -814,9 +813,7 @@ def create_backend_test(testname=None, target_device=None): c2.set_device(target_device) backend_test = MIGraphXBackendTest(c2, __name__) - if True: - backend_test.include(r'test_scan9_sum_cpu') - elif testname: + if testname: backend_test.include(testname + '.*') else: # Onnx Operator tests From 495965f7028e39d01c92ed32ae3e323d2049505c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Mon, 1 Apr 2024 14:47:11 +0000 Subject: [PATCH 12/18] Update onnx_operators.rst --- docs/dev/onnx_operators.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index 1ac83dabd3b..d11fc00bf0e 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -695,7 +695,13 @@ Operator Support Matrix +--------------------------+-----------+-----------------+------------------------------+ | STFT | ❌ | | | +--------------------------+-----------+-----------------+------------------------------+ -| Scan | 👷 | 👷 | | +| Scan | ✅ | UINT8, UINT16, | ``identity``, | +| | | UINT32, UINT64, | ``sequence`` | +| | | INT8, INT16, | datatypes are | +| | | INT32, INT64, | not supported, | +| | | FP8, FP16, | Number of iterations has | +| | | FP32, FP64 | upper-bound | +| | | | Version 8 not supported | +--------------------------+-----------+-----------------+------------------------------+ | Scatter (deprecated) | ✅ | BOOL, UINT8, | | | | | UINT16, UINT32, | | From 5a106dd19bb330a8687dace358c5e2dcc2116e0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 2 Apr 2024 06:40:54 +0000 Subject: [PATCH 13/18] Fix format issues and test failures --- src/include/migraphx/run_loop.hpp | 6 +++--- src/onnx/parse_celu.cpp | 1 - src/onnx/parse_scan.cpp | 10 +++++----- test/onnx/parse/scan_test.cpp | 2 +- test/verify/test_scan_slice.cpp | 2 +- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp index ee80f41f4cb..5e744bb5a8d 100644 --- a/src/include/migraphx/run_loop.hpp +++ b/src/include/migraphx/run_loop.hpp @@ -103,9 +103,9 @@ argument run_loop(const LoopModel& model, auto output_index = out_param_indices[name]; if(output_index > dep_num) { - int64_t dir = scan_output_directions.empty() - ? 0 - : scan_output_directions[output_index - dep_num - 1]; + int64_t dir = scan_output_directions.empty() + ? 0 + : scan_output_directions[output_index - dep_num - 1]; auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter); const auto& arg = out_args.at(output_index); assert((idx + 1) * ps.bytes() <= arg.get_shape().bytes()); diff --git a/src/onnx/parse_celu.cpp b/src/onnx/parse_celu.cpp index 05c434d9bc2..3bd8fd62e38 100644 --- a/src/onnx/parse_celu.cpp +++ b/src/onnx/parse_celu.cpp @@ -35,7 +35,6 @@ struct parse_celu : op_parser { std::vector operators() const { return {{"Celu"}}; } - instruction_ref parse(const op_desc&, const onnx_parser&, const onnx_parser::node_info& info, diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index a846dc654ec..4097d8e3992 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -249,10 +249,9 @@ struct parse_scan : op_parser { std::vector params; params.reserve(n + m); - auto param_names = mod->get_parameter_names(); - transform(param_names, std::back_inserter(params), [&](const std::string& name) { - return mod->get_parameter(name); - }); + transform(mod->get_parameter_names(), + std::back_inserter(params), + [&](const std::string& name) { return mod->get_parameter(name); }); // iteration_num, condition, and duplicate loop_state_variables are appended to parameters. // References to the original loop_state_variables in other instructions are then replaced @@ -266,7 +265,8 @@ struct parse_scan : op_parser std::vector new_params; new_params.reserve(n); for(auto i = 0; i < n; ++i) - new_params.push_back(mod->add_parameter(param_names[i], params[i]->get_shape())); + new_params.push_back( + mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); for(auto i = 0; i < params.size(); ++i) { diff --git a/test/onnx/parse/scan_test.cpp b/test/onnx/parse/scan_test.cpp index b819bc751a3..86eff89739e 100644 --- a/test/onnx/parse/scan_test.cpp +++ b/test/onnx/parse/scan_test.cpp @@ -38,7 +38,7 @@ TEST_CASE(scan_test) auto* body = prog.create_module("Scan_3_scan"); auto iter = body->add_parameter("iter", mgx::shape{mgx::shape::int64_type}); auto cond = body->add_parameter("cond", mgx::shape{mgx::shape::bool_type}); - auto sum_in = body->add_parameter("sum_in", mgx::shape{mgx::shape::float_type, {2, 2}}); + auto sum_in = body->add_parameter("state_var0", mgx::shape{mgx::shape::float_type, {2, 2}}); auto scan_in2 = body->add_instruction( mgx::make_op("scan_slice", {{"axis", 0}, {"direction", 1}}), scan_ins2, iter); diff --git a/test/verify/test_scan_slice.cpp b/test/verify/test_scan_slice.cpp index 3326ebd8173..08f99938a72 100644 --- a/test/verify/test_scan_slice.cpp +++ b/test/verify/test_scan_slice.cpp @@ -57,6 +57,6 @@ struct test_scan_slice2 : test_scan_slice_base { }; -struct test_scan_slice3: test_scan_slice_base +struct test_scan_slice3 : test_scan_slice_base { }; From 51bf14f67a84cd61eb9de9fe99ea473b5b8c9785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 2 Apr 2024 07:11:58 +0000 Subject: [PATCH 14/18] Add deduction guides to vectors --- test/ref/scan_slice.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/ref/scan_slice.cpp b/test/ref/scan_slice.cpp index 6dfb53d8a60..f5cf617b949 100644 --- a/test/ref/scan_slice.cpp +++ b/test/ref/scan_slice.cpp @@ -76,12 +76,12 @@ TEST_CASE(scan_slice_test_1) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); } TEST_CASE(scan_slice_test_2) @@ -94,12 +94,12 @@ TEST_CASE(scan_slice_test_2) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); + EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); } TEST_CASE(scan_slice_test_3) @@ -112,12 +112,12 @@ TEST_CASE(scan_slice_test_3) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 2, 1})); - EXPECT(arg_to_vec(result) == std::vector{0, 2, 4, 6}); + EXPECT(arg_to_vec(result) == std::vector{0, 2, 4, 6}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 2, 1})); - EXPECT(arg_to_vec(result) == std::vector{1, 3, 5, 7}); + EXPECT(arg_to_vec(result) == std::vector{1, 3, 5, 7}); } TEST_CASE(scan_slice_test_4) @@ -130,12 +130,12 @@ TEST_CASE(scan_slice_test_4) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); } TEST_CASE(scan_slice_test_5) @@ -148,12 +148,12 @@ TEST_CASE(scan_slice_test_5) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); } TEST_CASE(scan_slice_test_6) @@ -166,12 +166,12 @@ TEST_CASE(scan_slice_test_6) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); + EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); + EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); } TEST_CASE(scan_slice_test_7) From b82f193972d837fd32ed6812429b9652df567ba9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 25 Jun 2024 22:31:21 +0000 Subject: [PATCH 15/18] Additional verify test case, code refactoring --- src/include/migraphx/run_loop.hpp | 2 +- src/onnx/parse_scan.cpp | 53 ++++++++++++++---------------- test/onnx/gen_onnx.py | 48 +++++++++++++++++++++++++++ test/onnx/parse/scan_test.cpp | 34 ++++++++----------- test/onnx/scan_test7.onnx | Bin 0 -> 582 bytes test/onnx/verify/scan_test.cpp | 41 +++++++++++++++++++---- test/ref/scan_slice.cpp | 35 ++++++++------------ 7 files changed, 136 insertions(+), 77 deletions(-) create mode 100644 test/onnx/scan_test7.onnx diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp index 5e744bb5a8d..e12b82fd2ff 100644 --- a/src/include/migraphx/run_loop.hpp +++ b/src/include/migraphx/run_loop.hpp @@ -24,7 +24,7 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP #define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP -#include "stringutils.hpp" +#include #include #include #include diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 4097d8e3992..b820ef9be0c 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -21,12 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/argument.hpp" -#include "migraphx/errors.hpp" -#include "migraphx/instruction_ref.hpp" -#include "migraphx/iterator_for.hpp" -#include "migraphx/module_ref.hpp" -#include "migraphx/onnx/onnx_parser.hpp" +#include +#include +#include +#include +#include +#include #include #include #include @@ -74,8 +74,8 @@ struct parse_scan : op_parser std::back_inserter(body_params), [&](const auto& name) { return body->get_parameter(name); }); - if(auto num_body_ins = body_params.size(); num_body_ins != n + m) - MIGRAPHX_THROW("Scan: Number of inputs to body {" + std::to_string(num_body_ins) + + if(auto num_body_params = body_params.size(); num_body_params != n + m) + MIGRAPHX_THROW("Scan: Number of inputs to body {" + std::to_string(num_body_params) + "} does not match number of inputs to Scan {" + std::to_string(n + m) + "}"); @@ -184,11 +184,11 @@ struct parse_scan : op_parser } std::vector - parse_dirs(onnx_parser::node_info& info, const std::string& name, long expected_size) const + parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const { auto dirs = parse_vector_attribute(info, name, expected_size); if(dirs.empty()) - return {expected_size, 0}; + return std::vector(expected_size, 0); if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) MIGRAPHX_THROW("Scan: " + name + @@ -270,25 +270,22 @@ struct parse_scan : op_parser for(auto i = 0; i < params.size(); ++i) { - for(auto ins : iterator_for(*mod)) + if(i < n) + { + mod->replace_instruction(params[i], new_params[i]); + } + else { - if(not contains(ins->inputs(), params[i])) - continue; - - auto new_ins = i < n ? new_params[i] : args[i]; - if(i >= n) - { - auto scan_axis = scan_input_axes[i - n]; - auto scan_dir = scan_input_directions[i - n]; - new_ins = mod->insert_instruction( - params[i], - make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), - new_ins, - iter_param); - new_ins = mod->insert_instruction( - params[i], make_op("squeeze", {{"axes", {scan_axis}}}), new_ins); - } - ins->replace_argument(ins, params[i], new_ins); + auto scan_axis = scan_input_axes[i - n]; + auto scan_dir = scan_input_directions[i - n]; + auto new_ins = mod->insert_instruction( + params[i], + make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), + args[i], + iter_param); + new_ins = mod->insert_instruction( + params[i], make_op("squeeze", {{"axes", {scan_axis}}}), new_ins); + mod->replace_instruction(params[i], new_ins); } mod->remove_instruction(params[i]); } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index e11c95ef71a..91812d37766 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -12446,6 +12446,54 @@ def scan_test6(): scan_output_directions=[1, 1], scan_output_axes=[2, 1]) +@onnx_test() +def scan_test7(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add1 = helper.make_node("Add", + inputs=["sum_in", "scan_in"], + outputs=["add1_out"]) + add2 = helper.make_node("Add", + inputs=["add1_out", "scan_in"], + outputs=["sum_out"]) + id = helper.make_node("Identity", + inputs=["scan_in"], + outputs=["scan_out"]) + + scan_body = helper.make_graph([add1, add2, id], "scan_body", + [sum_in, scan_in], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins_sh = [3, 2, 2] + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + scan_ins_sh) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs_sh = [3, 2, 2] + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + scan_outs_sh) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) + def scan_negative_test(scan_input_axes=[0], scan_input_directions=[0], diff --git a/test/onnx/parse/scan_test.cpp b/test/onnx/parse/scan_test.cpp index 86eff89739e..5ed5e980b71 100644 --- a/test/onnx/parse/scan_test.cpp +++ b/test/onnx/parse/scan_test.cpp @@ -22,8 +22,8 @@ * THE SOFTWARE. */ -#include "migraphx/common.hpp" -#include "migraphx/literal.hpp" +#include +#include #include TEST_CASE(scan_test) @@ -70,75 +70,69 @@ TEST_CASE(scan_test) mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 0}}}), scan_outs2); mm->add_return({final_state, scan_outs1, scan_outs2}); - auto prog_gold = migraphx::parse_onnx("scan_test6.onnx"); + auto prog_gold = read_onnx("scan_test6.onnx"); EXPECT(prog == prog_gold); } TEST_CASE(scan_invalid_input_axes_len_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_input_axes_len_test.onnx"); }, "scan_input_axes")); + [] { read_onnx("scan_invalid_input_axes_len_test.onnx"); }, "scan_input_axes")); } TEST_CASE(scan_invalid_input_dirs_len_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_input_dirs_len_test.onnx"); }, - "scan_input_directions")); + [] { read_onnx("scan_invalid_input_dirs_len_test.onnx"); }, "scan_input_directions")); } TEST_CASE(scan_invalid_output_axes_len_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_output_axes_len_test.onnx"); }, - "scan_output_axes")); + [] { read_onnx("scan_invalid_output_axes_len_test.onnx"); }, "scan_output_axes")); } TEST_CASE(scan_invalid_output_dirs_len_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_output_dirs_len_test.onnx"); }, - "scan_output_directions")); + [] { read_onnx("scan_invalid_output_dirs_len_test.onnx"); }, "scan_output_directions")); } TEST_CASE(scan_invalid_input_axes_vals_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_input_axes_vals_test.onnx"); }, "scan_input_axes")); + [] { read_onnx("scan_invalid_input_axes_vals_test.onnx"); }, "scan_input_axes")); } TEST_CASE(scan_invalid_input_dirs_vals_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_input_dirs_vals_test.onnx"); }, - "scan_input_directions")); + [] { read_onnx("scan_invalid_input_dirs_vals_test.onnx"); }, "scan_input_directions")); } TEST_CASE(scan_invalid_output_axes_vals_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_output_axes_vals_test.onnx"); }, - "scan_output_axes")); + [] { read_onnx("scan_invalid_output_axes_vals_test.onnx"); }, "scan_output_axes")); } TEST_CASE(scan_invalid_output_dirs_vals_test) { EXPECT(test::throws( - [] { migraphx::parse_onnx("scan_invalid_output_dirs_vals_test.onnx"); }, - "scan_output_directions")); + [] { read_onnx("scan_invalid_output_dirs_vals_test.onnx"); }, "scan_output_directions")); } TEST_CASE(scan_arg_count_mismatch_test) { - EXPECT(test::throws([] { migraphx::parse_onnx("scan_arg_count_mismatch_test.onnx"); })); + EXPECT(test::throws([] { read_onnx("scan_arg_count_mismatch_test.onnx"); })); } TEST_CASE(scan_arg_shapes_mismatch_test) { - EXPECT(test::throws([] { migraphx::parse_onnx("scan_arg_shapes_mismatch_test.onnx"); })); + EXPECT(test::throws([] { read_onnx("scan_arg_shapes_mismatch_test.onnx"); })); } TEST_CASE(scan_input_axes_lens_mismatch_test) { - EXPECT(test::throws([] { migraphx::parse_onnx("scan_input_axes_lens_mismatch_test.onnx"); })); + EXPECT(test::throws([] { read_onnx("scan_input_axes_lens_mismatch_test.onnx"); })); } diff --git a/test/onnx/scan_test7.onnx b/test/onnx/scan_test7.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ccda6f8c7a096e3ff270804df8e32b966850c5d4 GIT binary patch literal 582 zcmZuuy-ve07{o~+?necY5v8DppN`pJL1F?nX0|L*QX9zB1om%HzL-*<;c1mvO?m10ue-hC(V6$5a=RV7ue7~r1eLX!O#F6!3-CZis1mrBl) zXRR-O5F|}YS2t?}m#`<7Zxt7iS~bBnrPr07pT`fB!WC=*HG@+A!ogv|CKzP-vyS>|IANMIDL^ z?y#!sf84zw%@`Onf5H`K^^Rzp-%T^4#l?K&As4 Z>u};tNKFoZnhk@x0M|A7NB +#include +#include +#include +#include +#include #include #include #include @@ -48,7 +48,7 @@ auto scan_test(const std::string& test_file, migraphx::shape scan_ins1_sh, migraphx::shape scan_ins2_sh) { - auto prog = migraphx::parse_onnx(test_file); + auto prog = read_onnx(test_file); prog.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; @@ -177,3 +177,30 @@ TEST_CASE(scan_test6) std::vector scan_out2_gold{42, 26, 12, 48, 30, 14}; EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); } + +TEST_CASE(scan_test7) +{ + auto prog = read_onnx("scan_test7.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(init_state_sh.elements(), 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(scan_ins_sh.elements()); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 2); + + EXPECT(result[0].get_shape() == make_shape({2, 2})); + std::vector final_state_gold{30, 36, 42, 48}; + EXPECT(arg_to_vec(result[0]) == final_state_gold); + + EXPECT(result[1].get_shape() == make_shape({3, 2, 2})); + EXPECT(arg_to_vec(result[1]) == scan_ins); +} diff --git a/test/ref/scan_slice.cpp b/test/ref/scan_slice.cpp index f5cf617b949..f65385bf310 100644 --- a/test/ref/scan_slice.cpp +++ b/test/ref/scan_slice.cpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/compile_options.hpp" -#include "migraphx/module.hpp" +#include +#include #include #include #include @@ -37,13 +37,6 @@ static migraphx::shape make_shape(const std::vector& lens) return migraphx::shape{migraphx::shape::int32_type, lens}; } -static std::vector arg_to_vec(const migraphx::argument& arg) -{ - std::vector ret; - arg.visit([&](auto output) { ret.assign(output.begin(), output.end()); }); - return ret; -} - migraphx::program make_scan_slice_program(int64_t axis, int64_t direction) { migraphx::program p; @@ -76,12 +69,12 @@ TEST_CASE(scan_slice_test_1) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); } TEST_CASE(scan_slice_test_2) @@ -94,12 +87,12 @@ TEST_CASE(scan_slice_test_2) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); + EXPECT(result.to_vector() == std::vector{0, 1, 4, 5}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); + EXPECT(result.to_vector() == std::vector{2, 3, 6, 7}); } TEST_CASE(scan_slice_test_3) @@ -112,12 +105,12 @@ TEST_CASE(scan_slice_test_3) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 2, 1})); - EXPECT(arg_to_vec(result) == std::vector{0, 2, 4, 6}); + EXPECT(result.to_vector() == std::vector{0, 2, 4, 6}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 2, 1})); - EXPECT(arg_to_vec(result) == std::vector{1, 3, 5, 7}); + EXPECT(result.to_vector() == std::vector{1, 3, 5, 7}); } TEST_CASE(scan_slice_test_4) @@ -130,12 +123,12 @@ TEST_CASE(scan_slice_test_4) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); } TEST_CASE(scan_slice_test_5) @@ -148,12 +141,12 @@ TEST_CASE(scan_slice_test_5) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{4, 5, 6, 7}); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({1, 2, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 2, 3}); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); } TEST_CASE(scan_slice_test_6) @@ -166,12 +159,12 @@ TEST_CASE(scan_slice_test_6) pm["idx"] = migraphx::argument{idx_sh, &idx}; auto result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{2, 3, 6, 7}); + EXPECT(result.to_vector() == std::vector{2, 3, 6, 7}); idx = 1; result = p.eval(pm).back(); EXPECT(result.get_shape() == make_shape({2, 1, 2})); - EXPECT(arg_to_vec(result) == std::vector{0, 1, 4, 5}); + EXPECT(result.to_vector() == std::vector{0, 1, 4, 5}); } TEST_CASE(scan_slice_test_7) From 2687a5a99196de72ef6a48e7080c4bf76c727d59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 26 Jun 2024 08:22:16 +0000 Subject: [PATCH 16/18] op_shape tests for scan_slice, added comments --- src/include/migraphx/op/loop.hpp | 6 ++-- src/include/migraphx/op/scan_slice.hpp | 1 + src/onnx/parse_scan.cpp | 17 ++++++++-- test/onnx/gen_onnx.py | 16 ++++----- test/onnx/verify/scan_test.cpp | 47 +++++++++++--------------- test/op_shape_test.cpp | 44 ++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 41 deletions(-) diff --git a/src/include/migraphx/op/loop.hpp b/src/include/migraphx/op/loop.hpp index 658368f757f..a1b76181414 100644 --- a/src/include/migraphx/op/loop.hpp +++ b/src/include/migraphx/op/loop.hpp @@ -100,8 +100,8 @@ struct loop void append(const std::vector& iter_state, const std::vector& concatenated_outputs, const std::vector& scan_output_dirs, - int64_t iter, - int64_t iter_num) const + int64_t curr_iter, + int64_t num_iters) const { assert(iter_state.size() == concatenated_outputs.size()); for(auto i : range(iter_state.size())) @@ -110,7 +110,7 @@ struct loop const auto& scan_out = concatenated_outputs.at(i); auto dir = scan_output_dirs.empty() ? 0 : scan_output_dirs[i]; - auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter); + auto idx = (1 - dir) * curr_iter + dir * (num_iters - 1 - curr_iter); auto* in_data = iter_stat.data(); auto* out_data = scan_out.data(); diff --git a/src/include/migraphx/op/scan_slice.hpp b/src/include/migraphx/op/scan_slice.hpp index 6bb5e56c4c7..9de1ad6bc7c 100644 --- a/src/include/migraphx/op/scan_slice.hpp +++ b/src/include/migraphx/op/scan_slice.hpp @@ -78,6 +78,7 @@ struct scan_slice : op_name MIGRAPHX_THROW("ScanSlice: index {" + std::to_string(idx) + "} out of range [0, " + std::to_string(max_idx) + "]"); idx = (1 - direction) * idx + direction * (max_idx - idx); + std::cout << idx << std::endl; auto offset = idx * input_sh.strides().at(axis) * input_sh.type_size(); return {output_shape, [=] { return input.data() + offset; }}; diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index b820ef9be0c..b3989cfe493 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -146,8 +146,18 @@ struct parse_scan : op_parser for(auto i = 0; i < k; ++i) { auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + n}}), loop); - // Loop scan_outputs are concatenated along axis 0, so it must be transposed to the - // index specified by the corresponding scan_output_axis + // Loop concatenates scan axes along axis 0 which is inserted/unsqueezed, e.g. a body + // scan output(from a single iteration) of shape {2, 2} is first expanded to {1, 2, 2}, + // and then concatenated with body scan outputs from previous iterations. For n + // iterations of the loop, this will end up producing a scan output of shape {n, 2, 2}. + // + // The scan_output_axes attribute of Scan can define an axis other than zero as the + // concatenation axis. Using the previous scenario, for a body scan output of + // shape {2,2}, with the scan output axis being 1, it is unsqueezed to {2, 1, 2}. The + // final concatenation is then of shape {2, n, 2}. + // + // Since Loop only concatenates along the unsqueezed axis 0, a transpose is necessary to + // place axis 0 in the appropriate scan_output_axis position auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); } @@ -295,6 +305,9 @@ struct parse_scan : op_parser mod->replace_return(returns); } + // Creates permutation so that axis 0 will be permuted to position axis, while maintaining the + // relative ordering of all the other axes. + // e.g. for rank = 4, axis = 2, the created perm is: [1, 2, 0, 3] std::vector make_perm_for_scan_out(int64_t rank, int64_t axis) const { std::vector perm(rank); diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 91812d37766..614d5c84b03 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -12446,40 +12446,38 @@ def scan_test6(): scan_output_directions=[1, 1], scan_output_axes=[2, 1]) + @onnx_test() def scan_test7(): sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, - [2, 2]) + [2, 2]) sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, [2, 2]) scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, - [2, 2]) + [2, 2]) add1 = helper.make_node("Add", inputs=["sum_in", "scan_in"], outputs=["add1_out"]) add2 = helper.make_node("Add", inputs=["add1_out", "scan_in"], outputs=["sum_out"]) - id = helper.make_node("Identity", - inputs=["scan_in"], - outputs=["scan_out"]) + id = helper.make_node("Identity", inputs=["scan_in"], outputs=["scan_out"]) scan_body = helper.make_graph([add1, add2, id], "scan_body", - [sum_in, scan_in], - [sum_out, scan_out]) + [sum_in, scan_in], [sum_out, scan_out]) init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, [2, 2]) scan_ins_sh = [3, 2, 2] scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, - scan_ins_sh) + scan_ins_sh) final_state = helper.make_tensor_value_info("final_state", TensorProto.FLOAT, [2, 2]) scan_outs_sh = [3, 2, 2] scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, - scan_outs_sh) + scan_outs_sh) node = helper.make_node( "Scan", inputs=["init_state", "scan_ins"], diff --git a/test/onnx/verify/scan_test.cpp b/test/onnx/verify/scan_test.cpp index 2b06ebd02be..899eea5b72c 100644 --- a/test/onnx/verify/scan_test.cpp +++ b/test/onnx/verify/scan_test.cpp @@ -37,13 +37,6 @@ static migraphx::shape make_shape(const std::vector& lens) return migraphx::shape{migraphx::shape::float_type, lens}; } -static std::vector arg_to_vec(const migraphx::argument& arg) -{ - std::vector ret; - arg.visit([&](auto output) { ret.assign(output.begin(), output.end()); }); - return ret; -} - auto scan_test(const std::string& test_file, migraphx::shape scan_ins1_sh, migraphx::shape scan_ins2_sh) @@ -77,15 +70,15 @@ TEST_CASE(scan_test1) EXPECT(final_state.get_shape() == make_shape({2, 2})); std::vector final_state_gold{18, 21, 24, 27}; - EXPECT(arg_to_vec(final_state) == final_state_gold); + EXPECT(final_state.to_vector() == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); std::vector scan_out1_gold{1, 2, 3, 4, 7, 9, 11, 13, 18, 21, 24, 27}; - EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + EXPECT(scan_out1.to_vector() == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; - EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); + EXPECT(scan_out2.to_vector() == scan_out2_gold); } TEST_CASE(scan_test2) @@ -95,15 +88,15 @@ TEST_CASE(scan_test2) EXPECT(final_state.get_shape() == make_shape({2, 2})); std::vector final_state_gold{18, 21, 24, 27}; - EXPECT(arg_to_vec(final_state) == final_state_gold); + EXPECT(final_state.to_vector() == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); std::vector scan_out1_gold{18, 21, 24, 27, 7, 9, 11, 13, 1, 2, 3, 4}; - EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + EXPECT(scan_out1.to_vector() == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; - EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); + EXPECT(scan_out2.to_vector() == scan_out2_gold); } TEST_CASE(scan_test3) @@ -113,15 +106,15 @@ TEST_CASE(scan_test3) EXPECT(final_state.get_shape() == make_shape({2, 2})); std::vector final_state_gold{18, 21, 24, 27}; - EXPECT(arg_to_vec(final_state) == final_state_gold); + EXPECT(final_state.to_vector() == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({2, 3, 2})); std::vector scan_out1_gold{1, 2, 7, 9, 18, 21, 3, 4, 11, 13, 24, 27}; - EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + EXPECT(scan_out1.to_vector() == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({2, 3})); std::vector scan_out2_gold{4, 18, 42, 6, 22, 48}; - EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); + EXPECT(scan_out2.to_vector() == scan_out2_gold); } TEST_CASE(scan_test4) @@ -131,15 +124,15 @@ TEST_CASE(scan_test4) EXPECT(final_state.get_shape() == make_shape({2, 2})); std::vector final_state_gold{18, 21, 24, 27}; - EXPECT(arg_to_vec(final_state) == final_state_gold); + EXPECT(final_state.to_vector() == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); std::vector scan_out1_gold{9, 10, 11, 12, 15, 17, 19, 21, 18, 21, 24, 27}; - EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + EXPECT(scan_out1.to_vector() == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); std::vector scan_out2_gold{20, 22, 34, 38, 42, 48}; - EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); + EXPECT(scan_out2.to_vector() == scan_out2_gold); } TEST_CASE(scan_test5) @@ -149,15 +142,15 @@ TEST_CASE(scan_test5) EXPECT(final_state.get_shape() == make_shape({2, 2})); std::vector final_state_gold{9, 18, 27, 36}; - EXPECT(arg_to_vec(final_state) == final_state_gold); + EXPECT(final_state.to_vector() == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); std::vector scan_out1_gold{1, 4, 7, 10, 4, 10, 16, 22, 9, 18, 27, 36}; - EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + EXPECT(scan_out1.to_vector() == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({3, 2})); std::vector scan_out2_gold{8, 14, 20, 32, 36, 54}; - EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); + EXPECT(scan_out2.to_vector() == scan_out2_gold); } TEST_CASE(scan_test6) @@ -167,15 +160,15 @@ TEST_CASE(scan_test6) EXPECT(final_state.get_shape() == make_shape({2, 2})); std::vector final_state_gold{12, 15, 30, 33}; - EXPECT(arg_to_vec(final_state) == final_state_gold); + EXPECT(final_state.to_vector() == final_state_gold); EXPECT(scan_out1.get_shape() == make_shape({2, 2, 3})); std::vector scan_out1_gold{12, 7, 3, 15, 9, 4, 30, 19, 9, 33, 21, 10}; - EXPECT(arg_to_vec(scan_out1) == scan_out1_gold); + EXPECT(scan_out1.to_vector() == scan_out1_gold); EXPECT(scan_out2.get_shape() == make_shape({2, 3})); std::vector scan_out2_gold{42, 26, 12, 48, 30, 14}; - EXPECT(arg_to_vec(scan_out2) == scan_out2_gold); + EXPECT(scan_out2.to_vector() == scan_out2_gold); } TEST_CASE(scan_test7) @@ -199,8 +192,8 @@ TEST_CASE(scan_test7) EXPECT(result[0].get_shape() == make_shape({2, 2})); std::vector final_state_gold{30, 36, 42, 48}; - EXPECT(arg_to_vec(result[0]) == final_state_gold); + EXPECT(result[0].to_vector() == final_state_gold); EXPECT(result[1].get_shape() == make_shape({3, 2, 2})); - EXPECT(arg_to_vec(result[1]) == scan_ins); + EXPECT(result[1].to_vector() == scan_ins); } diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 61918317fe0..5195bd84929 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -4192,6 +4192,50 @@ TEST_CASE(slice_dyn_shape5) input); } +TEST_CASE(test_scan_slice1) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {1, 3, 4}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", 0}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice2) +{ + migraphx::shape input{migraphx::shape::float_type, {4, 6, 5}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {4, 1, 5}, {30, 5, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", 1}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice3) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 5, 7}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {2, 5, 1}, {35, 7, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", -1}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice4) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 5, 7}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {1, 5, 7}, {35, 7, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", -3}, {"direction", 1}}), + input, + axis_input); +} + TEST_CASE(softmax_dyn0) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}; From 56e7056b7b099542e498c80060559157fa10aa1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 26 Jun 2024 10:07:01 +0000 Subject: [PATCH 17/18] Fix clang tidy issue --- src/onnx/parse_scan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index b3989cfe493..365e69410d5 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -198,7 +198,7 @@ struct parse_scan : op_parser { auto dirs = parse_vector_attribute(info, name, expected_size); if(dirs.empty()) - return std::vector(expected_size, 0); + return std::vector(expected_size, 0); // NOLINT if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) MIGRAPHX_THROW("Scan: " + name + From 84f9ecbd8c3a6c59517f3450a011710bf461c59b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Mon, 22 Jul 2024 06:44:54 +0000 Subject: [PATCH 18/18] Fix braced initializer issue, remove a cout --- src/include/migraphx/op/scan_slice.hpp | 1 - src/onnx/parse_scan.cpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/include/migraphx/op/scan_slice.hpp b/src/include/migraphx/op/scan_slice.hpp index 9de1ad6bc7c..6bb5e56c4c7 100644 --- a/src/include/migraphx/op/scan_slice.hpp +++ b/src/include/migraphx/op/scan_slice.hpp @@ -78,7 +78,6 @@ struct scan_slice : op_name MIGRAPHX_THROW("ScanSlice: index {" + std::to_string(idx) + "} out of range [0, " + std::to_string(max_idx) + "]"); idx = (1 - direction) * idx + direction * (max_idx - idx); - std::cout << idx << std::endl; auto offset = idx * input_sh.strides().at(axis) * input_sh.type_size(); return {output_shape, [=] { return input.data() + offset; }}; diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp index 365e69410d5..15db9e1f473 100644 --- a/src/onnx/parse_scan.cpp +++ b/src/onnx/parse_scan.cpp @@ -224,7 +224,7 @@ struct parse_scan : op_parser { auto axes = parse_vector_attribute(info, name, expected_size); if(axes.empty()) - return {expected_size, 0}; + return std::vector(expected_size, 0); // NOLINT std::transform(axes.begin(), axes.end(),