Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse ONNX Expand to broadcast_with_dims #2799

Merged
merged 55 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
b72ad09
initial
CharlieL7 Nov 27, 2023
0ef0d0b
fixes
CharlieL7 Nov 27, 2023
47a07c3
add dynamic_dimension.within_range()
CharlieL7 Nov 27, 2023
bc062ca
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
CharlieL7 Nov 27, 2023
8053390
some progress
CharlieL7 Dec 6, 2023
3b227ad
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into dot_broadcast
CharlieL7 Jan 9, 2024
f999ae4
ref tests
CharlieL7 Jan 10, 2024
be9b8c1
Progress 2
CharlieL7 Jan 10, 2024
b4b0490
Test updates and fixes
CharlieL7 Jan 11, 2024
39c127a
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into dot_broadcast
CharlieL7 Jan 11, 2024
a586090
Update more tests
CharlieL7 Jan 11, 2024
0978db3
Fix typo
CharlieL7 Jan 12, 2024
1b7c8ff
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
CharlieL7 Jan 18, 2024
55aa1d8
License update
CharlieL7 Jan 18, 2024
3068e0e
Merge branch 'dot_broadcast' of github.com:ROCmSoftwarePlatform/AMDMI…
CharlieL7 Jan 18, 2024
f052cdd
Fixes and review updates
CharlieL7 Jan 19, 2024
9806538
More updates/fixes
CharlieL7 Jan 24, 2024
ac190ce
Formatting and date
CharlieL7 Jan 24, 2024
3c0a75e
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
CharlieL7 Jan 24, 2024
357d9a3
Fix compile
CharlieL7 Jan 25, 2024
e43c298
Fix op names
CharlieL7 Feb 5, 2024
f595a06
Other CI fixes
CharlieL7 Feb 5, 2024
0d73237
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
CharlieL7 Feb 5, 2024
2391cb1
Merge branch 'develop' into dot_broadcast
CharlieL7 Feb 9, 2024
a686c8c
Merge branch 'develop' into dot_broadcast
CharlieL7 Feb 13, 2024
a7346f7
add op
CharlieL7 Feb 13, 2024
f545145
Progress
CharlieL7 Feb 15, 2024
a7261f5
start test
CharlieL7 Feb 15, 2024
f8c5597
Fix compute() and tests
CharlieL7 Feb 19, 2024
4ecd2b3
Add more tests
CharlieL7 Feb 20, 2024
38bf0e3
Formatting
CharlieL7 Feb 20, 2024
d67c1c4
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
CharlieL7 Mar 1, 2024
9537ef0
Revert resize_linear_non_const_test.cpp
CharlieL7 Mar 1, 2024
554e850
Revert docstring change for multibroadcast.hpp
CharlieL7 Mar 1, 2024
4c2643a
Merge branch 'expand_dyn_op' into broadcast_with_dims_matcher
bpickrel Mar 25, 2024
beeed3a
added simplify_dyn_ops matcher to substitute broadcast_with_dims op, …
bpickrel Mar 26, 2024
96950d6
tidy changes
bpickrel Mar 26, 2024
33d6bc3
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into expand_dyn_op
CharlieL7 Mar 28, 2024
d0b9381
formatting
CharlieL7 Mar 28, 2024
a6c057a
Fix tests and function call
CharlieL7 Mar 28, 2024
e03c4f5
licensing
CharlieL7 Mar 28, 2024
032674b
formatting
CharlieL7 Mar 28, 2024
5fc2121
tidy ignore line
CharlieL7 Mar 28, 2024
f46788b
Merge branch 'expand_dyn_op' into broadcast_with_dims_matcher
bpickrel Mar 29, 2024
ad7998d
Comments
bpickrel Mar 29, 2024
165d23e
cosmetic variable name change
bpickrel Mar 29, 2024
39b15ce
comment
bpickrel Mar 29, 2024
371a873
remove duplicated tests
bpickrel Apr 1, 2024
21f42b7
added a negative simplify_dyn_ops test
bpickrel Apr 8, 2024
e8ffc81
style; changed vector initialization
bpickrel Apr 8, 2024
c575299
update tests
CharlieL7 Apr 10, 2024
4d31064
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into expand_dyn_op
CharlieL7 Apr 10, 2024
3c019cf
Merge branch 'develop' into expand_dyn_op
causten Apr 11, 2024
f7800ac
Merge branch 'broadcast_with_dims_matcher' into expand_dyn_op
bpickrel Apr 11, 2024
66e3f0e
correct data type in test
bpickrel Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ register_migraphx_ops(
atan
broadcast
broadcast_for_dot
broadcast_with_dims
capture
ceil
clip
Expand Down
84 changes: 84 additions & 0 deletions src/include/migraphx/op/broadcast_with_dims.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BROADCAST_WITH_DIMS_HPP
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_WITH_DIMS_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/common.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

/**
* Broadcast the input tensor to the shape defined by the values of the second input.
* Used as `broadcast_with_dims(input_tensor, dims)`, where dims is a vector of integer dimensions.
* `input_tensor` must be broadcastable with `dims`, otherwise this operator with throw at compute.
* This operator can be replaced with `multibroadcast(input_tensor)` if the `dims` vector is
* constant.
*
* Example:
* input_tensor shape: lens = {2, 3}, strides = {3, 1}
* dims = [4, 1, 3]
* output shape: lens = {4, 2, 3}, strides = {0, 3, 1}
*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this example make this one of the tests so we can have a 1:1 with expected behavior

struct broadcast_with_dims
{
std::string name() const { return "broadcast_with_dims"; }

shape compute_shape(const std::vector<shape>& inputs) const
{
migraphx::check_shapes{inputs, *this, true}.has(2);
// check that second input has a static shape
(void)migraphx::check_shapes{inputs.begin() + 1, inputs.end(), *this, false};
// output tensor rank is greater of input_tensor rank or length of dims vector
const auto& input_tensor_shape = inputs.at(0);
const auto& dims_shape = inputs.at(1);
size_t out_ndim = std::max(input_tensor_shape.ndim(), dims_shape.lens().at(0));
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_ndim,
shape::dynamic_dimension{0, max_int});
return {input_tensor_shape.type(), dyn_dims};
}

argument compute(const shape& output_shape, const std::vector<argument>& args) const
{
auto s0 = args.at(0).get_shape();
const auto& in_lens = s0.lens();
std::vector<std::size_t> dims_input(output_shape.ndim());
args.at(1).visit([&](auto a) { dims_input.assign(a.begin(), a.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims_input);
auto out_shape = make_bcast_shape(s0, out_lens);
return args[0].reshape(out_shape);
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
20 changes: 14 additions & 6 deletions src/onnx/parse_expand.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -43,11 +43,19 @@ struct parse_expand : op_parser<parse_expand>
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
if(arg_s.empty())
{
// variable dims input
return info.add_instruction(make_op("broadcast_with_dims"), args[0], args[1]);
}
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
else
{
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[0]);
}
}
};

Expand Down
30 changes: 30 additions & 0 deletions src/simplify_dyn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input
* into multibroadcast op with a static output shape attribute.
*
*/
struct find_broadcast_with_dims_static
{
auto matcher() const
{
return match::name("broadcast_with_dims")(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::is_constant()));
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();

// read the values of arg(1) to create input to multibroadcast
std::vector<size_t> sizes_vec;
inputs.at(1)->eval().visit(
[&](auto output) { sizes_vec.assign(output.begin(), output.end()); });

m.replace_instruction(
ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0));
}
};

