From 904ceafa3b79b638b9276f835f304101690f0759 Mon Sep 17 00:00:00 2001 From: music-dino <111048524+music-dino@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:05:01 +0200 Subject: [PATCH] Add support for Scan operator (#2936) --- docs/dev/onnx_operators.rst | 8 +- src/CMakeLists.txt | 1 + src/include/migraphx/op/loop.hpp | 19 +- src/include/migraphx/op/scan_slice.hpp | 91 +++++ src/include/migraphx/operators.hpp | 1 + src/include/migraphx/run_loop.hpp | 14 +- src/onnx/parse_scan.cpp | 324 ++++++++++++++++ src/targets/gpu/loop.cpp | 11 +- src/targets/gpu/lowering.cpp | 12 + test/onnx/gen_onnx.py | 363 ++++++++++++++++++ test/onnx/parse/scan_test.cpp | 138 +++++++ test/onnx/scan_arg_count_mismatch_test.onnx | Bin 0 -> 634 bytes test/onnx/scan_arg_shapes_mismatch_test.onnx | Bin 0 -> 660 bytes .../scan_input_axes_lens_mismatch_test.onnx | Bin 0 -> 674 bytes .../scan_invalid_input_axes_len_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_input_axes_vals_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_input_dirs_len_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_input_dirs_vals_test.onnx | Bin 0 -> 592 bytes .../scan_invalid_output_axes_len_test.onnx | Bin 0 -> 594 bytes .../scan_invalid_output_axes_vals_test.onnx | Bin 0 -> 603 bytes .../scan_invalid_output_dirs_len_test.onnx | Bin 0 -> 594 bytes .../scan_invalid_output_dirs_vals_test.onnx | Bin 0 -> 603 bytes test/onnx/scan_test1.onnx | Bin 0 -> 793 bytes test/onnx/scan_test2.onnx | Bin 0 -> 793 bytes test/onnx/scan_test3.onnx | Bin 0 -> 802 bytes test/onnx/scan_test4.onnx | Bin 0 -> 793 bytes test/onnx/scan_test5.onnx | Bin 0 -> 802 bytes test/onnx/scan_test6.onnx | Bin 0 -> 802 bytes test/onnx/scan_test7.onnx | Bin 0 -> 582 bytes test/onnx/verify/scan_test.cpp | 199 ++++++++++ test/op_shape_test.cpp | 44 +++ test/py/onnx_backend_test.py | 1 - test/ref/scan_slice.cpp | 180 +++++++++ test/run_loop_test.cpp | 2 +- test/verify/test_scan_slice.cpp | 62 +++ 35 files changed, 1455 insertions(+), 15 deletions(-) create mode 100644 src/include/migraphx/op/scan_slice.hpp create mode 100644 src/onnx/parse_scan.cpp create mode 100644 test/onnx/parse/scan_test.cpp create mode 100644 test/onnx/scan_arg_count_mismatch_test.onnx create mode 100644 test/onnx/scan_arg_shapes_mismatch_test.onnx create mode 100644 test/onnx/scan_input_axes_lens_mismatch_test.onnx create mode 100644 test/onnx/scan_invalid_input_axes_len_test.onnx create mode 100644 test/onnx/scan_invalid_input_axes_vals_test.onnx create mode 100644 test/onnx/scan_invalid_input_dirs_len_test.onnx create mode 100644 test/onnx/scan_invalid_input_dirs_vals_test.onnx create mode 100644 test/onnx/scan_invalid_output_axes_len_test.onnx create mode 100644 test/onnx/scan_invalid_output_axes_vals_test.onnx create mode 100644 test/onnx/scan_invalid_output_dirs_len_test.onnx create mode 100644 test/onnx/scan_invalid_output_dirs_vals_test.onnx create mode 100644 test/onnx/scan_test1.onnx create mode 100644 test/onnx/scan_test2.onnx create mode 100644 test/onnx/scan_test3.onnx create mode 100644 test/onnx/scan_test4.onnx create mode 100644 test/onnx/scan_test5.onnx create mode 100644 test/onnx/scan_test6.onnx create mode 100644 test/onnx/scan_test7.onnx create mode 100644 test/onnx/verify/scan_test.cpp create mode 100644 test/ref/scan_slice.cpp create mode 100644 test/verify/test_scan_slice.cpp diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index 570205c0034..5acfd1923a5 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -705,7 +705,13 @@ Operator Support Matrix +--------------------------+-----------+-----------------+------------------------------+ | STFT | ❌ | | | +--------------------------+-----------+-----------------+------------------------------+ -| Scan | 👷 | 👷 | | +| Scan | ✅ | UINT8, UINT16, | ``identity``, | +| | | UINT32, UINT64, | ``sequence`` | +| | | INT8, INT16, | datatypes are | +| | | INT32, INT64, | not supported, | +| | | FP8, FP16, | Number of iterations has | +| | | FP32, FP64 | upper-bound | +| | | | Version 8 not supported | +--------------------------+-----------+-----------------+------------------------------+ | Scatter (deprecated) | ✅ | BOOL, UINT8, | | | | | UINT16, UINT32, | | diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aa851cef74b..523a6aee561 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -234,6 +234,7 @@ register_migraphx_ops( rsqrt run_on_target scalar + scan_slice scatter_none scatter_add scatter_mul diff --git a/src/include/migraphx/op/loop.hpp b/src/include/migraphx/op/loop.hpp index 969c16ef7cd..a1b76181414 100644 --- a/src/include/migraphx/op/loop.hpp +++ b/src/include/migraphx/op/loop.hpp @@ -41,12 +41,14 @@ namespace op { struct loop { - int64_t max_iterations = 10; + int64_t max_iterations = 10; + std::vector scan_output_directions = {}; template static auto reflect(Self& self, F f) { - return pack(f(self.max_iterations, "max_iterations")); + return pack(f(self.max_iterations, "max_iterations"), + f(self.scan_output_directions, "scan_output_directions")); } std::string name() const { return "loop"; } @@ -97,7 +99,9 @@ struct loop void append(const std::vector& iter_state, const std::vector& concatenated_outputs, - int iter) const + const std::vector& scan_output_dirs, + int64_t curr_iter, + int64_t num_iters) const { assert(iter_state.size() == concatenated_outputs.size()); for(auto i : range(iter_state.size())) @@ -105,11 +109,14 @@ struct loop const auto& iter_stat = iter_state.at(i); const auto& scan_out = concatenated_outputs.at(i); + auto dir = scan_output_dirs.empty() ? 0 : scan_output_dirs[i]; + auto idx = (1 - dir) * curr_iter + dir * (num_iters - 1 - curr_iter); + auto* in_data = iter_stat.data(); auto* out_data = scan_out.data(); std::size_t out_size = iter_stat.get_shape().bytes(); - assert((iter + 1) * out_size <= scan_out.get_shape().bytes()); - std::copy(in_data, in_data + out_size, out_data + iter * out_size); + assert((idx + 1) * out_size <= scan_out.get_shape().bytes()); + std::copy(in_data, in_data + out_size, out_data + idx * out_size); } } @@ -153,7 +160,7 @@ struct loop cpy_args.push_back(argument(out_shape)); // run loop - return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run); + return run_loop(ref_loop{max_iterations}, scan_output_directions, ctx, cpy_args, mods, run); } }; diff --git a/src/include/migraphx/op/scan_slice.hpp b/src/include/migraphx/op/scan_slice.hpp new file mode 100644 index 00000000000..6bb5e56c4c7 --- /dev/null +++ b/src/include/migraphx/op/scan_slice.hpp @@ -0,0 +1,91 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scan_slice : op_name +{ + int64_t axis = 0; + int64_t direction = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axis, "axis"), f(self.direction, "direction")); + } + + value attributes() const + { + value normalize_axes = value::object{}; + normalize_axes["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize_axes}}; + } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + auto input_shape = inputs[0]; + auto new_lens = input_shape.lens(); + new_lens[axis] = 1; + + return shape{input_shape.type(), new_lens, input_shape.strides()}; + } + + argument compute(shape output_shape, std::vector args) const + { + auto input = args[0]; + auto input_sh = input.get_shape(); + + int64_t idx; + args[1].visit([&](auto i) { idx = i.front(); }); + const auto max_idx = input_sh.lens()[axis] - 1; + if(idx > max_idx or idx < 0) + MIGRAPHX_THROW("ScanSlice: index {" + std::to_string(idx) + "} out of range [0, " + + std::to_string(max_idx) + "]"); + idx = (1 - direction) * idx + direction * (max_idx - idx); + + auto offset = idx * input_sh.strides().at(axis) * input_sh.type_size(); + return {output_shape, [=] { return input.data() + offset; }}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index 5166d97eecf..752fe67d7a9 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -113,6 +113,7 @@ #include #include #include +#include #include #include #include diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp index 859d7dfad7f..e12b82fd2ff 100644 --- a/src/include/migraphx/run_loop.hpp +++ b/src/include/migraphx/run_loop.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP #define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP +#include #include #include #include @@ -39,6 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS { template argument run_loop(const LoopModel& model, + const std::vector& scan_output_directions, T& ctx, std::vector args, const std::vector& mods, @@ -101,9 +103,13 @@ argument run_loop(const LoopModel& model, auto output_index = out_param_indices[name]; if(output_index > dep_num) { + int64_t dir = scan_output_directions.empty() + ? 0 + : scan_output_directions[output_index - dep_num - 1]; + auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter); const auto& arg = out_args.at(output_index); - assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes()); - params[name] = argument(ps, arg.data() + iter * ps.bytes()); + assert((idx + 1) * ps.bytes() <= arg.get_shape().bytes()); + params[name] = argument(ps, arg.data() + idx * ps.bytes()); } else { @@ -123,7 +129,7 @@ argument run_loop(const LoopModel& model, std::copy(dep_out.begin(), dep_out.end(), out_args.begin()); std::vector mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end()); - model.append(mod_scan_outs, scan_outputs, iter); + model.append(mod_scan_outs, scan_outputs, scan_output_directions, iter, iter_num); } out_args.erase(out_args.begin()); diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp new file mode 100644 index 00000000000..15db9e1f473 --- /dev/null +++ b/src/onnx/parse_scan.cpp @@ -0,0 +1,324 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_scan : op_parser +{ + std::vector operators() const { return {{"Scan"}}; } + + std::vector parse(const op_desc& /*opd*/, + onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + if(parser.opset_version == 8) + MIGRAPHX_THROW("Scan: Opset 8 version not supported"); + + check_for_required_attributes(info, {"body", "num_scan_inputs"}); + + const auto& body_graph = info.attributes["body"].g(); + auto* body = parser.prog.create_module(info.name + "_scan"); + parser.parse_graph(body, body_graph); + + // Scan has: + // N + M inputs (N state variables, M scan inputs) + // N + K outputs (N state variables, K scan outputs) + // Same input and output counts apply for body + auto body_outs = body->get_returns(); + const auto m = info.attributes["num_scan_inputs"].i(); + const auto n = args.size() - m; + const auto k = body_outs.size() - n; + + std::vector body_params; + transform(body->get_parameter_names(), + std::back_inserter(body_params), + [&](const auto& name) { return body->get_parameter(name); }); + + if(auto num_body_params = body_params.size(); num_body_params != n + m) + MIGRAPHX_THROW("Scan: Number of inputs to body {" + std::to_string(num_body_params) + + "} does not match number of inputs to Scan {" + std::to_string(n + m) + + "}"); + + const auto scan_input_axes = parse_axes(info, "scan_input_axes", m, args.begin() + n, 0); + const auto scan_input_directions = parse_dirs(info, "scan_input_directions", m); + const auto scan_output_axes = + parse_axes(info, "scan_output_axes", k, body_outs.begin() + n, 1); + const auto scan_output_directions = parse_dirs(info, "scan_output_directions", k); + + // Check that scan axes lens are the same across all scan inputs + size_t num_iters = args[n]->get_shape().lens()[scan_input_axes[0]]; + for(auto i = 1; i < m; ++i) + if(args[n + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) + MIGRAPHX_THROW( + "Scan: Lengths of scan_input_axes do not match across all scan inputs.\n" + "Scan input shapes: " + + to_string_range( + to_shapes(std::vector(args.begin() + n, args.end()))) + + "\nScan input axes: " + to_string_range(scan_input_axes)); + + if(num_iters > parser.max_loop_iterations) + MIGRAPHX_THROW("Scan: Number of required iterations {" + std::to_string(num_iters) + + "} would exceed the maximum iteration limit {" + + std::to_string(parser.max_loop_iterations) + "}"); + + // Check that state variable shapes match between the Scan node and its body attribute + for(auto i = 0; i < n; ++i) + if(args[i]->get_shape() != body_params[i]->get_shape()) + MIGRAPHX_THROW("Scan: State input " + std::to_string(i) + " shape {" + + to_string(args[i]->get_shape()) + + "} does not match corresponding body input shape {" + + to_string(body_params[i]->get_shape()) + "}"); + + // Check that the shapes of scan inputs sliced across scan input axes match the shapes of + // the body attribute scan inputs + for(auto i = 0; i < m; ++i) + { + auto node_shape = args[i + n]->get_shape(); + auto node_lens = node_shape.lens(); + node_lens.erase(node_lens.begin() + scan_input_axes[i]); + auto slice_sh = shape(node_shape.type(), std::move(node_lens)); + if(body_params[i + n]->get_shape() != slice_sh) + MIGRAPHX_THROW("Slice: Sliced scan input " + std::to_string(i) + " shape {" + + to_string(slice_sh) + + "} does not match corresponding body input shape {" + + to_string(body_params[i + n]->get_shape()) + "}"); + } + + modify_body(body, args, n, m, scan_input_axes, scan_input_directions); + + auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); + auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); + std::vector loop_args{max_iter_lit, cond_lit}; + loop_args.insert(loop_args.end(), args.begin(), args.begin() + n); + + auto loop = + info.add_instruction(make_op("loop", + {{"max_iterations", num_iters}, + {"scan_output_directions", scan_output_directions}}), + loop_args, + {body}); + + std::vector ret; + ret.reserve(n + k); + for(auto i = 0; i < n; ++i) + ret.push_back(info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop)); + + for(auto i = 0; i < k; ++i) + { + auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + n}}), loop); + // Loop concatenates scan axes along axis 0 which is inserted/unsqueezed, e.g. a body + // scan output(from a single iteration) of shape {2, 2} is first expanded to {1, 2, 2}, + // and then concatenated with body scan outputs from previous iterations. For n + // iterations of the loop, this will end up producing a scan output of shape {n, 2, 2}. + // + // The scan_output_axes attribute of Scan can define an axis other than zero as the + // concatenation axis. Using the previous scenario, for a body scan output of + // shape {2,2}, with the scan output axis being 1, it is unsqueezed to {2, 1, 2}. The + // final concatenation is then of shape {2, n, 2}. + // + // Since Loop only concatenates along the unsqueezed axis 0, a transpose is necessary to + // place axis 0 in the appropriate scan_output_axis position + auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); + ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); + } + + return ret; + } + + void check_for_required_attributes(onnx_parser::node_info& info, + const std::vector& attribute_names) const + { + auto it = std::find_if( + attribute_names.cbegin(), attribute_names.cend(), [&](const std::string& name) { + return not contains(info.attributes, name); + }); + if(it != attribute_names.cend()) + MIGRAPHX_THROW("Scan: " + *it + " attribute required"); + } + + std::vector parse_vector_attribute(onnx_parser::node_info& info, + const std::string& attr_name, + size_t expected_size) const + { + if(not contains(info.attributes, attr_name)) + return {}; + + std::vector res; + auto&& attr = info.attributes[attr_name].ints(); + if(attr.size() != expected_size) + MIGRAPHX_THROW("Scan: " + attr_name + " size is " + to_string(attr.size()) + + ", should be " + to_string(expected_size)); + res.assign(attr.begin(), attr.end()); + + return res; + } + + std::vector + parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const + { + auto dirs = parse_vector_attribute(info, name, expected_size); + if(dirs.empty()) + return std::vector(expected_size, 0); // NOLINT + + if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) + MIGRAPHX_THROW("Scan: " + name + + " may contain only 1s and 0s, actual values: " + to_string_range(dirs)); + + return dirs; + } + + int64_t normalize_axis(int64_t axis, int64_t rank, const std::string& attr_name) const + { + if(axis < -rank or axis >= rank) + MIGRAPHX_THROW("Scan: " + attr_name + " axis value {" + to_string(axis) + + "} out of range [" + to_string(-rank) + ", " + to_string(rank) + ")"); + + return axis < 0 ? rank + axis : axis; + } + + std::vector parse_axes(onnx_parser::node_info& info, + const std::string& name, + long expected_size, + std::vector::iterator ins_begin, + size_t rank_offset) const + { + auto axes = parse_vector_attribute(info, name, expected_size); + if(axes.empty()) + return std::vector(expected_size, 0); // NOLINT + + std::transform(axes.begin(), + axes.end(), + ins_begin, + axes.begin(), + [&](int64_t axis, instruction_ref arg) { + return normalize_axis(axis, arg->get_shape().ndim() + rank_offset, name); + }); + + return axes; + } + + // Alter the Scan body to match a body that Loop would expect. + // + // Loop body inputs: iteration_num, condition, loop_state_variables + // Scan body inputs: loop_state_variables, scan_input_slices + // iteration_num and condition parameters are prepended to the Scan body parameter list, while + // scan_input_slices are removed from parameters. + // Instead, scan_inputs are used directly in Scan body(as values from enclosing scope), and + // together with iteration_num passed to the scan_slice operator which produces slices that are + // used instead of the scan_inputs_slices. + // + // Loop body outputs: condition, loop_state_variables, scan_output_slices + // Scan body outputs: loop_state_variables, scan_output_slices + // The inserted Scan body condition parameter is prepended to the Scan body returns + void modify_body(module_ref mod, + const std::vector& args, + int64_t n, + int64_t m, + const std::vector& scan_input_axes, + const std::vector& scan_input_directions) const + { + std::vector params; + params.reserve(n + m); + transform(mod->get_parameter_names(), + std::back_inserter(params), + [&](const std::string& name) { return mod->get_parameter(name); }); + + // iteration_num, condition, and duplicate loop_state_variables are appended to parameters. + // References to the original loop_state_variables in other instructions are then replaced + // with references to the duplicate ones, after which the originals are removed. + // + // References to the scan_input_slices are replaced with references to inserted + // scan_slice->squeeze instructions, after which the scan_input_slices parameters are + // removed. + auto iter_param = mod->add_parameter("iter", shape{shape::int64_type}); + auto cond_param = mod->add_parameter("cond", shape{shape::bool_type}); + std::vector new_params; + new_params.reserve(n); + for(auto i = 0; i < n; ++i) + new_params.push_back( + mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); + + for(auto i = 0; i < params.size(); ++i) + { + if(i < n) + { + mod->replace_instruction(params[i], new_params[i]); + } + else + { + auto scan_axis = scan_input_axes[i - n]; + auto scan_dir = scan_input_directions[i - n]; + auto new_ins = mod->insert_instruction( + params[i], + make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), + args[i], + iter_param); + new_ins = mod->insert_instruction( + params[i], make_op("squeeze", {{"axes", {scan_axis}}}), new_ins); + mod->replace_instruction(params[i], new_ins); + } + mod->remove_instruction(params[i]); + } + + auto returns = mod->get_returns(); + returns.insert(returns.begin(), cond_param); + mod->replace_return(returns); + } + + // Creates permutation so that axis 0 will be permuted to position axis, while maintaining the + // relative ordering of all the other axes. + // e.g. for rank = 4, axis = 2, the created perm is: [1, 2, 0, 3] + std::vector make_perm_for_scan_out(int64_t rank, int64_t axis) const + { + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::copy(perm.begin() + 1, perm.begin() + 1 + axis, perm.begin()); + perm[axis] = 0; + + return perm; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/loop.cpp b/src/targets/gpu/loop.cpp index ee2731938d7..ad5fc210c7c 100644 --- a/src/targets/gpu/loop.cpp +++ b/src/targets/gpu/loop.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include @@ -56,7 +57,13 @@ struct gpu_loop copy_to_gpu(ctx, arg_src, dst); } - void append(const std::vector&, const std::vector&, int) const {} + void append(const std::vector&, + const std::vector&, + const std::vector&, + int64_t, + int64_t) const + { + } void set_zero(context& ctx, const std::vector& concatenated_outputs, int iter) const { @@ -111,7 +118,7 @@ hip_loop::compute(context& ctx, const std::function( module_ref&, const std::unordered_map&)>& run) const { - return run_loop(gpu_loop{op.max_iterations}, ctx, args, mods, run); + return run_loop(gpu_loop{op.max_iterations}, op.scan_output_directions, ctx, args, mods, run); } } // namespace gpu diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 477b1ebeac9..457aaa1d706 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -119,6 +119,7 @@ struct miopen_apply add_convolution_backwards_op(); add_select_module_op(); add_reshape_lazy_op(); + add_scan_slice_op(); } void copy_params() const @@ -484,6 +485,17 @@ struct miopen_apply return mod->replace_instruction(ins, make_op("gpu::contiguous"), after_contiguous_args); }); } + + void add_scan_slice_op() + { + apply_map.emplace("scan_slice", [=](instruction_ref ins) { + auto inputs = ins->inputs(); + auto cpu_idx = mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), inputs[1]); + inputs[1] = mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_idx); + return mod->replace_instruction( + ins, mod->insert_instruction(ins, ins->get_operator(), inputs)); + }); + } }; void lowering::apply(module_pass_manager& mpm) const diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 8e4f4837d18..94cc1aa8734 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -13107,3 +13107,366 @@ def where_mixed_test(): outputs=['z']) return ([node], [c, x, y], [z]) + + +def scan_test(scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0, 0], + scan_output_directions=[0, 0]): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [1]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out1 = helper.make_tensor_value_info("scan_out1", TensorProto.FLOAT, + [2, 2]) + scan_out2 = helper.make_tensor_value_info("scan_out2", TensorProto.FLOAT, + [2]) + add1 = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["add1_out"]) + add2 = helper.make_node("Add", + inputs=["add1_out", "scan_in2"], + outputs=["sum_out"]) + id = helper.make_node("Identity", + inputs=["sum_out"], + outputs=["scan_out1"]) + reduce_sum = helper.make_node("ReduceSum", + axes=[0], + keepdims=0, + inputs=["sum_out"], + outputs=["scan_out2"]) + scan_body = helper.make_graph([add1, add2, id, reduce_sum], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out1, scan_out2]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1_sh = [2, 2, 2] + scan_ins1_sh[scan_input_axes[0]] = 3 + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + scan_ins1_sh) + scan_ins2_sh = [1, 1] + scan_ins2_sh[scan_input_axes[1]] = 3 + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + scan_ins2_sh) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs1_sh = [2, 2, 2] + scan_outs1_sh[scan_output_axes[0]] = 3 + scan_outs1 = helper.make_tensor_value_info("scan_outs1", TensorProto.FLOAT, + scan_outs1_sh) + scan_outs2_sh = [2, 2] + scan_outs2_sh[scan_output_axes[1]] = 3 + scan_outs2 = helper.make_tensor_value_info("scan_outs2", TensorProto.FLOAT, + scan_outs2_sh) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs1", "scan_outs2"], + num_scan_inputs=2, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + body=scan_body, + ) + + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs1, scan_outs2]) + + +@onnx_test() +def scan_test1(): + return scan_test() + + +@onnx_test() +def scan_test2(): + return scan_test(scan_output_directions=[1, 0]) + + +@onnx_test() +def scan_test3(): + return scan_test(scan_output_axes=[1, -1]) + + +@onnx_test() +def scan_test4(): + return scan_test(scan_input_directions=[1, 0]) + + +@onnx_test() +def scan_test5(): + return scan_test(scan_input_axes=[2, -1]) + + +@onnx_test() +def scan_test6(): + return scan_test(scan_input_axes=[-2, 0], + scan_input_directions=[0, 1], + scan_output_directions=[1, 1], + scan_output_axes=[2, 1]) + + +@onnx_test() +def scan_test7(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add1 = helper.make_node("Add", + inputs=["sum_in", "scan_in"], + outputs=["add1_out"]) + add2 = helper.make_node("Add", + inputs=["add1_out", "scan_in"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["scan_in"], outputs=["scan_out"]) + + scan_body = helper.make_graph([add1, add2, id], "scan_body", + [sum_in, scan_in], [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins_sh = [3, 2, 2] + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + scan_ins_sh) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs_sh = [3, 2, 2] + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + scan_outs_sh) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) + + +def scan_negative_test(scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0], + scan_output_directions=[0]): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", [sum_in, scan_in], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + [3, 2, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) + + +@onnx_test() +def scan_invalid_input_axes_len_test(): + return scan_negative_test(scan_input_axes=[0, 0]) + + +@onnx_test() +def scan_invalid_input_dirs_len_test(): + return scan_negative_test(scan_input_directions=[0, 0]) + + +@onnx_test() +def scan_invalid_output_axes_len_test(): + return scan_negative_test(scan_output_axes=[0, 0]) + + +@onnx_test() +def scan_invalid_output_dirs_len_test(): + return scan_negative_test(scan_output_directions=[0, 0]) + + +@onnx_test() +def scan_invalid_input_axes_vals_test(): + return scan_negative_test(scan_input_axes=[3]) + + +@onnx_test() +def scan_invalid_input_dirs_vals_test(): + return scan_negative_test(scan_input_directions=[2]) + + +@onnx_test() +def scan_invalid_output_axes_vals_test(): + return scan_negative_test(scan_output_axes=[-4]) + + +@onnx_test() +def scan_invalid_output_dirs_vals_test(): + return scan_negative_test(scan_output_directions=[-1]) + + +@onnx_test() +def scan_arg_count_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", [sum_in, scan_in1], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [2, 3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) + + +@onnx_test() +def scan_input_axes_lens_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [2, 3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) + + +@onnx_test() +def scan_arg_shapes_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) diff --git a/test/onnx/parse/scan_test.cpp b/test/onnx/parse/scan_test.cpp new file mode 100644 index 00000000000..5ed5e980b71 --- /dev/null +++ b/test/onnx/parse/scan_test.cpp @@ -0,0 +1,138 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(scan_test) +{ + namespace mgx = migraphx; + migraphx::program prog; + auto* mm = prog.get_main_module(); + auto init_state = mm->add_parameter("init_state", mgx::shape{mgx::shape::float_type, {2, 2}}); + auto scan_ins1 = mm->add_parameter("scan_ins1", mgx::shape{mgx::shape::float_type, {2, 3, 2}}); + auto scan_ins2 = mm->add_parameter("scan_ins2", mgx::shape{mgx::shape::float_type, {3, 1}}); + + auto* body = prog.create_module("Scan_3_scan"); + auto iter = body->add_parameter("iter", mgx::shape{mgx::shape::int64_type}); + auto cond = body->add_parameter("cond", mgx::shape{mgx::shape::bool_type}); + auto sum_in = body->add_parameter("state_var0", mgx::shape{mgx::shape::float_type, {2, 2}}); + + auto scan_in2 = body->add_instruction( + mgx::make_op("scan_slice", {{"axis", 0}, {"direction", 1}}), scan_ins2, iter); + scan_in2 = body->add_instruction(mgx::make_op("squeeze", {{"axes", {0}}}), scan_in2); + auto scan_in1 = body->add_instruction( + mgx::make_op("scan_slice", {{"axis", 1}, {"direction", 0}}), scan_ins1, iter); + scan_in1 = body->add_instruction(mgx::make_op("squeeze", {{"axes", {1}}}), scan_in1); + + auto add1 = mgx::add_common_op(*body, mgx::make_op("add"), {sum_in, scan_in1}); + auto add2 = mgx::add_common_op(*body, mgx::make_op("add"), {add1, scan_in2}); + auto id = body->add_instruction(mgx::make_op("identity"), add2); + auto reduce_sum = body->add_instruction(mgx::make_op("reduce_sum", {{"axes", {0}}}), add2); + reduce_sum = body->add_instruction(mgx::make_op("squeeze", {{"axes", {0}}}), reduce_sum); + body->add_return({cond, add2, id, reduce_sum}); + + auto iter_lit = mm->add_literal(mgx::literal{mgx::shape{mgx::shape::int64_type}, {3}}); + auto cond_lit = mm->add_literal(mgx::literal{mgx::shape{mgx::shape::bool_type}, {true}}); + auto loop = mm->add_instruction( + mgx::make_op("loop", {{"max_iterations", 3}, {"scan_output_directions", {1, 1}}}), + {iter_lit, cond_lit, init_state}, + {body}); + + auto final_state = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 0}}), loop); + auto scan_outs1 = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 1}}), loop); + scan_outs1 = + mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 2, 0}}}), scan_outs1); + auto scan_outs2 = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 2}}), loop); + scan_outs2 = + mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 0}}}), scan_outs2); + mm->add_return({final_state, scan_outs1, scan_outs2}); + + auto prog_gold = read_onnx("scan_test6.onnx"); + EXPECT(prog == prog_gold); +} + +TEST_CASE(scan_invalid_input_axes_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_axes_len_test.onnx"); }, "scan_input_axes")); +} + +TEST_CASE(scan_invalid_input_dirs_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_dirs_len_test.onnx"); }, "scan_input_directions")); +} + +TEST_CASE(scan_invalid_output_axes_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_axes_len_test.onnx"); }, "scan_output_axes")); +} + +TEST_CASE(scan_invalid_output_dirs_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_dirs_len_test.onnx"); }, "scan_output_directions")); +} + +TEST_CASE(scan_invalid_input_axes_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_axes_vals_test.onnx"); }, "scan_input_axes")); +} + +TEST_CASE(scan_invalid_input_dirs_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_dirs_vals_test.onnx"); }, "scan_input_directions")); +} + +TEST_CASE(scan_invalid_output_axes_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_axes_vals_test.onnx"); }, "scan_output_axes")); +} + +TEST_CASE(scan_invalid_output_dirs_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_dirs_vals_test.onnx"); }, "scan_output_directions")); +} + +TEST_CASE(scan_arg_count_mismatch_test) +{ + EXPECT(test::throws([] { read_onnx("scan_arg_count_mismatch_test.onnx"); })); +} + +TEST_CASE(scan_arg_shapes_mismatch_test) +{ + EXPECT(test::throws([] { read_onnx("scan_arg_shapes_mismatch_test.onnx"); })); +} + +TEST_CASE(scan_input_axes_lens_mismatch_test) +{ + EXPECT(test::throws([] { read_onnx("scan_input_axes_lens_mismatch_test.onnx"); })); +} diff --git a/test/onnx/scan_arg_count_mismatch_test.onnx b/test/onnx/scan_arg_count_mismatch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c14171c22688671bf490902a4590826617a176c8 GIT binary patch literal 634 zcmaJ;OHRWu5N$qCcc^G&LHol1q{^kyDwx5}65k_hZAS1;99yL`W(#STuabMKE4N zeDcgVeQPPWr4*$UTl-15``-rGagiRvF9UAjSd=+o98l*J;CWq9{H=ALrr<&U#x12r z&|{i$$>gDPuk9kY_rj;JL*#03iUG8acWr^NB$aCT>Z5B_q!U< zH7mpPU097ZLO(LGo9$+mG54?GTtiZ}&l1TBKHwMNG`MQ0>^f;J^kA(6x>CXaZ!OyB w8=DRi-HeM1Fa}2jw^X57LwFh?)ChI54Jr2E&%UQA_Ifpx(DXxt2h%zG1xV$gh5!Hn literal 0 HcmV?d00001 diff --git a/test/onnx/scan_arg_shapes_mismatch_test.onnx b/test/onnx/scan_arg_shapes_mismatch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1226c00bb68a4642bcf4141a7b0ed6b3cc93176c GIT binary patch literal 660 zcmah{yH3L}6s?m$-AhFyBT9G#P{(Y^$W+*v*|J#918RvlRqaajZ}J8F4IjYHN8r4g z5P=vXIX(~joNGKGu8f(WZ1Y0pimg*gGcGfx;uTe?RCiw%yqf@Aa7CqJDuqcU;6nP{ zG9X9KT(H-gLTXCBRdQiHX?yPzffeP+F8Dxj3x{%>Q7!;CPCglzJDR^R?~?>vX#Yt~ zDG~gbq(X7EYurn_&}qHMF-$Qc7z}KbUwF}T{~HU2y6C(+dmxKiOMRQ6tUw3e9A-kN zXfmyJc*iJ^;W;qt_9%Nx<#Mzf6=>{T!)Xsm_$G}N&&3`;1G~cYf!eN-mO>ZieW0(@ z@INCBulnO0CJAr4srrO~ty`VyM$LV*mmQ&BUqy&q`0)^i6}IL&k^Ig4!LU!h0m;Xv Aod5s; literal 0 HcmV?d00001 diff --git a/test/onnx/scan_input_axes_lens_mismatch_test.onnx b/test/onnx/scan_input_axes_lens_mismatch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1ddb7b6cf5511f1a639ec0bb0717273ef0ebffcc GIT binary patch literal 674 zcma)4%T9za6rBP(%yr_B8b>suh9x^#x*FZNcE^VBsIkdVj4g@tH++GA;|I(~xVJny zR1y;xB+&EfJr{d{eV(T*q3o>|N%}!}vL%csCFNx*^4CNVF7CfA_%s11qaxuV6$BQ8 ziihKCgil^5OSe4*_mrv@+_#=(MzBX(I&4U{_*N}p~lU&e>afM%i)8R5ut*94zu+{-xDSiL5 z*Z5lB>2#dGdpth=xOe@vo|gT)|xcIVr(()>M!BqZ}11~k6EAu zY)m{5!uRmLU;DmzO77%76Db$(S;47(zg8;EKA21kCQ`+udie6-(*@vyt5m8?F~CM6 z$=Nd(S<#Kjv>lg~l56ityT(TZUS86}{WpRYoXYA&F96@DiQf$L?`!vwQiwF@zDeib z5WJ&ID6S61ym3py-in;V3?qWU!A|oe*Mqr*@&*}u8{L=->4mmqjZWeP%!O8(Q|o+W zh-&1-A-Tb3xPgJh{V2s?;wd`brXfe+xvcupK8 zSPTr7tHr%;i}eZ6+EJC@=YKpD!yNT5m>^Js-JNJgjUT$1K{;aSH$qEOgE$@4E3@<^!HwL_5)SboBsd+ literal 0 HcmV?d00001 diff --git a/test/onnx/scan_invalid_input_dirs_len_test.onnx b/test/onnx/scan_invalid_input_dirs_len_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05ea46095969cf7e489968f8cdb51eb7d76221b3 GIT binary patch literal 592 zcmah`u};G<6vRm&?xTXph*A(~QOAsoOu)v>mc>eHBUuU#!VaP~e*-I;a%j-d}i zYOCzVn>)9BzO!OSFr|c12x#C|e%_mpsM1m1eoS5X58=SEV@vk4860a=`d4EiX zIK0Pr@h;8nUQ6O^0jC2Tvy^MAYK?5ipFy(aahwag7p5>95-TNW#+gJh{V2s?;wd`brXfe+xvcy3%J zL<|g;mcmd4(GhVT&pD6On8Hn$}(FEaf2 zN$Gsk!hE!YtJ<1H_=vXYH-WIIqJ z>hKN|`Fm+r2MvkR8JzcU46|HYRcW)Lj=g|*$Ko)Tbtg+;(kFU4=lYMLvRS{R;ZL&b gUTqv=%x#RB1^euZ1b&q{3~3Hsz5EaTJ&n(P09Q?%00000 literal 0 HcmV?d00001 diff --git a/test/onnx/scan_invalid_output_axes_len_test.onnx b/test/onnx/scan_invalid_output_axes_len_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0b649d7cb6563103dc7a9de659ba5f7ef8f01ac9 GIT binary patch literal 594 zcma)3u};G<6s(g#+(!kG5mk#CDmr##WGZaTY+0Oohw>$Ni|pc_A`Qy5I{(wsnWUTfSpK@ zlP4kbvV+OAg^h>WM32UH`;8zfDt5T}La>HoRX>|a5IZ%AEnuE+!aK$wHAHtOoxf%9 z1LIN)eQ@U1EnDHY5exwi?8%mcxknY2a`rxT=@v2rea8kJZRc<%jne(vHBLq|IEE!lm##*4uHCUAFvXHtJfja<8`^4bN5XUgrz4$<=W4A*Emh}ULe#prtj~J&^=%*IhG`75@d*d;pUSeu lxGwij*~O?p1`#GW#q|6>#WIFpwYqUFVJMpYF@NhtCqI&irV#)D literal 0 HcmV?d00001 diff --git a/test/onnx/scan_invalid_output_dirs_len_test.onnx b/test/onnx/scan_invalid_output_dirs_len_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0739fab0537f8bbb44d2e6fb0b0d706e69e80cd1 GIT binary patch literal 594 zcma)(u};G<6h(0oi2JA@GNKgJ649|EBNMPOvt_Z8+DMjygRq0>%9mu~7x(~vjOWBv zLd3w}!T08#7vD!G?8u&;4QHkyK7l9NGO#O_LAogh%I{`mmg?B>05{>T8I=*x0 zgOJ)PyY+BK%jP>2JAx@Cj6y&Ice3*yKC%jDdHmSA@DIX)V`4-0*E2ZQnDl>*{g^pv zNZ|Bd`VQqHnSz5agxUF24vY;7C``I0{0iW}yY}dhFi|0dpjpGi&@U zLmxz;j5J$kZr##pYsF4rObMglQP1x5vNIo0xuu-_kh*do!hu7_nj9t*I8~^0uf~A% zl+?ttc!%-qz0h|DH3^d$ocC}HvsfEhD#ug5vEhFzOK1JM khJQ-0dj&FxFu^gV=eH^5bNE&3Fs=o3MYBKjD=#|x0pJp*761SM literal 0 HcmV?d00001 diff --git a/test/onnx/scan_test1.onnx b/test/onnx/scan_test1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ef11c728bdaa5ced4d7c05a3ab30cd3bf389ef4c GIT binary patch literal 793 zcmZuvOHRWu6eMn-?rTLWBTCArg)FmkmPkmzjxF6~6D75gET~%%N1}TU(sOVM4!~h} zX}f`}v@&60A|s)*L6uH;{pz#so6^xozETJ* zQj_q;yYJdl6!!iI*zrQF{9gihaFV3+P-&orEN3(lf``ax5hRXbxS zzgcb%5$@6#A+=Fv1rG;s-|f6gkZ#ug&ir!+8S{>Ql~)T4+oo$44;O;AjQqwx|;rHlaOkvrEBs zvq!r(1(Ui?oG}<+NIeXxRRvrog4OdI>}EQnR1b%D)nMxfg9cuxd4JuRRH?INfzR5KHC{{{= zMW&T+ya%p5MPcucfSoM(D)=RE2PZ0m+v@{;7h=ifTsb%(rDpCq zbSlH_6nfNR6bL2x6JFNt{5@<~V*vx8$g`1SC8NgIj$Cq8rkh`~tCw6fpeab!*0DE) zeu~+YaW3&EI5;Ps1N~+e71tK`j^W%uxcn3|Ef*=KQQ#vOY;d%Ka$D313{A-3S{ts} zrSRJ9(e6#*q^Xn77z{9^9){Ga1FjOm>IDWnU$BxJ*lOTznzvAt8c)DZ9#BuvTy@Ew JJb5OAvwxTh@{j zEY~`AYiG|kXGrY)5HO=9-+12yZlE2n7lBkjb6S_tkaIV{X}k5nIoMDZ$|jF>lCiR9 z48=FgjgB(_`YObV%SCKqf#jTt zX5T6Vvm@wJgHa%qB(Hy7s&n_SQHgmBfa08qGe&e)>1xnL7rDF1CQZKRv;ymdWNREc zBN!<3Cbx5imtbL?I2H`5T99pM<@`yNE0-}CqFa4*t5rH(Sc4h#YST%B5naJf6q~kQLzZsb1;bgf VCwG|sAwqG*COh650D9e1~7*2lz4G z(kF#QllGo_PR>0&w8KEn!#L1F>HCN87W|k1GL||}I@AI>J5eUowck*;I{ z%e7W+<=wN*845dp1k7m3H{LISJ7}r(B9JlAOx9&IfR+D;$nyAUfb7s|o`$utws zzLgthr_iGYqd+K09{;j%=N@3g0`nLEMV^TqBN~;yYI4z4mTq!MT`#&Q!KxtH8Yj*W z`Y~pc#ks&zG%YhG6J;QFNaQ2@2mZw$@B{o8 zy?vKrlcv*q?ma#C^w2hgXb~u1i%8$zeKz6C1dvK1RydSePGFJ{`&_}nrxy$%$c)2y0LANq(#9(+0dr_?0dIM>?aRm%# V$==An5TdwzlRf&Rg~H!n>3lZ=iZZZ&m7uj5G?}bYZ2+|oA)Msm;h2q?MFJ$0@?{GRpgXMmvx?`3SQEd zj9|6VkyAOlmN`da`-^}XuK3pdBrt_$v|0L60nKO~Mgz_rAE&Lx9p_*{Q7D>R){4i9 zo-yQimdZoKvFHOAN=v&X?MU3-b6Gz*cmOm2lQEpHJ8h1hy{{y z#+tohE|?ubml})$p(OsjvqGKo2OAWa%K#|OnK)yFqf%FmElcfp*NYGYkY(u)`>lZURm>#ZNGoA&qx;^!%3xu%U7|`a;4CgnGqW^f_{RdCge>y zND570P=zY9M$%9Xm3L7MRDC&bQs&BG4EpF+H*>j>IGtOA8FVYtF=&LBuouO;tyhqy Z8&|+^?z6}5*a%TvzR4bag5Hz<@h>>M%F_S< literal 0 HcmV?d00001 diff --git a/test/onnx/scan_test7.onnx b/test/onnx/scan_test7.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ccda6f8c7a096e3ff270804df8e32b966850c5d4 GIT binary patch literal 582 zcmZuuy-ve07{o~+?necY5v8DppN`pJL1F?nX0|L*QX9zB1om%HzL-*<;c1mvO?m10ue-hC(V6$5a=RV7ue7~r1eLX!O#F6!3-CZis1mrBl) zXRR-O5F|}YS2t?}m#`<7Zxt7iS~bBnrPr07pT`fB!WC=*HG@+A!ogv|CKzP-vyS>|IANMIDL^ z?y#!sf84zw%@`Onf5H`K^^Rzp-%T^4#l?K&As4 Z>u};tNKFoZnhk@x0M|A7NB +#include +#include +#include +#include +#include +#include +#include +#include + +static migraphx::shape make_shape(const std::vector& lens) +{ + return migraphx::shape{migraphx::shape::float_type, lens}; +} + +auto scan_test(const std::string& test_file, + migraphx::shape scan_ins1_sh, + migraphx::shape scan_ins2_sh) +{ + auto prog = read_onnx(test_file); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(init_state_sh.elements(), 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + std::vector scan_ins1(scan_ins1_sh.elements()); + std::iota(scan_ins1.begin(), scan_ins1.end(), 1); + pm["scan_ins1"] = migraphx::argument(scan_ins1_sh, scan_ins1.data()); + + std::vector scan_ins2(scan_ins2_sh.elements()); + std::iota(scan_ins2.begin(), scan_ins2.end(), 0); + pm["scan_ins2"] = migraphx::argument(scan_ins2_sh, scan_ins2.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + return std::make_tuple(result[0], result[1], result[2]); +} + +TEST_CASE(scan_test1) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test1.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{1, 2, 3, 4, 7, 9, 11, 13, 18, 21, 24, 27}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test2) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test2.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{18, 21, 24, 27, 7, 9, 11, 13, 1, 2, 3, 4}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test3) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test3.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({2, 3, 2})); + std::vector scan_out1_gold{1, 2, 7, 9, 18, 21, 3, 4, 11, 13, 24, 27}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({2, 3})); + std::vector scan_out2_gold{4, 18, 42, 6, 22, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test4) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test4.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{9, 10, 11, 12, 15, 17, 19, 21, 18, 21, 24, 27}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{20, 22, 34, 38, 42, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test5) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test5.onnx", make_shape({2, 2, 3}), make_shape({1, 3})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{9, 18, 27, 36}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{1, 4, 7, 10, 4, 10, 16, 22, 9, 18, 27, 36}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{8, 14, 20, 32, 36, 54}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test6) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test6.onnx", make_shape({2, 3, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{12, 15, 30, 33}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({2, 2, 3})); + std::vector scan_out1_gold{12, 7, 3, 15, 9, 4, 30, 19, 9, 33, 21, 10}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({2, 3})); + std::vector scan_out2_gold{42, 26, 12, 48, 30, 14}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test7) +{ + auto prog = read_onnx("scan_test7.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(init_state_sh.elements(), 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(scan_ins_sh.elements()); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 2); + + EXPECT(result[0].get_shape() == make_shape({2, 2})); + std::vector final_state_gold{30, 36, 42, 48}; + EXPECT(result[0].to_vector() == final_state_gold); + + EXPECT(result[1].get_shape() == make_shape({3, 2, 2})); + EXPECT(result[1].to_vector() == scan_ins); +} diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 99723493237..5b2c00f243b 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -4239,6 +4239,50 @@ TEST_CASE(slice_dyn_shape5) input); } +TEST_CASE(test_scan_slice1) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {1, 3, 4}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", 0}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice2) +{ + migraphx::shape input{migraphx::shape::float_type, {4, 6, 5}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {4, 1, 5}, {30, 5, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", 1}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice3) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 5, 7}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {2, 5, 1}, {35, 7, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", -1}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice4) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 5, 7}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {1, 5, 7}, {35, 7, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", -3}, {"direction", 1}}), + input, + axis_input); +} + TEST_CASE(softmax_dyn0) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}; diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 3f9867b4072..914ff0ac5b3 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -143,7 +143,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_resize_upsample_sizes_cubic_cpu') backend_test.exclude(r'test_reversesequence_batch_cpu') backend_test.exclude(r'test_reversesequence_time_cpu') - backend_test.exclude(r'test_scan9_sum_cpu') backend_test.exclude(r'test_scan_sum_cpu') backend_test.exclude(r'test_slice_cpu') backend_test.exclude(r'test_slice_end_out_of_bounds_cpu') diff --git a/test/ref/scan_slice.cpp b/test/ref/scan_slice.cpp new file mode 100644 index 00000000000..f65385bf310 --- /dev/null +++ b/test/ref/scan_slice.cpp @@ -0,0 +1,180 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static migraphx::shape make_shape(const std::vector& lens) +{ + return migraphx::shape{migraphx::shape::int32_type, lens}; +} + +migraphx::program make_scan_slice_program(int64_t axis, int64_t direction) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; + std::vector data(data_sh.elements()); + std::iota(data.begin(), data.end(), 0); + auto data_lit = mm->add_literal(migraphx::literal{data_sh, data}); + + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + auto idx_param = mm->add_parameter("idx", idx_sh); + + mm->add_instruction(migraphx::make_op("scan_slice", {{"axis", axis}, {"direction", direction}}), + data_lit, + idx_param); + + p.compile(migraphx::make_target("ref")); + + return p; +} + +TEST_CASE(scan_slice_test_1) +{ + auto p = make_scan_slice_program(0, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); +} + +TEST_CASE(scan_slice_test_2) +{ + auto p = make_scan_slice_program(1, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 4, 5}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{2, 3, 6, 7}); +} + +TEST_CASE(scan_slice_test_3) +{ + auto p = make_scan_slice_program(2, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 1})); + EXPECT(result.to_vector() == std::vector{0, 2, 4, 6}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 1})); + EXPECT(result.to_vector() == std::vector{1, 3, 5, 7}); +} + +TEST_CASE(scan_slice_test_4) +{ + auto p = make_scan_slice_program(-3, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); +} + +TEST_CASE(scan_slice_test_5) +{ + auto p = make_scan_slice_program(0, 1); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); +} + +TEST_CASE(scan_slice_test_6) +{ + auto p = make_scan_slice_program(-2, 1); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{2, 3, 6, 7}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 4, 5}); +} + +TEST_CASE(scan_slice_test_7) +{ + auto p = make_scan_slice_program(0, 0); + + migraphx::parameter_map pm; + int64_t idx = 2; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + + EXPECT(test::throws([&] { p.eval(pm); })); +} diff --git a/test/run_loop_test.cpp b/test/run_loop_test.cpp index 7c1f148c17e..c371ef59c63 100644 --- a/test/run_loop_test.cpp +++ b/test/run_loop_test.cpp @@ -156,7 +156,7 @@ struct test_loop_op cpy_args.push_back(migraphx::argument(s_cond)); cpy_args.push_back(migraphx::argument(out_shape)); // run loop - return run_loop(test_loop{max_iterations}, ctx, cpy_args, mods, run); + return run_loop(test_loop{max_iterations}, {}, ctx, cpy_args, mods, run); } }; diff --git a/test/verify/test_scan_slice.cpp b/test/verify/test_scan_slice.cpp new file mode 100644 index 00000000000..08f99938a72 --- /dev/null +++ b/test/verify/test_scan_slice.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_scan_slice_base : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; + auto data_param = mm->add_parameter("data", data_sh); + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + auto idx_lit = mm->add_literal(migraphx::literal{idx_sh, {Idx}}); + + mm->add_instruction( + migraphx::make_op("scan_slice", {{"axis", Axis}, {"direction", Direction}}), + data_param, + idx_lit); + + return p; + } +}; + +struct test_scan_slice1 : test_scan_slice_base +{ +}; + +struct test_scan_slice2 : test_scan_slice_base +{ +}; + +struct test_scan_slice3 : test_scan_slice_base +{ +};