Skip to content

Commit

Permalink
Concat - multibroadcast fix (#3096)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored May 21, 2024
1 parent 93d77e9 commit 1f07af9
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 15 deletions.
73 changes: 59 additions & 14 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ struct find_nested_slice
}
};

/**
* Example case
* From:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* mb0: multibroadcast(param0, output_lens = [2, 3, 4])
* mb1: multibroadcast(param1, output_lens = [2, 3, 4])
* concat(mb0, mb1, axis = 2)
*
* To:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* con0: concat(param0, param1, axis = 1)
* multibroadcast(con0, lens = [2, 3, 4])
*/
struct find_concat_multibroadcasts
{
auto matcher() const
Expand All @@ -253,32 +268,62 @@ struct find_concat_multibroadcasts

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
auto concat_ins = mr.result;
auto concat_op = any_cast<op::concat>(concat_ins->get_operator());
auto concat_out_lens = concat_ins->get_shape().lens();
auto concat_inputs = concat_ins->inputs();
auto front_mb_strides = concat_inputs.front()->get_shape().strides();
assert(concat_op.axis >= 0);

// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) {
return i->get_shape().strides()[concat_op.axis] == 0;
}))
{
return;
}

// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
// Get the inputs of multibroadcast ops. Will be used as inputs to new concat op
std::vector<instruction_ref> mb_inputs(concat_inputs.size());
std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) {
return i->inputs().front();
});

// Check that the inputs into the multibroadcasts have the same rank
const auto& first_shape = mb_inputs.front()->get_shape();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) {
return mb_in->get_shape().ndim() == first_shape.ndim();
}))
{
return;
}

// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size())
{
concat_op.axis -=
std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0);
}

auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
// Inputs to multibroadcasts should have the same dimensions except for the axis to
// concatenate over
const auto& front_in_lens = mb_inputs.front()->get_shape().lens();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) {
const auto& lens = input_to_mb->get_shape().lens();
return std::equal(
lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and
std::equal(lens.begin() + concat_op.axis + 1,
lens.end(),
front_in_lens.begin() + concat_op.axis + 1);
}))
{
return;
}

auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs);
m.replace_instruction(concat_ins,
migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}),
new_concat_ins);
}
};

Expand Down
109 changes: 108 additions & 1 deletion test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3)
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
}

// Broadcasted batch dim, axis is broadcasted dim
// matched by find_concat_multibroadcasts but it skips this case
TEST_CASE(concat_multibroadcasts4)
{
// Broadcasted batch dim, axis is broadcasted dim
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 0;
Expand All @@ -930,6 +931,112 @@ TEST_CASE(concat_multibroadcasts4)
EXPECT(m1 == m);
}

// Matched by find_concat_multibroadcasts but skipped because dimensions other than concat axis do
// not match
TEST_CASE(concat_multibroadcasts5)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {1, 12, 60, 64, 64};
std::vector<std::size_t> mb_lens1 = {1, 12, 60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Matched by find_concat_multibroadcasts but skipped because parameter inputs are not the same
// rank.
TEST_CASE(concat_multibroadcasts6)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {12, 60, 64, 64};
std::vector<std::size_t> mb_lens1 = {12, 60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Concat axis moved to 2 because rank(in_dims) < rank(out_dims)
// Matched by find_concat_multibroadcasts but skipped because the dimensions
// other than the concat axis are not the same.
// TODO: has common broadcast axes, so can be simplified by moving multibroadcast up to have a
// smaller concat.
TEST_CASE(concat_multibroadcasts7)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {1, 12, 60, 64, 64};
std::vector<std::size_t> mb_lens1 = {1, 12, 60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Shape of inputs to multibroadcasts do not have the same rank.
// Matched by find_concat_multibroadcasts but skipped.
// TODO: has a common broadcast axis, so can be simplified by moving multibroadcast up to have a
// smaller concat.
TEST_CASE(concat_multibroadcasts8)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {64, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {60, 64, 64};
std::vector<std::size_t> mb_lens1 = {60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Shape of inputs to multibroadcasts do not have a common broadcast axis.
// Matched by find_concat_multibroadcasts, but skipped because the dimensions other than
// the concat axis are not the same.
TEST_CASE(concat_multibroadcasts9)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 64, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {60, 64, 64};
std::vector<std::size_t> mb_lens1 = {60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

TEST_CASE(concat_transpose1)
{
migraphx::module m;
Expand Down

0 comments on commit 1f07af9

Please sign in to comment.