Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split dynamic shape parsing update #3034

Merged
merged 20 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 135 additions & 63 deletions src/onnx/parse_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,131 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

auto parse_dyn_split(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args,
int64_t tuned_axis)
{
if(contains(info.attributes, "split"))
{
MIGRAPHX_THROW("PARSE_SPLIT: dynamic input and non-fixed split axis and `split` "
"attribute not supported");
}
if(args.size() == 2)
{
MIGRAPHX_THROW("PARSE_SPLIT: dynamic input and non-fixed split axis and `split` "
"input not supported");
}

std::size_t num_outputs = info.num_outputs;
std::vector<instruction_ref> ret_ins(num_outputs);

// Doing shape calculations for the splits in the graph
auto split_dim = info.add_instruction(
make_op("dimensions_of", {{"start", tuned_axis}, {"end", tuned_axis + 1}}), args[0]);
shape int64_scalar_shape{shape::int64_type, {1}, {0}};
auto num_outputs_lit = info.add_literal(literal{int64_scalar_shape, {num_outputs}});
auto num_outputs_minus_1_lit = info.add_literal(literal{int64_scalar_shape, {num_outputs - 1}});
// (A + (B - 1)) / B == ceil(A / B)
auto chunk_size = info.add_instruction(
make_op("div"),
info.add_instruction(make_op("add"), split_dim, num_outputs_minus_1_lit),
num_outputs_lit);
for(int n = 0; n < num_outputs - 1; ++n)
{
// slice(input, starts = {n * chunk_size}, ends = {(n+1) * chunk_size}); axes =
// {tuned_axis}
ret_ins.at(n) = info.add_instruction(
make_op("slice", {{"axes", {tuned_axis}}}),
args[0],
info.add_instruction(
make_op("mul"), chunk_size, info.add_literal(literal{int64_scalar_shape, {n}})),
info.add_instruction(make_op("mul"),
chunk_size,
info.add_literal(literal{int64_scalar_shape, {n + 1}})));
}
// last slice: slice(input, starts = {n * chunk_size}); ends = max_int, axes =
// {tuned_axis}
ret_ins.at(num_outputs - 1) = info.add_instruction(
make_op("slice", {{"axes", {tuned_axis}}, {"ends", {std::numeric_limits<int64_t>::max()}}}),
args[0],
info.add_instruction(make_op("mul"),
chunk_size,
info.add_literal(literal{int64_scalar_shape, {num_outputs - 1}})));
return ret_ins;
}

auto parse_static_split(const onnx_parser::node_info& info,
const onnx_parser& parser,
const std::vector<instruction_ref>& args,
int64_t tuned_axis)
{
const auto& input_shape = args[0]->get_shape();
// either static shape or fixed dynamic_dimension for split axis
auto tuned_axis_len = input_shape.to_static(0).lens().at(tuned_axis);
std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
else if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "PARSE_SPLIT: non-constant `split` input is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
// no split attribute, input is equally divided
else
{
std::size_t num_outputs = info.num_outputs;
// the num_outputs attribute seems to be redundant since we already have
// node_info::num_outputs, but we can still perform an error check
if(contains(info.attributes, "num_outputs"))
{
num_outputs = parser.parse_value(info.attributes.at("num_outputs")).at<std::size_t>();
if(num_outputs != info.num_outputs)
{
MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " + std::to_string(num_outputs) +
" doesn't match actual number of outputs " +
std::to_string(info.num_outputs) + "!");
}
}
if(tuned_axis_len % num_outputs == 0)
{
std::size_t chunk_size = tuned_axis_len / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
}
else
{
std::size_t chunk_size = tuned_axis_len / num_outputs + 1;
std::size_t last_chunk_size = tuned_axis_len - chunk_size * (num_outputs - 1);
vec_splits.resize(num_outputs - 1, chunk_size);
vec_splits.push_back(last_chunk_size);
Comment on lines +135 to +136
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can potentially lead to resizing and memory allocation twice for the vector. Ideally, you can allocate vector for the splits once for the size (num_outputs) and then put values in that

}
}

if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(tuned_axis_len))
{
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(tuned_axis_len) + " Output " + to_string_range(vec_splits) + " Rank " +
std::to_string(input_shape.ndim()));
}

std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(info.add_instruction(
make_op("slice", {{"axes", {tuned_axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
}

return ret_ins;
}

struct parse_split : op_parser<parse_split>
{
std::vector<op_desc> operators() const { return {{"Split"}}; }
Expand All @@ -49,75 +174,22 @@ struct parse_split : op_parser<parse_split>
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}

auto lens = args[0]->get_shape().lens();
int64_t n_rank = lens.size();
int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
const auto& input_shape = args[0]->get_shape();
// axis over which the split occurs (split_axis)
int64_t tuned_axis = tune_axis(input_shape.ndim(), axis, opd.op_name);

std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
else if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Split: dynamic shape is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
// no split attribute, input is equally divided
else
{
std::size_t num_outputs = info.num_outputs;
// the num_outputs attribute seems to be redundant since we already have
// node_info::num_outputs, but we can still perform an error check
if(contains(info.attributes, "num_outputs"))
{
num_outputs =
parser.parse_value(info.attributes.at("num_outputs")).at<std::size_t>();
if(num_outputs != info.num_outputs)
{
MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " +
std::to_string(num_outputs) +
" doesn't match actual number of outputs " +
std::to_string(info.num_outputs) + "!");
}
}

if(lens[tuned_axis] % num_outputs == 0)
{
std::size_t chunk_size = lens[tuned_axis] / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
}
else
{
std::size_t chunk_size = lens[tuned_axis] / num_outputs + 1;
std::size_t last_chunk_size = lens[tuned_axis] - chunk_size * (num_outputs - 1);
vec_splits.resize(num_outputs - 1, chunk_size);
vec_splits.push_back(last_chunk_size);
}
}
auto split_axis_is_fixed = [&]() {
return input_shape.dyn_dims().at(tuned_axis).is_fixed();
};
Comment on lines +181 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can simply be a bool variable.


if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
if(input_shape.dynamic() and not split_axis_is_fixed())
{
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(lens[tuned_axis]) + " Output " + to_string_range(vec_splits) +
" Rank " + std::to_string(n_rank) + " Len outs " + to_string_range(lens));
return parse_dyn_split(info, args, tuned_axis);
}

std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
else
{
ret_ins.push_back(info.add_instruction(
make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
return parse_static_split(info, parser, args, tuned_axis);
}

return ret_ins;
}
};

Expand Down
71 changes: 71 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10828,6 +10828,77 @@ def split_test_invalid_num_outputs():
return ([node], [x], [y1, y2, y3, y4])


@onnx_test()
def split_dyn_input_fixed_split_axis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1)

return ([node], [x], [y1, y2, y3])


@onnx_test()
def split_dyn_input_dyn_split_axis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=0)

return ([node], [x], [y1, y2, y3])


@onnx_test()
def split_dyn_input_split_attr_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=0,
split=[7, 4, 4])

return ([node], [x], [y1, y2, y3])


@onnx_test()
def split_dyn_input_split_input_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [None, 5])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [None, 5])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [None, 5])

split = np.ones(3) * 5
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)

node = onnx.helper.make_node('Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3'],
axis=0)

return ([const_node, node], [x], [y1, y2, y3])


@onnx_test()
def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
Expand Down
Loading
Loading