Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concat - multibroadcast fix #3096

Merged
merged 13 commits into from
May 21, 2024
73 changes: 59 additions & 14 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ struct find_nested_slice
}
};

/**
* Example case
* From:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* mb0: multibroadcast(param0, output_lens = [2, 3, 4])
* mb1: multibroadcast(param1, output_lens = [2, 3, 4])
* concat(mb0, mb1, axis = 2)
*
* To:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* con0: concat(param0, param1, axis = 1)
* multibroadcast(con0, lens = [2, 3, 4])
*/
struct find_concat_multibroadcasts
{
auto matcher() const
Expand All @@ -253,32 +268,62 @@ struct find_concat_multibroadcasts

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
auto concat_ins = mr.result;
auto concat_op = any_cast<op::concat>(concat_ins->get_operator());
auto concat_out_lens = concat_ins->get_shape().lens();
auto concat_inputs = concat_ins->inputs();
auto front_mb_strides = concat_inputs.front()->get_shape().strides();
assert(concat_op.axis >= 0);

// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) {
return i->get_shape().strides()[concat_op.axis] == 0;
}))
{
return;
}

// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
// Get the inputs of multibroadcast ops. Will be used as inputs to new concat op
std::vector<instruction_ref> mb_inputs(concat_inputs.size());
std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) {
return i->inputs().front();
});

// Check that the inputs into the multibroadcasts have the same rank
const auto& first_shape = mb_inputs.front()->get_shape();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) {
return mb_in->get_shape().ndim() == first_shape.ndim();
}))
{
return;
}

// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size())
{
concat_op.axis -=
std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0);
}

auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
// Inputs to multibroadcasts should have the same dimensions except for the axis to
// concatenate over
const auto& front_in_lens = mb_inputs.front()->get_shape().lens();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) {
const auto& lens = input_to_mb->get_shape().lens();
return std::equal(
lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and
std::equal(lens.begin() + concat_op.axis + 1,
lens.end(),
front_in_lens.begin() + concat_op.axis + 1);
}))
{
return;
}

auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs);
m.replace_instruction(concat_ins,
migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}),
new_concat_ins);
}
};

Expand Down
109 changes: 108 additions & 1 deletion test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3)
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
}

// Broadcasted batch dim, axis is broadcasted dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the simplify_reshapes matchers each test is intended to match. There's currently also no way to check whether a test that's expected to match but then exit without doing anything actually matched. For now, suggest you add to the descriptive comment for each test: what matcher it's supposed to match (or not match). The current test names are very similar to struct find_concat_multibroadcasts but just different enough that the intent might not be clear to someone new.

A similar problem would come up if the order of passes in simplify_reshapes changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar problem would come up if the order of passes in simplify_reshapes changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.

That's possible, but quite unlikely given how small these test cases are. The only way I can think of to actually prevent this is to have passes that only run one matcher.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the simplify_reshapes matchers each test is intended to match.

I don't follow your logic here. All we would have to do is turn on MIGRAPHX_TRACE_MATCHES.

// matched by find_concat_multibroadcasts but it skips this case
TEST_CASE(concat_multibroadcasts4)
{
// Broadcasted batch dim, axis is broadcasted dim
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 0;
Expand All @@ -930,6 +931,112 @@ TEST_CASE(concat_multibroadcasts4)
EXPECT(m1 == m);
}

// Matched by find_concat_multibroadcasts but skipped because dimensions other than concat axis do
// not match
TEST_CASE(concat_multibroadcasts5)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {1, 12, 60, 64, 64};
std::vector<std::size_t> mb_lens1 = {1, 12, 60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Matched by find_concat_multibroadcasts but skipped because parameter inputs are not the same
// rank.
TEST_CASE(concat_multibroadcasts6)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {12, 60, 64, 64};
std::vector<std::size_t> mb_lens1 = {12, 60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Concat axis moved to 2 because rank(in_dims) < rank(out_dims)
// Matched by find_concat_multibroadcasts but skipped because the dimensions
// other than the concat axis are not the same.
// TODO: has common broadcast axes, so can be simplified by moving multibroadcast up to have a
// smaller concat.
TEST_CASE(concat_multibroadcasts7)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {1, 12, 60, 64, 64};
std::vector<std::size_t> mb_lens1 = {1, 12, 60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Shape of inputs to multibroadcasts do not have the same rank.
// Matched by find_concat_multibroadcasts but skipped.
// TODO: has a common broadcast axis, so can be simplified by moving multibroadcast up to have a
// smaller concat.
TEST_CASE(concat_multibroadcasts8)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {64, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {60, 64, 64};
std::vector<std::size_t> mb_lens1 = {60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}

// Shape of inputs to multibroadcasts do not have a common broadcast axis.
// Matched by find_concat_multibroadcasts, but skipped because the dimensions other than
// the concat axis are not the same.
TEST_CASE(concat_multibroadcasts9)
{
migraphx::module m;
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 64, 64}};
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}};
auto x = m.add_parameter("x", s0);
auto y = m.add_parameter("y", s1);
std::vector<std::size_t> mb_lens0 = {60, 64, 64};
std::vector<std::size_t> mb_lens1 = {60, 64, 192};
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x);
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y);
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y);
m.add_return({concat_xy});
auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add TODO comments on these test that we will simplify them in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand what you mean by simplify these in the future. I'm only intending to change what the comment says.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean when we rewrite the matcher for doing a broadcast before if there's atleast one common broadcast axis?


TEST_CASE(concat_transpose1)
{
migraphx::module m;
Expand Down
Loading