Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split dynamic shape parsing update #3034

Merged
merged 20 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 117 additions & 53 deletions src/onnx/parse_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,75 +49,139 @@ struct parse_split : op_parser<parse_split>
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}

auto lens = args[0]->get_shape().lens();
int64_t n_rank = lens.size();
int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
const auto& input_shape = args[0]->get_shape();
// axis over which the split occurs (split_axis)
int64_t tuned_axis = tune_axis(input_shape.ndim(), axis, opd.op_name);

std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
else if(args.size() == 2)
auto split_axis_is_fixed = [&]() {
return input_shape.dyn_dims().at(tuned_axis).is_fixed();
};
Comment on lines +181 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can simply be a bool variable.


if(input_shape.dynamic() and not split_axis_is_fixed())
{
auto s = args[1]->eval();
check_arg_empty(s, "Split: dynamic shape is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(contains(info.attributes, "split"))
{
MIGRAPHX_THROW("PARSE_SPLIT: dynamic input and non-fixed split axis and `split` "
"attribute not supported");
}
if(args.size() == 2)
{
MIGRAPHX_THROW("PARSE_SPLIT: dynamic input and non-fixed split axis and `split` "
"input not supported");
}

std::size_t num_outputs = info.num_outputs;
std::vector<instruction_ref> ret_ins(num_outputs);

// Doing shape calculations for the splits in the graph
auto split_dim = info.add_instruction(
make_op("dimensions_of", {{"start", tuned_axis}, {"end", tuned_axis + 1}}),
args[0]);
shape int64_scalar_shape{shape::int64_type, {1}, {0}};
auto num_outputs_lit = info.add_literal(literal{int64_scalar_shape, {num_outputs}});
auto num_outputs_minus_1_lit =
info.add_literal(literal{int64_scalar_shape, {num_outputs - 1}});
// (A + (B - 1)) / B == ceil(A / B)
auto chunk_size = info.add_instruction(
make_op("div"),
info.add_instruction(make_op("add"), split_dim, num_outputs_minus_1_lit),
num_outputs_lit);
for(int n = 0; n < num_outputs - 1; ++n)
{
// slice(input, starts = {n * chunk_size}, ends = {(n+1) * chunk_size}); axes =
// {tuned_axis}
ret_ins.at(n) = info.add_instruction(
make_op("slice", {{"axes", {tuned_axis}}}),
args[0],
info.add_instruction(make_op("mul"),
chunk_size,
info.add_literal(literal{int64_scalar_shape, {n}})),
info.add_instruction(make_op("mul"),
chunk_size,
info.add_literal(literal{int64_scalar_shape, {n + 1}})));
}
// last slice: slice(input, starts = {n * chunk_size}); ends = max_int, axes =
// {tuned_axis}
ret_ins.at(num_outputs - 1) = info.add_instruction(
make_op("slice",
{{"axes", {tuned_axis}}, {"ends", {std::numeric_limits<int64_t>::max()}}}),
args[0],
info.add_instruction(
make_op("mul"),
chunk_size,
info.add_literal(literal{int64_scalar_shape, {num_outputs - 1}})));
return ret_ins;
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
}
// no split attribute, input is equally divided
else
{
std::size_t num_outputs = info.num_outputs;
// the num_outputs attribute seems to be redundant since we already have
// node_info::num_outputs, but we can still perform an error check
if(contains(info.attributes, "num_outputs"))
// either static shape or fixed dynamic_dimension for split axis
auto tuned_axis_len = input_shape.to_static(0).lens().at(tuned_axis);
std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
else if(args.size() == 2)
{
num_outputs =
parser.parse_value(info.attributes.at("num_outputs")).at<std::size_t>();
if(num_outputs != info.num_outputs)
auto s = args[1]->eval();
check_arg_empty(s, "PARSE_SPLIT: non-constant `split` input is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
// no split attribute, input is equally divided
else
{
std::size_t num_outputs = info.num_outputs;
// the num_outputs attribute seems to be redundant since we already have
// node_info::num_outputs, but we can still perform an error check
if(contains(info.attributes, "num_outputs"))
{
num_outputs =
parser.parse_value(info.attributes.at("num_outputs")).at<std::size_t>();
if(num_outputs != info.num_outputs)
{
MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " +
std::to_string(num_outputs) +
" doesn't match actual number of outputs " +
std::to_string(info.num_outputs) + "!");
}
}
if(tuned_axis_len % num_outputs == 0)
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
{
std::size_t chunk_size = tuned_axis_len / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
}
else
{
MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " +
std::to_string(num_outputs) +
" doesn't match actual number of outputs " +
std::to_string(info.num_outputs) + "!");
std::size_t chunk_size = tuned_axis_len / num_outputs + 1;
std::size_t last_chunk_size = tuned_axis_len - chunk_size * (num_outputs - 1);
vec_splits.resize(num_outputs - 1, chunk_size);
vec_splits.push_back(last_chunk_size);
}
}

if(lens[tuned_axis] % num_outputs == 0)
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
static_cast<int64_t>(tuned_axis_len))
{
std::size_t chunk_size = lens[tuned_axis] / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(tuned_axis_len) + " Output " + to_string_range(vec_splits) +
" Rank " + std::to_string(input_shape.ndim()));
}
else

