-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Concat - multibroadcast fix #3096
Changes from all commits
62c303a
647afcc
784392e
e271392
a9d03d0
6ad1822
5367222
bfb6453
ed6057e
d0a3d5e
ebf4697
b3315bf
b314f90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3) | |
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2); | ||
} | ||
|
||
// Broadcasted batch dim, axis is broadcasted dim | ||
// matched by find_concat_multibroadcasts but it skips this case | ||
TEST_CASE(concat_multibroadcasts4) | ||
{ | ||
// Broadcasted batch dim, axis is broadcasted dim | ||
std::vector<std::size_t> in_lens = {3, 4}; | ||
std::vector<std::size_t> mbcast_lens = {2, 3, 4}; | ||
const int axis = 0; | ||
|
@@ -930,6 +931,112 @@ TEST_CASE(concat_multibroadcasts4) | |
EXPECT(m1 == m); | ||
} | ||
|
||
// Matched by find_concat_multibroadcasts but skipped because dimensions other than concat axis do | ||
// not match | ||
TEST_CASE(concat_multibroadcasts5) | ||
{ | ||
migraphx::module m; | ||
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}}; | ||
auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}}; | ||
auto x = m.add_parameter("x", s0); | ||
auto y = m.add_parameter("y", s1); | ||
std::vector<std::size_t> mb_lens0 = {1, 12, 60, 64, 64}; | ||
std::vector<std::size_t> mb_lens1 = {1, 12, 60, 64, 192}; | ||
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); | ||
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); | ||
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); | ||
m.add_return({concat_xy}); | ||
auto m_original = m; | ||
run_pass(m); | ||
EXPECT(m == m_original); | ||
} | ||
|
||
// Matched by find_concat_multibroadcasts but skipped because parameter inputs are not the same | ||
// rank. | ||
TEST_CASE(concat_multibroadcasts6) | ||
{ | ||
migraphx::module m; | ||
auto s0 = migraphx::shape{migraphx::shape::float_type, {64}}; | ||
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; | ||
auto x = m.add_parameter("x", s0); | ||
auto y = m.add_parameter("y", s1); | ||
std::vector<std::size_t> mb_lens0 = {12, 60, 64, 64}; | ||
std::vector<std::size_t> mb_lens1 = {12, 60, 64, 192}; | ||
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); | ||
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); | ||
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), mb_x, mb_y); | ||
m.add_return({concat_xy}); | ||
auto m_original = m; | ||
run_pass(m); | ||
EXPECT(m == m_original); | ||
} | ||
|
||
// Concat axis moved to 2 because rank(in_dims) < rank(out_dims) | ||
// Matched by find_concat_multibroadcasts but skipped because the dimensions | ||
// other than the concat axis are not the same. | ||
// TODO: has common broadcast axes, so can be simplified by moving multibroadcast up to have a | ||
// smaller concat. | ||
TEST_CASE(concat_multibroadcasts7) | ||
{ | ||
migraphx::module m; | ||
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 64}}; | ||
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; | ||
auto x = m.add_parameter("x", s0); | ||
auto y = m.add_parameter("y", s1); | ||
std::vector<std::size_t> mb_lens0 = {1, 12, 60, 64, 64}; | ||
std::vector<std::size_t> mb_lens1 = {1, 12, 60, 64, 192}; | ||
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); | ||
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); | ||
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); | ||
m.add_return({concat_xy}); | ||
auto m_original = m; | ||
run_pass(m); | ||
EXPECT(m == m_original); | ||
} | ||
|
||
// Shape of inputs to multibroadcasts do not have the same rank. | ||
// Matched by find_concat_multibroadcasts but skipped. | ||
// TODO: has a common broadcast axis, so can be simplified by moving multibroadcast up to have a | ||
// smaller concat. | ||
TEST_CASE(concat_multibroadcasts8) | ||
{ | ||
migraphx::module m; | ||
auto s0 = migraphx::shape{migraphx::shape::float_type, {64, 64}}; | ||
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; | ||
auto x = m.add_parameter("x", s0); | ||
auto y = m.add_parameter("y", s1); | ||
std::vector<std::size_t> mb_lens0 = {60, 64, 64}; | ||
std::vector<std::size_t> mb_lens1 = {60, 64, 192}; | ||
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); | ||
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); | ||
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); | ||
m.add_return({concat_xy}); | ||
auto m_original = m; | ||
run_pass(m); | ||
EXPECT(m == m_original); | ||
} | ||
|
||
// Shape of inputs to multibroadcasts do not have a common broadcast axis. | ||
// Matched by find_concat_multibroadcasts, but skipped because the dimensions other than | ||
// the concat axis are not the same. | ||
TEST_CASE(concat_multibroadcasts9) | ||
{ | ||
migraphx::module m; | ||
auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 64, 64}}; | ||
auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; | ||
auto x = m.add_parameter("x", s0); | ||
auto y = m.add_parameter("y", s1); | ||
std::vector<std::size_t> mb_lens0 = {60, 64, 64}; | ||
std::vector<std::size_t> mb_lens1 = {60, 64, 192}; | ||
auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); | ||
auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); | ||
auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); | ||
m.add_return({concat_xy}); | ||
auto m_original = m; | ||
run_pass(m); | ||
EXPECT(m == m_original); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add TODO comments on these test that we will simplify them in the future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand what you mean by simplify these in the future. I'm only intending to change what the comment says. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean when we rewrite the matcher for doing a broadcast before if there's atleast one common broadcast axis? |
||
|
||
TEST_CASE(concat_transpose1) | ||
{ | ||
migraphx::module m; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the
simplify_reshapes
matchers each test is intended to match. There's currently also no way to check whether a test that's expected to match but then exit without doing anything actually matched. For now, suggest you add to the descriptive comment for each test: what matcher it's supposed to match (or not match). The current test names are very similar tostruct find_concat_multibroadcasts
but just different enough that the intent might not be clear to someone new.A similar problem would come up if the order of passes in
simplify_reshapes
changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's possible, but quite unlikely given how small these test cases are. The only way I can think of to actually prevent this is to have passes that only run one matcher.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow your logic here. All we would have to do is turn on
MIGRAPHX_TRACE_MATCHES
.