/**
* Convert a Resize op. with Nearest mode to an implementation using Gather op.
* From: resize[scales={...}/sizes={...},](static, constant)
Expand Down Expand Up @@ -586,6 +615,7 @@ struct simplify_select_module_output_shape
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
find_broadcast_with_dims_static{},
find_resize_static{},
find_static_dimensions_of{},
find_const_alloc_reshapes{},
Expand Down
19 changes: 19 additions & 0 deletions test/onnx/expand_dyn_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 expand_dyn_test:q

x
dimsy"Expandexpand_dyn_testZ
x



Z
dims


b
y




B
11 changes: 11 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,17 @@ def expand_test():
return ([shape_const, node], [x], [y])


@onnx_test()
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
def expand_dyn_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 1])
dims_in = helper.make_tensor_value_info('dims', TensorProto.INT64, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5])

node = onnx.helper.make_node('Expand', inputs=['x', 'dims'], outputs=['y'])

return ([node], [x, dims_in], [y])


@onnx_test(True)
def external_constant_test():
x = np.array([0, 1, 2])
Expand Down
16 changes: 15 additions & 1 deletion test/onnx/parse/expand_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,3 +37,17 @@ TEST_CASE(expand_test)
auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog);
}

CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
TEST_CASE(expand_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = mm->add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int64_type, {4});
auto dims = mm->add_parameter("dims", ss);
mm->add_instruction(migraphx::make_op("broadcast_with_dims"), param, dims);

auto prog = optimize_onnx("expand_dyn_test.onnx");
EXPECT(p == prog);
}
37 changes: 37 additions & 0 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,43 @@ TEST_CASE(broadcast_for_dot_dyn2)
s0);
}

TEST_CASE(broadcast_with_dims0)
{
using migraphx::shape;
shape s0{migraphx::shape::float_type, {2, 4}};
shape s1{migraphx::shape::int64_type, {4}};
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(4, shape::dynamic_dimension{0, max_int});
expect_shape(
shape{shape::float_type, dyn_dims}, migraphx::make_op("broadcast_with_dims"), s0, s1);
}

TEST_CASE(broadcast_with_dims1)
bpickrel marked this conversation as resolved.
Show resolved Hide resolved
{
using migraphx::shape;
shape s0{migraphx::shape::int32_type, {1, 2, 4}};
shape s1{migraphx::shape::int64_type, {1}};
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(3, shape::dynamic_dimension{0, max_int});
expect_shape(shape{migraphx::shape::int32_type, dyn_dims},
migraphx::make_op("broadcast_with_dims"),
s0,
s1);
}

TEST_CASE(broadcast_with_dims2)
{
using migraphx::shape;
shape s0{migraphx::shape::float_type, {{1, 4}, {2, 2}, {4, 4}}};
shape s1{migraphx::shape::int64_type, {4}};
std::size_t max_int = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(4, shape::dynamic_dimension{0, max_int});
expect_shape(shape{migraphx::shape::float_type, dyn_dims},
migraphx::make_op("broadcast_with_dims"),
s0,
s1);
}

TEST_CASE(flatten_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
Expand Down
Loading
Loading