Skip to content

Commit

Permalink
Parse ONNX Expand to broadcast_with_dims (#2799)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Apr 16, 2024
1 parent c3c0980 commit 6e496c1
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ register_migraphx_ops(
atan
broadcast
broadcast_for_dot
broadcast_with_dims
capture
ceil
clip
Expand Down
84 changes: 84 additions & 0 deletions src/include/migraphx/op/broadcast_with_dims.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BROADCAST_WITH_DIMS_HPP
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_WITH_DIMS_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/common.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

/**
* Broadcast the input tensor to the shape defined by the values of the second input.
* Used as `broadcast_with_dims(input_tensor, dims)`, where dims is a vector of integer dimensions.
* `input_tensor` must be broadcastable with `dims`, otherwise this operator with throw at compute.
* This operator can be replaced with `multibroadcast(input_tensor)` if the `dims` vector is
* constant.
*
* Example:
* input_tensor shape: lens = {2, 3}, strides = {3, 1}
* dims = [4, 1, 3]
* output shape: lens = {4, 2, 3}, strides = {0, 3, 1}
*/
struct broadcast_with_dims
{
std::string name() const { return "broadcast_with_dims"; }

shape compute_shape(const std::vector<shape>& inputs) const
{
migraphx::check_shapes{inputs, *this, true}.has(2);
// check that second input has a static shape
(void)migraphx::check_shapes{inputs.begin() + 1, inputs.end(), *this, false};
// output tensor rank is greater of input_tensor rank or length of dims vector
const auto& input_tensor_shape = inputs.at(0);
const auto& dims_shape = inputs.at(1);
size_t out_ndim = std::max(input_tensor_shape.ndim(), dims_shape.lens().at(0));
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_ndim,
shape::dynamic_dimension{0, max_int});
return {input_tensor_shape.type(), dyn_dims};
}

argument compute(const shape& output_shape, const std::vector<argument>& args) const
{
auto s0 = args.at(0).get_shape();
const auto& in_lens = s0.lens();
std::vector<std::size_t> dims_input(output_shape.ndim());
args.at(1).visit([&](auto a) { dims_input.assign(a.begin(), a.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims_input);
auto out_shape = make_bcast_shape(s0, out_lens);
return args[0].reshape(out_shape);
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
20 changes: 14 additions & 6 deletions src/onnx/parse_expand.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -43,11 +43,19 @@ struct parse_expand : op_parser<parse_expand>
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
if(arg_s.empty())
{
// variable dims input
return info.add_instruction(make_op("broadcast_with_dims"), args[0], args[1]);
}
else
{
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[0]);
}
}
};

Expand Down
30 changes: 30 additions & 0 deletions src/simplify_dyn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input
* into multibroadcast op with a static output shape attribute.
*
*/
struct find_broadcast_with_dims_static
{
auto matcher() const
{
return match::name("broadcast_with_dims")(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::is_constant()));
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();

// read the values of arg(1) to create input to multibroadcast
std::vector<size_t> sizes_vec;
inputs.at(1)->eval().visit(
[&](auto output) { sizes_vec.assign(output.begin(), output.end()); });

m.replace_instruction(
ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0));
}
};

/**
* Convert a Resize op. with Nearest mode to an implementation using Gather op.
* From: resize[scales={...}/sizes={...},](static, constant)
Expand Down Expand Up @@ -586,6 +615,7 @@ struct simplify_select_module_output_shape
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
find_broadcast_with_dims_static{},
find_resize_static{},
find_static_dimensions_of{},
find_const_alloc_reshapes{},
Expand Down
19 changes: 19 additions & 0 deletions test/onnx/expand_dyn_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 expand_dyn_test:q

x
dimsy"Expandexpand_dyn_testZ
x



Z
dims


b
y




B
11 changes: 11 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,17 @@ def expand_test():
return ([shape_const, node], [x], [y])


@onnx_test()
def expand_dyn_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 1])
dims_in = helper.make_tensor_value_info('dims', TensorProto.INT64, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5])

node = onnx.helper.make_node('Expand', inputs=['x', 'dims'], outputs=['y'])

return ([node], [x, dims_in], [y])


@onnx_test(True)
def external_constant_test():
x = np.array([0, 1, 2])
Expand Down
16 changes: 15 additions & 1 deletion test/onnx/parse/expand_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,3 +37,17 @@ TEST_CASE(expand_test)
auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog);
}

TEST_CASE(expand_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = mm->add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int64_type, {4});
auto dims = mm->add_parameter("dims", ss);
mm->add_instruction(migraphx::make_op("broadcast_with_dims"), param, dims);

auto prog = optimize_onnx("expand_dyn_test.onnx");
EXPECT(p == prog);
}
37 changes: 37 additions & 0 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,43 @@ TEST_CASE(broadcast_for_dot_dyn2)
s0);
}

TEST_CASE(broadcast_with_dims0)
{
using migraphx::shape;
shape s0{migraphx::shape::float_type, {2, 4}};
shape s1{migraphx::shape::int64_type, {4}};
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(4, shape::dynamic_dimension{0, max_int});
expect_shape(
shape{shape::float_type, dyn_dims}, migraphx::make_op("broadcast_with_dims"), s0, s1);
}

TEST_CASE(broadcast_with_dims1)
{
using migraphx::shape;
shape s0{migraphx::shape::int32_type, {1, 2, 4}};
shape s1{migraphx::shape::int64_type, {1}};
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(3, shape::dynamic_dimension{0, max_int});
expect_shape(shape{migraphx::shape::int32_type, dyn_dims},
migraphx::make_op("broadcast_with_dims"),
s0,
s1);
}

TEST_CASE(broadcast_with_dims2)
{
using migraphx::shape;
shape s0{migraphx::shape::float_type, {{1, 4}, {2, 2}, {4, 4}}};
shape s1{migraphx::shape::int64_type, {4}};
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(4, shape::dynamic_dimension{0, max_int});
expect_shape(shape{migraphx::shape::float_type, dyn_dims},
migraphx::make_op("broadcast_with_dims"),
s0,
s1);
}

TEST_CASE(flatten_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
Expand Down
Loading

0 comments on commit 6e496c1

Please sign in to comment.