Skip to content

Commit

Permalink
Add support for Scan operator (#2936)
Browse files Browse the repository at this point in the history
  • Loading branch information
music-dino authored and TedThemistokleous committed Aug 21, 2024
1 parent 9a3fb0e commit 904ceaf
Show file tree
Hide file tree
Showing 35 changed files with 1,455 additions and 15 deletions.
8 changes: 7 additions & 1 deletion docs/dev/onnx_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,13 @@ Operator Support Matrix
+--------------------------+-----------+-----------------+------------------------------+
| STFT || | |
+--------------------------+-----------+-----------------+------------------------------+
| Scan | 👷 | 👷 | |
| Scan || UINT8, UINT16, | ``identity``, |
| | | UINT32, UINT64, | ``sequence`` |
| | | INT8, INT16, | datatypes are |
| | | INT32, INT64, | not supported, |
| | | FP8, FP16, | Number of iterations has |
| | | FP32, FP64 | upper-bound |
| | | | Version 8 not supported |
+--------------------------+-----------+-----------------+------------------------------+
| Scatter (deprecated) || BOOL, UINT8, | |
| | | UINT16, UINT32, | |
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ register_migraphx_ops(
rsqrt
run_on_target
scalar
scan_slice
scatter_none
scatter_add
scatter_mul
Expand Down
19 changes: 13 additions & 6 deletions src/include/migraphx/op/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ namespace op {

struct loop
{
int64_t max_iterations = 10;
int64_t max_iterations = 10;
std::vector<int64_t> scan_output_directions = {};

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_iterations, "max_iterations"));
return pack(f(self.max_iterations, "max_iterations"),
f(self.scan_output_directions, "scan_output_directions"));
}

std::string name() const { return "loop"; }
Expand Down Expand Up @@ -97,19 +99,24 @@ struct loop

void append(const std::vector<argument>& iter_state,
const std::vector<argument>& concatenated_outputs,
int iter) const
const std::vector<int64_t>& scan_output_dirs,
int64_t curr_iter,
int64_t num_iters) const
{
assert(iter_state.size() == concatenated_outputs.size());
for(auto i : range(iter_state.size()))
{
const auto& iter_stat = iter_state.at(i);
const auto& scan_out = concatenated_outputs.at(i);

auto dir = scan_output_dirs.empty() ? 0 : scan_output_dirs[i];
auto idx = (1 - dir) * curr_iter + dir * (num_iters - 1 - curr_iter);

auto* in_data = iter_stat.data();
auto* out_data = scan_out.data();
std::size_t out_size = iter_stat.get_shape().bytes();
assert((iter + 1) * out_size <= scan_out.get_shape().bytes());
std::copy(in_data, in_data + out_size, out_data + iter * out_size);
assert((idx + 1) * out_size <= scan_out.get_shape().bytes());
std::copy(in_data, in_data + out_size, out_data + idx * out_size);
}
}

Expand Down Expand Up @@ -153,7 +160,7 @@ struct loop
cpy_args.push_back(argument(out_shape));

// run loop
return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run);
return run_loop(ref_loop{max_iterations}, scan_output_directions, ctx, cpy_args, mods, run);
}
};

Expand Down
91 changes: 91 additions & 0 deletions src/include/migraphx/op/scan_slice.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_SLICE_HPP

#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <array>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct scan_slice : op_name<scan_slice>
{
int64_t axis = 0;
int64_t direction = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.direction, "direction"));
}

value attributes() const
{
value normalize_axes = value::object{};
normalize_axes["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize_axes}};
}

shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto input_shape = inputs[0];
auto new_lens = input_shape.lens();
new_lens[axis] = 1;

return shape{input_shape.type(), new_lens, input_shape.strides()};
}

argument compute(shape output_shape, std::vector<argument> args) const
{
auto input = args[0];
auto input_sh = input.get_shape();

int64_t idx;
args[1].visit([&](auto i) { idx = i.front(); });
const auto max_idx = input_sh.lens()[axis] - 1;
if(idx > max_idx or idx < 0)
MIGRAPHX_THROW("ScanSlice: index {" + std::to_string(idx) + "} out of range [0, " +
std::to_string(max_idx) + "]");
idx = (1 - direction) * idx + direction * (max_idx - idx);

auto offset = idx * input_sh.strides().at(axis) * input_sh.type_size();
return {output_shape, [=] { return input.data() + offset; }};
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
1 change: 1 addition & 0 deletions src/include/migraphx/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scan_slice.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
Expand Down
14 changes: 10 additions & 4 deletions src/include/migraphx/run_loop.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP

#include <migraphx/stringutils.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
Expand All @@ -39,6 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS {

template <class LoopModel, class T>
argument run_loop(const LoopModel& model,
const std::vector<int64_t>& scan_output_directions,
T& ctx,
std::vector<argument> args,
const std::vector<module_ref>& mods,
Expand Down Expand Up @@ -101,9 +103,13 @@ argument run_loop(const LoopModel& model,
auto output_index = out_param_indices[name];
if(output_index > dep_num)
{
int64_t dir = scan_output_directions.empty()
? 0
: scan_output_directions[output_index - dep_num - 1];
auto idx = (1 - dir) * iter + dir * (iter_num - 1 - iter);
const auto& arg = out_args.at(output_index);
assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + iter * ps.bytes());
assert((idx + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + idx * ps.bytes());
}
else
{
Expand All @@ -123,7 +129,7 @@ argument run_loop(const LoopModel& model,
std::copy(dep_out.begin(), dep_out.end(), out_args.begin());

std::vector<argument> mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end());
model.append(mod_scan_outs, scan_outputs, iter);
model.append(mod_scan_outs, scan_outputs, scan_output_directions, iter, iter_num);
}

out_args.erase(out_args.begin());
Expand Down
Loading

0 comments on commit 904ceaf

Please sign in to comment.