diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index 570205c0034..5acfd1923a5 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -705,7 +705,13 @@ Operator Support Matrix +--------------------------+-----------+-----------------+------------------------------+ | STFT | ❌ | | | +--------------------------+-----------+-----------------+------------------------------+ -| Scan | 👷 | 👷 | | +| Scan | ✅ | UINT8, UINT16, | ``identity``, | +| | | UINT32, UINT64, | ``sequence`` | +| | | INT8, INT16, | datatypes are | +| | | INT32, INT64, | not supported, | +| | | FP8, FP16, | Number of iterations has | +| | | FP32, FP64 | upper-bound | +| | | | Version 8 not supported | +--------------------------+-----------+-----------------+------------------------------+ | Scatter (deprecated) | ✅ | BOOL, UINT8, | | | | | UINT16, UINT32, | | diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aa851cef74b..523a6aee561 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -234,6 +234,7 @@ register_migraphx_ops( rsqrt run_on_target scalar + scan_slice scatter_none scatter_add scatter_mul diff --git a/src/include/migraphx/op/loop.hpp b/src/include/migraphx/op/loop.hpp index 969c16ef7cd..a1b76181414 100644 --- a/src/include/migraphx/op/loop.hpp +++ b/src/include/migraphx/op/loop.hpp @@ -41,12 +41,14 @@ namespace op { struct loop { - int64_t max_iterations = 10; + int64_t max_iterations = 10; + std::vector scan_output_directions = {}; template static auto reflect(Self& self, F f) { - return pack(f(self.max_iterations, "max_iterations")); + return pack(f(self.max_iterations, "max_iterations"), + f(self.scan_output_directions, "scan_output_directions")); } std::string name() const { return "loop"; } @@ -97,7 +99,9 @@ struct loop void append(const std::vector& iter_state, const std::vector& concatenated_outputs, - int iter) const + const std::vector& scan_output_dirs, + int64_t curr_iter, + int64_t num_iters) const { assert(iter_state.size() == concatenated_outputs.size()); for(auto i : range(iter_state.size())) @@ -105,11 +109,14 @@ struct loop const auto& iter_stat = iter_state.at(i); const auto& scan_out = concatenated_outputs.at(i); + auto dir = scan_output_dirs.empty() ? 0 : scan_output_dirs[i]; + auto idx = (1 - dir) * curr_iter + dir * (num_iters - 1 - curr_iter); + auto* in_data = iter_stat.data(); auto* out_data = scan_out.data(); std::size_t out_size = iter_stat.get_shape().bytes(); - assert((iter + 1) * out_size <= scan_out.get_shape().bytes()); - std::copy(in_data, in_data + out_size, out_data + iter * out_size); + assert((idx + 1) * out_size <= scan_out.get_shape().bytes()); + std::copy(in_data, in_data + out_size, out_data + idx * out_size); } } @@ -153,7 +160,7 @@ struct loop cpy_args.push_back(argument(out_shape)); // run loop - return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run); + return run_loop(ref_loop{max_iterations}, scan_output_directions, ctx, cpy_args, mods, run); } }; diff --git a/src/include/migraphx/op/scan_slice.hpp b/src/include/migraphx/op/scan_slice.hpp new file mode 100644 index 00000000000..6bb5e56c4c7 --- /dev/null +++ b/src/include/migraphx/op/scan_slice.hpp @@ -0,0 +1,91 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scan_slice : op_name +{ + int64_t axis = 0; + int64_t direction = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axis, "axis"), f(self.direction, "direction")); + } + + value attributes() const + { + value normalize_axes = value::object{}; + normalize_axes["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize_axes}}; + } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + auto input_shape = inputs[0]; + auto new_lens = input_shape.lens(); + new_lens[axis] = 1; + + return shape{input_shape.type(), new_lens, input_shape.strides()}; + } + + argument compute(shape output_shape, std::vector args) const + { + auto input = args[0]; + auto input_sh = input.get_shape(); + + int64_t idx; + args[1].visit([&](auto i) { idx = i.front(); }); + const auto max_idx = input_sh.lens()[axis] - 1; + if(idx > max_idx or idx < 0) + MIGRAPHX_THROW("ScanSlice: index {" + std::to_string(idx) + "} out of range [0, " + + std::to_string(max_idx) + "]"); + idx = (1 - direction) * idx + direction * (max_idx - idx); + + auto offset = idx * input_sh.strides().at(axis) * input_sh.type_size(); + return {output_shape, [=] { return input.data() + offset; }}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index 5166d97eecf..752fe67d7a9 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -113,6 +113,7 @@ #include #include #include +#include #include #include #include diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp index 859d7dfad7f..e12b82fd2ff 100644 --- a/src/include/migraphx/run_loop.hpp +++ b/src/include/migraphx/run_loop.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP #define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP +#include #include #include #include @@ -39,6 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS { template argument run_loop(const LoopModel& model, + const std::vector& scan_output_directions, T& ctx, std::vector args, const std::vector& mods, @@ -101,9 +103,13 @@ argument run_loop(const LoopModel& model, auto output_index = out_param_indices[name]; if(output_index > dep_num) { + int64_t dir = scan_output_directions.empty() + ? 0 + : scan_output_directions[output_index - dep_num - 1]; + auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter); const auto& arg = out_args.at(output_index); - assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes()); - params[name] = argument(ps, arg.data() + iter * ps.bytes()); + assert((idx + 1) * ps.bytes() <= arg.get_shape().bytes()); + params[name] = argument(ps, arg.data() + idx * ps.bytes()); } else { @@ -123,7 +129,7 @@ argument run_loop(const LoopModel& model, std::copy(dep_out.begin(), dep_out.end(), out_args.begin()); std::vector mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end()); - model.append(mod_scan_outs, scan_outputs, iter); + model.append(mod_scan_outs, scan_outputs, scan_output_directions, iter, iter_num); } out_args.erase(out_args.begin()); diff --git a/src/onnx/parse_scan.cpp b/src/onnx/parse_scan.cpp new file mode 100644 index 00000000000..15db9e1f473 --- /dev/null +++ b/src/onnx/parse_scan.cpp @@ -0,0 +1,324 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_scan : op_parser +{ + std::vector operators() const { return {{"Scan"}}; } + + std::vector parse(const op_desc& /*opd*/, + onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + if(parser.opset_version == 8) + MIGRAPHX_THROW("Scan: Opset 8 version not supported"); + + check_for_required_attributes(info, {"body", "num_scan_inputs"}); + + const auto& body_graph = info.attributes["body"].g(); + auto* body = parser.prog.create_module(info.name + "_scan"); + parser.parse_graph(body, body_graph); + + // Scan has: + // N + M inputs (N state variables, M scan inputs) + // N + K outputs (N state variables, K scan outputs) + // Same input and output counts apply for body + auto body_outs = body->get_returns(); + const auto m = info.attributes["num_scan_inputs"].i(); + const auto n = args.size() - m; + const auto k = body_outs.size() - n; + + std::vector body_params; + transform(body->get_parameter_names(), + std::back_inserter(body_params), + [&](const auto& name) { return body->get_parameter(name); }); + + if(auto num_body_params = body_params.size(); num_body_params != n + m) + MIGRAPHX_THROW("Scan: Number of inputs to body {" + std::to_string(num_body_params) + + "} does not match number of inputs to Scan {" + std::to_string(n + m) + + "}"); + + const auto scan_input_axes = parse_axes(info, "scan_input_axes", m, args.begin() + n, 0); + const auto scan_input_directions = parse_dirs(info, "scan_input_directions", m); + const auto scan_output_axes = + parse_axes(info, "scan_output_axes", k, body_outs.begin() + n, 1); + const auto scan_output_directions = parse_dirs(info, "scan_output_directions", k); + + // Check that scan axes lens are the same across all scan inputs + size_t num_iters = args[n]->get_shape().lens()[scan_input_axes[0]]; + for(auto i = 1; i < m; ++i) + if(args[n + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) + MIGRAPHX_THROW( + "Scan: Lengths of scan_input_axes do not match across all scan inputs.\n" + "Scan input shapes: " + + to_string_range( + to_shapes(std::vector(args.begin() + n, args.end()))) + + "\nScan input axes: " + to_string_range(scan_input_axes)); + + if(num_iters > parser.max_loop_iterations) + MIGRAPHX_THROW("Scan: Number of required iterations {" + std::to_string(num_iters) + + "} would exceed the maximum iteration limit {" + + std::to_string(parser.max_loop_iterations) + "}"); + + // Check that state variable shapes match between the Scan node and its body attribute + for(auto i = 0; i < n; ++i) + if(args[i]->get_shape() != body_params[i]->get_shape()) + MIGRAPHX_THROW("Scan: State input " + std::to_string(i) + " shape {" + + to_string(args[i]->get_shape()) + + "} does not match corresponding body input shape {" + + to_string(body_params[i]->get_shape()) + "}"); + + // Check that the shapes of scan inputs sliced across scan input axes match the shapes of + // the body attribute scan inputs + for(auto i = 0; i < m; ++i) + { + auto node_shape = args[i + n]->get_shape(); + auto node_lens = node_shape.lens(); + node_lens.erase(node_lens.begin() + scan_input_axes[i]); + auto slice_sh = shape(node_shape.type(), std::move(node_lens)); + if(body_params[i + n]->get_shape() != slice_sh) + MIGRAPHX_THROW("Slice: Sliced scan input " + std::to_string(i) + " shape {" + + to_string(slice_sh) + + "} does not match corresponding body input shape {" + + to_string(body_params[i + n]->get_shape()) + "}"); + } + + modify_body(body, args, n, m, scan_input_axes, scan_input_directions); + + auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); + auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); + std::vector loop_args{max_iter_lit, cond_lit}; + loop_args.insert(loop_args.end(), args.begin(), args.begin() + n); + + auto loop = + info.add_instruction(make_op("loop", + {{"max_iterations", num_iters}, + {"scan_output_directions", scan_output_directions}}), + loop_args, + {body}); + + std::vector ret; + ret.reserve(n + k); + for(auto i = 0; i < n; ++i) + ret.push_back(info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop)); + + for(auto i = 0; i < k; ++i) + { + auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + n}}), loop); + // Loop concatenates scan axes along axis 0 which is inserted/unsqueezed, e.g. a body + // scan output(from a single iteration) of shape {2, 2} is first expanded to {1, 2, 2}, + // and then concatenated with body scan outputs from previous iterations. For n + // iterations of the loop, this will end up producing a scan output of shape {n, 2, 2}. + // + // The scan_output_axes attribute of Scan can define an axis other than zero as the + // concatenation axis. Using the previous scenario, for a body scan output of + // shape {2,2}, with the scan output axis being 1, it is unsqueezed to {2, 1, 2}. The + // final concatenation is then of shape {2, n, 2}. + // + // Since Loop only concatenates along the unsqueezed axis 0, a transpose is necessary to + // place axis 0 in the appropriate scan_output_axis position + auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); + ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); + } + + return ret; + } + + void check_for_required_attributes(onnx_parser::node_info& info, + const std::vector& attribute_names) const + { + auto it = std::find_if( + attribute_names.cbegin(), attribute_names.cend(), [&](const std::string& name) { + return not contains(info.attributes, name); + }); + if(it != attribute_names.cend()) + MIGRAPHX_THROW("Scan: " + *it + " attribute required"); + } + + std::vector parse_vector_attribute(onnx_parser::node_info& info, + const std::string& attr_name, + size_t expected_size) const + { + if(not contains(info.attributes, attr_name)) + return {}; + + std::vector res; + auto&& attr = info.attributes[attr_name].ints(); + if(attr.size() != expected_size) + MIGRAPHX_THROW("Scan: " + attr_name + " size is " + to_string(attr.size()) + + ", should be " + to_string(expected_size)); + res.assign(attr.begin(), attr.end()); + + return res; + } + + std::vector + parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const + { + auto dirs = parse_vector_attribute(info, name, expected_size); + if(dirs.empty()) + return std::vector(expected_size, 0); // NOLINT + + if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) + MIGRAPHX_THROW("Scan: " + name + + " may contain only 1s and 0s, actual values: " + to_string_range(dirs)); + + return dirs; + } + + int64_t normalize_axis(int64_t axis, int64_t rank, const std::string& attr_name) const + { + if(axis < -rank or axis >= rank) + MIGRAPHX_THROW("Scan: " + attr_name + " axis value {" + to_string(axis) + + "} out of range [" + to_string(-rank) + ", " + to_string(rank) + ")"); + + return axis < 0 ? rank + axis : axis; + } + + std::vector parse_axes(onnx_parser::node_info& info, + const std::string& name, + long expected_size, + std::vector::iterator ins_begin, + size_t rank_offset) const + { + auto axes = parse_vector_attribute(info, name, expected_size); + if(axes.empty()) + return std::vector(expected_size, 0); // NOLINT + + std::transform(axes.begin(), + axes.end(), + ins_begin, + axes.begin(), + [&](int64_t axis, instruction_ref arg) { + return normalize_axis(axis, arg->get_shape().ndim() + rank_offset, name); + }); + + return axes; + } + + // Alter the Scan body to match a body that Loop would expect. + // + // Loop body inputs: iteration_num, condition, loop_state_variables + // Scan body inputs: loop_state_variables, scan_input_slices + // iteration_num and condition parameters are prepended to the Scan body parameter list, while + // scan_input_slices are removed from parameters. + // Instead, scan_inputs are used directly in Scan body(as values from enclosing scope), and + // together with iteration_num passed to the scan_slice operator which produces slices that are + // used instead of the scan_inputs_slices. + // + // Loop body outputs: condition, loop_state_variables, scan_output_slices + // Scan body outputs: loop_state_variables, scan_output_slices + // The inserted Scan body condition parameter is prepended to the Scan body returns + void modify_body(module_ref mod, + const std::vector& args, + int64_t n, + int64_t m, + const std::vector& scan_input_axes, + const std::vector& scan_input_directions) const + { + std::vector params; + params.reserve(n + m); + transform(mod->get_parameter_names(), + std::back_inserter(params), + [&](const std::string& name) { return mod->get_parameter(name); }); + + // iteration_num, condition, and duplicate loop_state_variables are appended to parameters. + // References to the original loop_state_variables in other instructions are then replaced + // with references to the duplicate ones, after which the originals are removed. + // + // References to the scan_input_slices are replaced with references to inserted + // scan_slice->squeeze instructions, after which the scan_input_slices parameters are + // removed. + auto iter_param = mod->add_parameter("iter", shape{shape::int64_type}); + auto cond_param = mod->add_parameter("cond", shape{shape::bool_type}); + std::vector new_params; + new_params.reserve(n); + for(auto i = 0; i < n; ++i) + new_params.push_back( + mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); + + for(auto i = 0; i < params.size(); ++i) + { + if(i < n) + { + mod->replace_instruction(params[i], new_params[i]); + } + else + { + auto scan_axis = scan_input_axes[i - n]; + auto scan_dir = scan_input_directions[i - n]; + auto new_ins = mod->insert_instruction( + params[i], + make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), + args[i], + iter_param); + new_ins = mod->insert_instruction( + params[i], make_op("squeeze", {{"axes", {scan_axis}}}), new_ins); + mod->replace_instruction(params[i], new_ins); + } + mod->remove_instruction(params[i]); + } + + auto returns = mod->get_returns(); + returns.insert(returns.begin(), cond_param); + mod->replace_return(returns); + } + + // Creates permutation so that axis 0 will be permuted to position axis, while maintaining the + // relative ordering of all the other axes. + // e.g. for rank = 4, axis = 2, the created perm is: [1, 2, 0, 3] + std::vector make_perm_for_scan_out(int64_t rank, int64_t axis) const + { + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::copy(perm.begin() + 1, perm.begin() + 1 + axis, perm.begin()); + perm[axis] = 0; + + return perm; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/loop.cpp b/src/targets/gpu/loop.cpp index ee2731938d7..ad5fc210c7c 100644 --- a/src/targets/gpu/loop.cpp +++ b/src/targets/gpu/loop.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include #include #include #include @@ -56,7 +57,13 @@ struct gpu_loop copy_to_gpu(ctx, arg_src, dst); } - void append(const std::vector&, const std::vector&, int) const {} + void append(const std::vector&, + const std::vector&, + const std::vector&, + int64_t, + int64_t) const + { + } void set_zero(context& ctx, const std::vector& concatenated_outputs, int iter) const { @@ -111,7 +118,7 @@ hip_loop::compute(context& ctx, const std::function( module_ref&, const std::unordered_map&)>& run) const { - return run_loop(gpu_loop{op.max_iterations}, ctx, args, mods, run); + return run_loop(gpu_loop{op.max_iterations}, op.scan_output_directions, ctx, args, mods, run); } } // namespace gpu diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 477b1ebeac9..457aaa1d706 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -119,6 +119,7 @@ struct miopen_apply add_convolution_backwards_op(); add_select_module_op(); add_reshape_lazy_op(); + add_scan_slice_op(); } void copy_params() const @@ -484,6 +485,17 @@ struct miopen_apply return mod->replace_instruction(ins, make_op("gpu::contiguous"), after_contiguous_args); }); } + + void add_scan_slice_op() + { + apply_map.emplace("scan_slice", [=](instruction_ref ins) { + auto inputs = ins->inputs(); + auto cpu_idx = mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), inputs[1]); + inputs[1] = mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_idx); + return mod->replace_instruction( + ins, mod->insert_instruction(ins, ins->get_operator(), inputs)); + }); + } }; void lowering::apply(module_pass_manager& mpm) const diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 8e4f4837d18..94cc1aa8734 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -13107,3 +13107,366 @@ def where_mixed_test(): outputs=['z']) return ([node], [c, x, y], [z]) + + +def scan_test(scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0, 0], + scan_output_directions=[0, 0]): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [1]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out1 = helper.make_tensor_value_info("scan_out1", TensorProto.FLOAT, + [2, 2]) + scan_out2 = helper.make_tensor_value_info("scan_out2", TensorProto.FLOAT, + [2]) + add1 = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["add1_out"]) + add2 = helper.make_node("Add", + inputs=["add1_out", "scan_in2"], + outputs=["sum_out"]) + id = helper.make_node("Identity", + inputs=["sum_out"], + outputs=["scan_out1"]) + reduce_sum = helper.make_node("ReduceSum", + axes=[0], + keepdims=0, + inputs=["sum_out"], + outputs=["scan_out2"]) + scan_body = helper.make_graph([add1, add2, id, reduce_sum], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out1, scan_out2]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1_sh = [2, 2, 2] + scan_ins1_sh[scan_input_axes[0]] = 3 + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + scan_ins1_sh) + scan_ins2_sh = [1, 1] + scan_ins2_sh[scan_input_axes[1]] = 3 + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + scan_ins2_sh) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs1_sh = [2, 2, 2] + scan_outs1_sh[scan_output_axes[0]] = 3 + scan_outs1 = helper.make_tensor_value_info("scan_outs1", TensorProto.FLOAT, + scan_outs1_sh) + scan_outs2_sh = [2, 2] + scan_outs2_sh[scan_output_axes[1]] = 3 + scan_outs2 = helper.make_tensor_value_info("scan_outs2", TensorProto.FLOAT, + scan_outs2_sh) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs1", "scan_outs2"], + num_scan_inputs=2, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + body=scan_body, + ) + + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs1, scan_outs2]) + + +@onnx_test() +def scan_test1(): + return scan_test() + + +@onnx_test() +def scan_test2(): + return scan_test(scan_output_directions=[1, 0]) + + +@onnx_test() +def scan_test3(): + return scan_test(scan_output_axes=[1, -1]) + + +@onnx_test() +def scan_test4(): + return scan_test(scan_input_directions=[1, 0]) + + +@onnx_test() +def scan_test5(): + return scan_test(scan_input_axes=[2, -1]) + + +@onnx_test() +def scan_test6(): + return scan_test(scan_input_axes=[-2, 0], + scan_input_directions=[0, 1], + scan_output_directions=[1, 1], + scan_output_axes=[2, 1]) + + +@onnx_test() +def scan_test7(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add1 = helper.make_node("Add", + inputs=["sum_in", "scan_in"], + outputs=["add1_out"]) + add2 = helper.make_node("Add", + inputs=["add1_out", "scan_in"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["scan_in"], outputs=["scan_out"]) + + scan_body = helper.make_graph([add1, add2, id], "scan_body", + [sum_in, scan_in], [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins_sh = [3, 2, 2] + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + scan_ins_sh) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs_sh = [3, 2, 2] + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + scan_outs_sh) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) + + +def scan_negative_test(scan_input_axes=[0], + scan_input_directions=[0], + scan_output_axes=[0], + scan_output_directions=[0]): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in = helper.make_tensor_value_info("scan_in", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", [sum_in, scan_in], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins = helper.make_tensor_value_info("scan_ins", TensorProto.FLOAT, + [3, 2, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=1, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + body=scan_body, + ) + + return ([node], [init_state, scan_ins], [final_state, scan_outs]) + + +@onnx_test() +def scan_invalid_input_axes_len_test(): + return scan_negative_test(scan_input_axes=[0, 0]) + + +@onnx_test() +def scan_invalid_input_dirs_len_test(): + return scan_negative_test(scan_input_directions=[0, 0]) + + +@onnx_test() +def scan_invalid_output_axes_len_test(): + return scan_negative_test(scan_output_axes=[0, 0]) + + +@onnx_test() +def scan_invalid_output_dirs_len_test(): + return scan_negative_test(scan_output_directions=[0, 0]) + + +@onnx_test() +def scan_invalid_input_axes_vals_test(): + return scan_negative_test(scan_input_axes=[3]) + + +@onnx_test() +def scan_invalid_input_dirs_vals_test(): + return scan_negative_test(scan_input_directions=[2]) + + +@onnx_test() +def scan_invalid_output_axes_vals_test(): + return scan_negative_test(scan_output_axes=[-4]) + + +@onnx_test() +def scan_invalid_output_dirs_vals_test(): + return scan_negative_test(scan_output_directions=[-1]) + + +@onnx_test() +def scan_arg_count_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", [sum_in, scan_in1], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [2, 3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) + + +@onnx_test() +def scan_input_axes_lens_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [2, 3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) + + +@onnx_test() +def scan_arg_shapes_mismatch_test(): + sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2, 2]) + scan_in1 = helper.make_tensor_value_info("scan_in1", TensorProto.FLOAT, + [2, 2]) + scan_in2 = helper.make_tensor_value_info("scan_in2", TensorProto.FLOAT, + [2, 2]) + sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, + [2, 2]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, + [2, 2]) + add = helper.make_node("Add", + inputs=["sum_in", "scan_in1"], + outputs=["sum_out"]) + id = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) + scan_body = helper.make_graph([add, id], "scan_body", + [sum_in, scan_in1, scan_in2], + [sum_out, scan_out]) + + init_state = helper.make_tensor_value_info("init_state", TensorProto.FLOAT, + [2, 2]) + scan_ins1 = helper.make_tensor_value_info("scan_ins1", TensorProto.FLOAT, + [3, 2, 2]) + scan_ins2 = helper.make_tensor_value_info("scan_ins2", TensorProto.FLOAT, + [3, 2]) + + final_state = helper.make_tensor_value_info("final_state", + TensorProto.FLOAT, [2, 2]) + scan_outs = helper.make_tensor_value_info("scan_outs", TensorProto.FLOAT, + [3, 2, 2]) + node = helper.make_node( + "Scan", + inputs=["init_state", "scan_ins1", "scan_ins2"], + outputs=["final_state", "scan_outs"], + num_scan_inputs=2, + scan_input_axes=[0, 0], + scan_input_directions=[0, 0], + scan_output_axes=[0], + scan_output_directions=[0], + body=scan_body, + ) + return ([node], [init_state, scan_ins1, + scan_ins2], [final_state, scan_outs]) diff --git a/test/onnx/parse/scan_test.cpp b/test/onnx/parse/scan_test.cpp new file mode 100644 index 00000000000..5ed5e980b71 --- /dev/null +++ b/test/onnx/parse/scan_test.cpp @@ -0,0 +1,138 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(scan_test) +{ + namespace mgx = migraphx; + migraphx::program prog; + auto* mm = prog.get_main_module(); + auto init_state = mm->add_parameter("init_state", mgx::shape{mgx::shape::float_type, {2, 2}}); + auto scan_ins1 = mm->add_parameter("scan_ins1", mgx::shape{mgx::shape::float_type, {2, 3, 2}}); + auto scan_ins2 = mm->add_parameter("scan_ins2", mgx::shape{mgx::shape::float_type, {3, 1}}); + + auto* body = prog.create_module("Scan_3_scan"); + auto iter = body->add_parameter("iter", mgx::shape{mgx::shape::int64_type}); + auto cond = body->add_parameter("cond", mgx::shape{mgx::shape::bool_type}); + auto sum_in = body->add_parameter("state_var0", mgx::shape{mgx::shape::float_type, {2, 2}}); + + auto scan_in2 = body->add_instruction( + mgx::make_op("scan_slice", {{"axis", 0}, {"direction", 1}}), scan_ins2, iter); + scan_in2 = body->add_instruction(mgx::make_op("squeeze", {{"axes", {0}}}), scan_in2); + auto scan_in1 = body->add_instruction( + mgx::make_op("scan_slice", {{"axis", 1}, {"direction", 0}}), scan_ins1, iter); + scan_in1 = body->add_instruction(mgx::make_op("squeeze", {{"axes", {1}}}), scan_in1); + + auto add1 = mgx::add_common_op(*body, mgx::make_op("add"), {sum_in, scan_in1}); + auto add2 = mgx::add_common_op(*body, mgx::make_op("add"), {add1, scan_in2}); + auto id = body->add_instruction(mgx::make_op("identity"), add2); + auto reduce_sum = body->add_instruction(mgx::make_op("reduce_sum", {{"axes", {0}}}), add2); + reduce_sum = body->add_instruction(mgx::make_op("squeeze", {{"axes", {0}}}), reduce_sum); + body->add_return({cond, add2, id, reduce_sum}); + + auto iter_lit = mm->add_literal(mgx::literal{mgx::shape{mgx::shape::int64_type}, {3}}); + auto cond_lit = mm->add_literal(mgx::literal{mgx::shape{mgx::shape::bool_type}, {true}}); + auto loop = mm->add_instruction( + mgx::make_op("loop", {{"max_iterations", 3}, {"scan_output_directions", {1, 1}}}), + {iter_lit, cond_lit, init_state}, + {body}); + + auto final_state = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 0}}), loop); + auto scan_outs1 = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 1}}), loop); + scan_outs1 = + mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 2, 0}}}), scan_outs1); + auto scan_outs2 = mm->add_instruction(mgx::make_op("get_tuple_elem", {{"index", 2}}), loop); + scan_outs2 = + mm->add_instruction(mgx::make_op("transpose", {{"permutation", {1, 0}}}), scan_outs2); + mm->add_return({final_state, scan_outs1, scan_outs2}); + + auto prog_gold = read_onnx("scan_test6.onnx"); + EXPECT(prog == prog_gold); +} + +TEST_CASE(scan_invalid_input_axes_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_axes_len_test.onnx"); }, "scan_input_axes")); +} + +TEST_CASE(scan_invalid_input_dirs_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_dirs_len_test.onnx"); }, "scan_input_directions")); +} + +TEST_CASE(scan_invalid_output_axes_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_axes_len_test.onnx"); }, "scan_output_axes")); +} + +TEST_CASE(scan_invalid_output_dirs_len_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_dirs_len_test.onnx"); }, "scan_output_directions")); +} + +TEST_CASE(scan_invalid_input_axes_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_axes_vals_test.onnx"); }, "scan_input_axes")); +} + +TEST_CASE(scan_invalid_input_dirs_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_input_dirs_vals_test.onnx"); }, "scan_input_directions")); +} + +TEST_CASE(scan_invalid_output_axes_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_axes_vals_test.onnx"); }, "scan_output_axes")); +} + +TEST_CASE(scan_invalid_output_dirs_vals_test) +{ + EXPECT(test::throws( + [] { read_onnx("scan_invalid_output_dirs_vals_test.onnx"); }, "scan_output_directions")); +} + +TEST_CASE(scan_arg_count_mismatch_test) +{ + EXPECT(test::throws([] { read_onnx("scan_arg_count_mismatch_test.onnx"); })); +} + +TEST_CASE(scan_arg_shapes_mismatch_test) +{ + EXPECT(test::throws([] { read_onnx("scan_arg_shapes_mismatch_test.onnx"); })); +} + +TEST_CASE(scan_input_axes_lens_mismatch_test) +{ + EXPECT(test::throws([] { read_onnx("scan_input_axes_lens_mismatch_test.onnx"); })); +} diff --git a/test/onnx/scan_arg_count_mismatch_test.onnx b/test/onnx/scan_arg_count_mismatch_test.onnx new file mode 100644 index 00000000000..c14171c2268 Binary files /dev/null and b/test/onnx/scan_arg_count_mismatch_test.onnx differ diff --git a/test/onnx/scan_arg_shapes_mismatch_test.onnx b/test/onnx/scan_arg_shapes_mismatch_test.onnx new file mode 100644 index 00000000000..1226c00bb68 Binary files /dev/null and b/test/onnx/scan_arg_shapes_mismatch_test.onnx differ diff --git a/test/onnx/scan_input_axes_lens_mismatch_test.onnx b/test/onnx/scan_input_axes_lens_mismatch_test.onnx new file mode 100644 index 00000000000..1ddb7b6cf55 Binary files /dev/null and b/test/onnx/scan_input_axes_lens_mismatch_test.onnx differ diff --git a/test/onnx/scan_invalid_input_axes_len_test.onnx b/test/onnx/scan_invalid_input_axes_len_test.onnx new file mode 100644 index 00000000000..3049740b74a Binary files /dev/null and b/test/onnx/scan_invalid_input_axes_len_test.onnx differ diff --git a/test/onnx/scan_invalid_input_axes_vals_test.onnx b/test/onnx/scan_invalid_input_axes_vals_test.onnx new file mode 100644 index 00000000000..10b3aa58246 Binary files /dev/null and b/test/onnx/scan_invalid_input_axes_vals_test.onnx differ diff --git a/test/onnx/scan_invalid_input_dirs_len_test.onnx b/test/onnx/scan_invalid_input_dirs_len_test.onnx new file mode 100644 index 00000000000..05ea4609596 Binary files /dev/null and b/test/onnx/scan_invalid_input_dirs_len_test.onnx differ diff --git a/test/onnx/scan_invalid_input_dirs_vals_test.onnx b/test/onnx/scan_invalid_input_dirs_vals_test.onnx new file mode 100644 index 00000000000..b1ea12b2975 Binary files /dev/null and b/test/onnx/scan_invalid_input_dirs_vals_test.onnx differ diff --git a/test/onnx/scan_invalid_output_axes_len_test.onnx b/test/onnx/scan_invalid_output_axes_len_test.onnx new file mode 100644 index 00000000000..0b649d7cb65 Binary files /dev/null and b/test/onnx/scan_invalid_output_axes_len_test.onnx differ diff --git a/test/onnx/scan_invalid_output_axes_vals_test.onnx b/test/onnx/scan_invalid_output_axes_vals_test.onnx new file mode 100644 index 00000000000..000aee54a0a Binary files /dev/null and b/test/onnx/scan_invalid_output_axes_vals_test.onnx differ diff --git a/test/onnx/scan_invalid_output_dirs_len_test.onnx b/test/onnx/scan_invalid_output_dirs_len_test.onnx new file mode 100644 index 00000000000..0739fab0537 Binary files /dev/null and b/test/onnx/scan_invalid_output_dirs_len_test.onnx differ diff --git a/test/onnx/scan_invalid_output_dirs_vals_test.onnx b/test/onnx/scan_invalid_output_dirs_vals_test.onnx new file mode 100644 index 00000000000..87f0ae82e44 Binary files /dev/null and b/test/onnx/scan_invalid_output_dirs_vals_test.onnx differ diff --git a/test/onnx/scan_test1.onnx b/test/onnx/scan_test1.onnx new file mode 100644 index 00000000000..ef11c728bda Binary files /dev/null and b/test/onnx/scan_test1.onnx differ diff --git a/test/onnx/scan_test2.onnx b/test/onnx/scan_test2.onnx new file mode 100644 index 00000000000..a33ea4c5451 Binary files /dev/null and b/test/onnx/scan_test2.onnx differ diff --git a/test/onnx/scan_test3.onnx b/test/onnx/scan_test3.onnx new file mode 100644 index 00000000000..aa7c2ec33a9 Binary files /dev/null and b/test/onnx/scan_test3.onnx differ diff --git a/test/onnx/scan_test4.onnx b/test/onnx/scan_test4.onnx new file mode 100644 index 00000000000..af1fac2edc4 Binary files /dev/null and b/test/onnx/scan_test4.onnx differ diff --git a/test/onnx/scan_test5.onnx b/test/onnx/scan_test5.onnx new file mode 100644 index 00000000000..9baa8397885 Binary files /dev/null and b/test/onnx/scan_test5.onnx differ diff --git a/test/onnx/scan_test6.onnx b/test/onnx/scan_test6.onnx new file mode 100644 index 00000000000..2de8cae29d2 Binary files /dev/null and b/test/onnx/scan_test6.onnx differ diff --git a/test/onnx/scan_test7.onnx b/test/onnx/scan_test7.onnx new file mode 100644 index 00000000000..ccda6f8c7a0 Binary files /dev/null and b/test/onnx/scan_test7.onnx differ diff --git a/test/onnx/verify/scan_test.cpp b/test/onnx/verify/scan_test.cpp new file mode 100644 index 00000000000..899eea5b72c --- /dev/null +++ b/test/onnx/verify/scan_test.cpp @@ -0,0 +1,199 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static migraphx::shape make_shape(const std::vector& lens) +{ + return migraphx::shape{migraphx::shape::float_type, lens}; +} + +auto scan_test(const std::string& test_file, + migraphx::shape scan_ins1_sh, + migraphx::shape scan_ins2_sh) +{ + auto prog = read_onnx(test_file); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(init_state_sh.elements(), 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + std::vector scan_ins1(scan_ins1_sh.elements()); + std::iota(scan_ins1.begin(), scan_ins1.end(), 1); + pm["scan_ins1"] = migraphx::argument(scan_ins1_sh, scan_ins1.data()); + + std::vector scan_ins2(scan_ins2_sh.elements()); + std::iota(scan_ins2.begin(), scan_ins2.end(), 0); + pm["scan_ins2"] = migraphx::argument(scan_ins2_sh, scan_ins2.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 3); + return std::make_tuple(result[0], result[1], result[2]); +} + +TEST_CASE(scan_test1) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test1.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{1, 2, 3, 4, 7, 9, 11, 13, 18, 21, 24, 27}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test2) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test2.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{18, 21, 24, 27, 7, 9, 11, 13, 1, 2, 3, 4}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{4, 6, 18, 22, 42, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test3) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test3.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({2, 3, 2})); + std::vector scan_out1_gold{1, 2, 7, 9, 18, 21, 3, 4, 11, 13, 24, 27}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({2, 3})); + std::vector scan_out2_gold{4, 18, 42, 6, 22, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test4) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test4.onnx", make_shape({3, 2, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{18, 21, 24, 27}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{9, 10, 11, 12, 15, 17, 19, 21, 18, 21, 24, 27}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{20, 22, 34, 38, 42, 48}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test5) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test5.onnx", make_shape({2, 2, 3}), make_shape({1, 3})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{9, 18, 27, 36}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({3, 2, 2})); + std::vector scan_out1_gold{1, 4, 7, 10, 4, 10, 16, 22, 9, 18, 27, 36}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({3, 2})); + std::vector scan_out2_gold{8, 14, 20, 32, 36, 54}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test6) +{ + auto [final_state, scan_out1, scan_out2] = + scan_test("scan_test6.onnx", make_shape({2, 3, 2}), make_shape({3, 1})); + + EXPECT(final_state.get_shape() == make_shape({2, 2})); + std::vector final_state_gold{12, 15, 30, 33}; + EXPECT(final_state.to_vector() == final_state_gold); + + EXPECT(scan_out1.get_shape() == make_shape({2, 2, 3})); + std::vector scan_out1_gold{12, 7, 3, 15, 9, 4, 30, 19, 9, 33, 21, 10}; + EXPECT(scan_out1.to_vector() == scan_out1_gold); + + EXPECT(scan_out2.get_shape() == make_shape({2, 3})); + std::vector scan_out2_gold{42, 26, 12, 48, 30, 14}; + EXPECT(scan_out2.to_vector() == scan_out2_gold); +} + +TEST_CASE(scan_test7) +{ + auto prog = read_onnx("scan_test7.onnx"); + prog.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + + migraphx::shape init_state_sh{migraphx::shape::float_type, {2, 2}}; + std::vector init_state(init_state_sh.elements(), 0); + pm["init_state"] = migraphx::argument(init_state_sh, init_state.data()); + + migraphx::shape scan_ins_sh{migraphx::shape::float_type, {3, 2, 2}}; + std::vector scan_ins(scan_ins_sh.elements()); + std::iota(scan_ins.begin(), scan_ins.end(), 1); + pm["scan_ins"] = migraphx::argument(scan_ins_sh, scan_ins.data()); + + auto result = prog.eval(pm); + EXPECT(result.size() == 2); + + EXPECT(result[0].get_shape() == make_shape({2, 2})); + std::vector final_state_gold{30, 36, 42, 48}; + EXPECT(result[0].to_vector() == final_state_gold); + + EXPECT(result[1].get_shape() == make_shape({3, 2, 2})); + EXPECT(result[1].to_vector() == scan_ins); +} diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 99723493237..5b2c00f243b 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -4239,6 +4239,50 @@ TEST_CASE(slice_dyn_shape5) input); } +TEST_CASE(test_scan_slice1) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {1, 3, 4}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", 0}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice2) +{ + migraphx::shape input{migraphx::shape::float_type, {4, 6, 5}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {4, 1, 5}, {30, 5, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", 1}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice3) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 5, 7}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {2, 5, 1}, {35, 7, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", -1}, {"direction", 0}}), + input, + axis_input); +} + +TEST_CASE(test_scan_slice4) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 5, 7}}; + migraphx::shape axis_input{migraphx::shape::int64_type}; + migraphx::shape expected{migraphx::shape::float_type, {1, 5, 7}, {35, 7, 1}}; + expect_shape(expected, + migraphx::make_op("scan_slice", {{"axis", -3}, {"direction", 1}}), + input, + axis_input); +} + TEST_CASE(softmax_dyn0) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}; diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 3f9867b4072..914ff0ac5b3 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -143,7 +143,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_resize_upsample_sizes_cubic_cpu') backend_test.exclude(r'test_reversesequence_batch_cpu') backend_test.exclude(r'test_reversesequence_time_cpu') - backend_test.exclude(r'test_scan9_sum_cpu') backend_test.exclude(r'test_scan_sum_cpu') backend_test.exclude(r'test_slice_cpu') backend_test.exclude(r'test_slice_end_out_of_bounds_cpu') diff --git a/test/ref/scan_slice.cpp b/test/ref/scan_slice.cpp new file mode 100644 index 00000000000..f65385bf310 --- /dev/null +++ b/test/ref/scan_slice.cpp @@ -0,0 +1,180 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static migraphx::shape make_shape(const std::vector& lens) +{ + return migraphx::shape{migraphx::shape::int32_type, lens}; +} + +migraphx::program make_scan_slice_program(int64_t axis, int64_t direction) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; + std::vector data(data_sh.elements()); + std::iota(data.begin(), data.end(), 0); + auto data_lit = mm->add_literal(migraphx::literal{data_sh, data}); + + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + auto idx_param = mm->add_parameter("idx", idx_sh); + + mm->add_instruction(migraphx::make_op("scan_slice", {{"axis", axis}, {"direction", direction}}), + data_lit, + idx_param); + + p.compile(migraphx::make_target("ref")); + + return p; +} + +TEST_CASE(scan_slice_test_1) +{ + auto p = make_scan_slice_program(0, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); +} + +TEST_CASE(scan_slice_test_2) +{ + auto p = make_scan_slice_program(1, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 4, 5}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{2, 3, 6, 7}); +} + +TEST_CASE(scan_slice_test_3) +{ + auto p = make_scan_slice_program(2, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 1})); + EXPECT(result.to_vector() == std::vector{0, 2, 4, 6}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 2, 1})); + EXPECT(result.to_vector() == std::vector{1, 3, 5, 7}); +} + +TEST_CASE(scan_slice_test_4) +{ + auto p = make_scan_slice_program(-3, 0); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); +} + +TEST_CASE(scan_slice_test_5) +{ + auto p = make_scan_slice_program(0, 1); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{4, 5, 6, 7}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({1, 2, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 2, 3}); +} + +TEST_CASE(scan_slice_test_6) +{ + auto p = make_scan_slice_program(-2, 1); + + migraphx::parameter_map pm; + int64_t idx = 0; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{2, 3, 6, 7}); + + idx = 1; + result = p.eval(pm).back(); + EXPECT(result.get_shape() == make_shape({2, 1, 2})); + EXPECT(result.to_vector() == std::vector{0, 1, 4, 5}); +} + +TEST_CASE(scan_slice_test_7) +{ + auto p = make_scan_slice_program(0, 0); + + migraphx::parameter_map pm; + int64_t idx = 2; + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + pm["idx"] = migraphx::argument{idx_sh, &idx}; + + EXPECT(test::throws([&] { p.eval(pm); })); +} diff --git a/test/run_loop_test.cpp b/test/run_loop_test.cpp index 7c1f148c17e..c371ef59c63 100644 --- a/test/run_loop_test.cpp +++ b/test/run_loop_test.cpp @@ -156,7 +156,7 @@ struct test_loop_op cpy_args.push_back(migraphx::argument(s_cond)); cpy_args.push_back(migraphx::argument(out_shape)); // run loop - return run_loop(test_loop{max_iterations}, ctx, cpy_args, mods, run); + return run_loop(test_loop{max_iterations}, {}, ctx, cpy_args, mods, run); } }; diff --git a/test/verify/test_scan_slice.cpp b/test/verify/test_scan_slice.cpp new file mode 100644 index 00000000000..08f99938a72 --- /dev/null +++ b/test/verify/test_scan_slice.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_scan_slice_base : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape data_sh{migraphx::shape::int32_type, {2, 2, 2}}; + auto data_param = mm->add_parameter("data", data_sh); + migraphx::shape idx_sh{migraphx::shape::int64_type, {1}}; + auto idx_lit = mm->add_literal(migraphx::literal{idx_sh, {Idx}}); + + mm->add_instruction( + migraphx::make_op("scan_slice", {{"axis", Axis}, {"direction", Direction}}), + data_param, + idx_lit); + + return p; + } +}; + +struct test_scan_slice1 : test_scan_slice_base +{ +}; + +struct test_scan_slice2 : test_scan_slice_base +{ +}; + +struct test_scan_slice3 : test_scan_slice_base +{ +};