std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
std::size_t chunk_size = lens[tuned_axis] / num_outputs + 1;
std::size_t last_chunk_size = lens[tuned_axis] - chunk_size * (num_outputs - 1);
vec_splits.resize(num_outputs - 1, chunk_size);
vec_splits.push_back(last_chunk_size);
ret_ins.push_back(info.add_instruction(
make_op("slice",
{{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
}
}

if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(lens[tuned_axis]) + " Output " + to_string_range(vec_splits) +
" Rank " + std::to_string(n_rank) + " Len outs " + to_string_range(lens));
return ret_ins;
}

std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(info.add_instruction(
make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
}

return ret_ins;
}
};

Expand Down
67 changes: 67 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10827,6 +10827,73 @@ def split_test_invalid_num_outputs():

return ([node], [x], [y1, y2, y3, y4])

@onnx_test()
def split_dyn_input_fixed_split_axis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1)

return ([node], [x], [y1, y2, y3])

@onnx_test()
def split_dyn_input_dyn_split_axis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=0)

return ([node], [x], [y1, y2, y3])

@onnx_test()
def split_dyn_input_split_attr_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=0,
split=[7, 4, 4])

return ([node], [x], [y1, y2, y3])

@onnx_test()
def split_dyn_input_split_input_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

split = np.ones(3) * 5
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)

node = onnx.helper.make_node('Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3'],
axis=0)

return ([const_node, node], [x], [y1, y2, y3])


@onnx_test()
def sqrt_test():
Expand Down
109 changes: 109 additions & 0 deletions test/onnx/parse/split_dyn_input.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(split_dyn_input_fixed_split_axis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {{10, 30}, {15, 15}}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {5}}, {"ends", {10}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {10}}, {"ends", {15}}}), input);
mm->add_return({r1, r2, r3});

migraphx::onnx_options options;
options.default_dyn_dim_value = {10, 30};
auto prog = migraphx::parse_onnx("split_dyn_input_fixed_split_axis_test.onnx", options);
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
EXPECT(p == prog);
}

TEST_CASE(split_dyn_input_dyn_split_axis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {{10, 30}, {15, 15}}});
auto split_dim =
mm->add_instruction(migraphx::make_op("dimensions_of", {{"start", 0}, {"end", 1}}), input);
migraphx::shape int64_scalar_shape{migraphx::shape::int64_type, {1}, {0}};
auto num_outputs_lit = mm->add_literal(migraphx::literal{int64_scalar_shape, {3}});
auto num_outputs_minus_1_lit = mm->add_literal(migraphx::literal{int64_scalar_shape, {2}});
auto chunk_size = mm->add_instruction(
migraphx::make_op("div"),
mm->add_instruction(migraphx::make_op("add"), split_dim, num_outputs_minus_1_lit),
num_outputs_lit);
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}),
input,
mm->add_instruction(migraphx::make_op("mul"),
chunk_size,
mm->add_literal(migraphx::literal{int64_scalar_shape, {0}})),
mm->add_instruction(migraphx::make_op("mul"),
chunk_size,
mm->add_literal(migraphx::literal{int64_scalar_shape, {1}})));
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}),
input,
mm->add_instruction(migraphx::make_op("mul"),
chunk_size,
mm->add_literal(migraphx::literal{int64_scalar_shape, {1}})),
mm->add_instruction(migraphx::make_op("mul"),
chunk_size,
mm->add_literal(migraphx::literal{int64_scalar_shape, {2}})));
auto r3 = mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {0}}, {"ends", {std::numeric_limits<int64_t>::max()}}}),
input,
mm->add_instruction(migraphx::make_op("mul"),
chunk_size,
mm->add_literal(migraphx::literal{int64_scalar_shape, {2}})));
mm->add_return({r1, r2, r3});

migraphx::onnx_options options;
options.default_dyn_dim_value = {10, 30};
auto prog = migraphx::parse_onnx("split_dyn_input_dyn_split_axis_test.onnx", options);
EXPECT(p == prog);
}

TEST_CASE(split_dyn_input_split_attr_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {10, 30};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("split_dyn_input_split_attr_test.onnx", options); }));
}

TEST_CASE(split_dyn_input_split_input_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {10, 30};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("split_dyn_input_split_input_test.onnx", options); }));
}
Binary file added test/onnx/split_dyn_input_dyn_split_axis_test.onnx
Binary file not shown.
Binary file not shown.
Binary file added test/onnx/split_dyn_input_split_attr_test.onnx
Binary file not shown.
Binary file added test/onnx/split_dyn_input_split_input_test.onnx
Binary file not shown.
Loading
Loading