From 62c303a859344c31920b2dd1f7dd2d4e6b7ff2ba Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 15 May 2024 15:18:53 -0500 Subject: [PATCH 01/10] Add check that inputs to mb have same dimensions other than the concat axis --- src/simplify_reshapes.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a651d6e2432..e3a90956f86 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -267,11 +267,27 @@ struct find_concat_multibroadcasts return; } - // Use inputs of multibroadcast ops as inputs to new concat op + // Get the inputs of multibroadcast ops, will be used as inputs to new concat op std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) { return i->inputs().front(); }); + // inputs to multibroadcasts should have the same dimensions except for the axis to + // concatenate n skips apply if not true + const auto& first_out_lens = inputs.front()->get_shape().lens(); + for(std::size_t ax = 0; ax < first_out_lens.size(); ++ax) + { + if(ax != op.axis) + { + if(not std::all_of(inputs.begin(), inputs.end(), [&](auto input_to_mb) { + return input_to_mb->get_shape().lens()[ax] == first_out_lens[ax]; + })) + { + return; + } + } + } + // Reduce axis by number of leading broadcasted dimensions if(inputs.front()->get_shape().lens().size() < out_lens.size()) op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0); From 647afcc49d2dd2d992af2b484662590ca0ba8841 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 16 May 2024 11:15:36 -0500 Subject: [PATCH 02/10] Fix bug and add tests --- src/simplify_reshapes.cpp | 67 ++++++++++++++++++++++----------- test/simplify_reshapes_test.cpp | 45 +++++++++++++++++++++- 2 files changed, 89 insertions(+), 23 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index e3a90956f86..f7d06213e0b 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -244,6 +244,21 @@ struct find_nested_slice } }; +/** + * Example case + * From: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * mb0: multibroadcast(param0, output_lens = [2, 3, 4]) + * mb1: multibroadcast(param1, output_lens = [2, 3, 4]) + * concat(mb0, mb1, axis = 2) + * + * To: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * con0: concat(param0, param1, axis = 1) + * multibroadcast(con0, lens = [2, 3, 4]) + */ struct find_concat_multibroadcasts { auto matcher() const @@ -253,34 +268,46 @@ struct find_concat_multibroadcasts void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto op = any_cast(ins->get_operator()); - auto out_lens = ins->get_shape().lens(); - auto inputs = ins->inputs(); - auto in_strides = inputs.front()->get_shape().strides(); + auto concat_ins = mr.result; + auto concat_op = any_cast(concat_ins->get_operator()); + auto concat_out_lens = concat_ins->get_shape().lens(); + auto concat_inputs = concat_ins->inputs(); + auto front_mb_strides = concat_inputs.front()->get_shape().strides(); // Only apply when concat axis is not a broadcasted dimension - if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { - return i->get_shape().strides()[op.axis] == 0; + if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) { + return i->get_shape().strides()[concat_op.axis] == 0; })) { return; } - + // Get the inputs of multibroadcast ops, will be used as inputs to new concat op - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) { + std::vector mb_inputs(concat_inputs.size()); + std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { return i->inputs().front(); }); - // inputs to multibroadcasts should have the same dimensions except for the axis to - // concatenate n skips apply if not true - const auto& first_out_lens = inputs.front()->get_shape().lens(); - for(std::size_t ax = 0; ax < first_out_lens.size(); ++ax) + // Check that the inputs into the multibroadcasts have the same rank + const auto& first_shape = mb_inputs.front()->get_shape(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { return mb_in->get_shape().ndim() == first_shape.ndim(); })) + { + return; + } + + // Reduce axis by number of leading broadcasted dimensions + if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) + concat_op.axis -= std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); + + // Inputs to multibroadcasts should have the same dimensions except for the axis to + // concatenate over + const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); + for(std::size_t ax = 0; ax < front_in_lens.size(); ++ax) { - if(ax != op.axis) + if(ax != concat_op.axis) { - if(not std::all_of(inputs.begin(), inputs.end(), [&](auto input_to_mb) { - return input_to_mb->get_shape().lens()[ax] == first_out_lens[ax]; + if(not std::all_of(mb_inputs.begin(), mb_inputs.end(), [&](auto input_to_mb) { + return input_to_mb->get_shape().lens()[ax] == front_in_lens[ax]; })) { return; @@ -288,13 +315,9 @@ struct find_concat_multibroadcasts } } - // Reduce axis by number of leading broadcasted dimensions - if(inputs.front()->get_shape().lens().size() < out_lens.size()) - op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0); - - auto concat = m.insert_instruction(ins, op, inputs); + auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); m.replace_instruction( - ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat); + concat_ins, migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), new_concat_ins); } }; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index f28c3ff8bde..1d4f1541918 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3) EXPECT(new_concat->get_operator().to_value()["axis"].to() == 2); } +// Broadcasted batch dim, axis is broadcasted dim +// skips this case TEST_CASE(concat_multibroadcasts4) { - // Broadcasted batch dim, axis is broadcasted dim std::vector in_lens = {3, 4}; std::vector mbcast_lens = {2, 3, 4}; const int axis = 0; @@ -930,6 +931,48 @@ TEST_CASE(concat_multibroadcasts4) EXPECT(m1 == m); } +// different input parameter shapes +// dimensions other than concat axis do not match +// skips this case +TEST_CASE(concat_multibroadcasts5) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {1, 12, 60, 64, 64}; + std::vector mb_lens1 = {1, 12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// concat axis moved because rank(in_dims) < rank(out_dims) +// parameter inputs are not the same rank +// skips this case +TEST_CASE(concat_multibroadcasts6) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {1, 12, 60, 64, 64}; + std::vector mb_lens1 = {1, 12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + TEST_CASE(concat_transpose1) { migraphx::module m; From 784392e9202b5525463aa9088748cb255de961c6 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 16 May 2024 11:15:50 -0500 Subject: [PATCH 03/10] formatting --- src/simplify_reshapes.cpp | 24 ++++++++++++++---------- test/simplify_reshapes_test.cpp | 12 ++++++------ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index f7d06213e0b..a579cba6179 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -268,11 +268,11 @@ struct find_concat_multibroadcasts void apply(module& m, const match::matcher_result& mr) const { - auto concat_ins = mr.result; - auto concat_op = any_cast(concat_ins->get_operator()); - auto concat_out_lens = concat_ins->get_shape().lens(); - auto concat_inputs = concat_ins->inputs(); - auto front_mb_strides = concat_inputs.front()->get_shape().strides(); + auto concat_ins = mr.result; + auto concat_op = any_cast(concat_ins->get_operator()); + auto concat_out_lens = concat_ins->get_shape().lens(); + auto concat_inputs = concat_ins->inputs(); + auto front_mb_strides = concat_inputs.front()->get_shape().strides(); // Only apply when concat axis is not a broadcasted dimension if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) { @@ -281,7 +281,7 @@ struct find_concat_multibroadcasts { return; } - + // Get the inputs of multibroadcast ops, will be used as inputs to new concat op std::vector mb_inputs(concat_inputs.size()); std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { @@ -290,14 +290,17 @@ struct find_concat_multibroadcasts // Check that the inputs into the multibroadcasts have the same rank const auto& first_shape = mb_inputs.front()->get_shape(); - if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { return mb_in->get_shape().ndim() == first_shape.ndim(); })) + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { + return mb_in->get_shape().ndim() == first_shape.ndim(); + })) { return; } // Reduce axis by number of leading broadcasted dimensions if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) - concat_op.axis -= std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); + concat_op.axis -= + std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); // Inputs to multibroadcasts should have the same dimensions except for the axis to // concatenate over @@ -316,8 +319,9 @@ struct find_concat_multibroadcasts } auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); - m.replace_instruction( - concat_ins, migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), new_concat_ins); + m.replace_instruction(concat_ins, + migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), + new_concat_ins); } }; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 1d4f1541918..04e3c80af35 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -939,8 +939,8 @@ TEST_CASE(concat_multibroadcasts5) migraphx::module m; auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}}; auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}}; - auto x = m.add_parameter("x", s0); - auto y = m.add_parameter("y", s1); + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); std::vector mb_lens0 = {1, 12, 60, 64, 64}; std::vector mb_lens1 = {1, 12, 60, 64, 192}; auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); @@ -958,10 +958,10 @@ TEST_CASE(concat_multibroadcasts5) TEST_CASE(concat_multibroadcasts6) { migraphx::module m; - auto s0 = migraphx::shape{migraphx::shape::float_type, {64}}; - auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; - auto x = m.add_parameter("x", s0); - auto y = m.add_parameter("y", s1); + auto s0 = migraphx::shape{migraphx::shape::float_type, {64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); std::vector mb_lens0 = {1, 12, 60, 64, 64}; std::vector mb_lens1 = {1, 12, 60, 64, 192}; auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); From a9d03d0c31959938d6dd79bf300760d1b18a4a4b Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Thu, 16 May 2024 16:02:07 -0400 Subject: [PATCH 04/10] Update test/simplify_reshapes_test.cpp Co-authored-by: Brian Pickrell <95253842+bpickrel@users.noreply.github.com> --- test/simplify_reshapes_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 04e3c80af35..06c133aabeb 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -919,7 +919,7 @@ TEST_CASE(concat_multibroadcasts3) } // Broadcasted batch dim, axis is broadcasted dim -// skips this case +// matched by find_concat_multibroadcasts but it skips this case TEST_CASE(concat_multibroadcasts4) { std::vector in_lens = {3, 4}; From 6ad1822d2ca9bb8a88ca2cf0b58e741c0b2bff51 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Thu, 16 May 2024 16:02:14 -0400 Subject: [PATCH 05/10] Update test/simplify_reshapes_test.cpp Co-authored-by: Brian Pickrell <95253842+bpickrel@users.noreply.github.com> --- test/simplify_reshapes_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 06c133aabeb..5aa96a111ca 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -932,7 +932,7 @@ TEST_CASE(concat_multibroadcasts4) } // different input parameter shapes -// dimensions other than concat axis do not match +// matched by find_concat_multibroadcasts but dimensions other than concat axis do not match // skips this case TEST_CASE(concat_multibroadcasts5) { From 5367222a37020ec4bca8adaf0122045d01ce32aa Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Thu, 16 May 2024 16:02:20 -0400 Subject: [PATCH 06/10] Update test/simplify_reshapes_test.cpp Co-authored-by: Brian Pickrell <95253842+bpickrel@users.noreply.github.com> --- test/simplify_reshapes_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 5aa96a111ca..b9e2c12f547 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -953,7 +953,7 @@ TEST_CASE(concat_multibroadcasts5) } // concat axis moved because rank(in_dims) < rank(out_dims) -// parameter inputs are not the same rank +// matched by find_concat_multibroadcasts but parameter inputs are not the same rank // skips this case TEST_CASE(concat_multibroadcasts6) { From bfb64531d51caddd7746c225786f68c4fda03d2e Mon Sep 17 00:00:00 2001 From: charlie Date: Sat, 18 May 2024 10:06:39 -0400 Subject: [PATCH 07/10] simplfy algo --- src/simplify_reshapes.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a579cba6179..dc029821309 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -282,7 +282,7 @@ struct find_concat_multibroadcasts return; } - // Get the inputs of multibroadcast ops, will be used as inputs to new concat op + // Get the inputs of multibroadcast ops. Will be used as inputs to new concat op std::vector mb_inputs(concat_inputs.size()); std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { return i->inputs().front(); @@ -307,15 +307,9 @@ struct find_concat_multibroadcasts const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); for(std::size_t ax = 0; ax < front_in_lens.size(); ++ax) { - if(ax != concat_op.axis) - { - if(not std::all_of(mb_inputs.begin(), mb_inputs.end(), [&](auto input_to_mb) { - return input_to_mb->get_shape().lens()[ax] == front_in_lens[ax]; - })) - { - return; - } - } + const auto& lens = input_to_mb->get_shape().lens(); + return std::equal(lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and + std::equal(lens.begin() + concat_op.axis + 1, lens.end(), front_in_lens.begin() + concat_op.axis + 1); } auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); From ed6057eb3d0ed41b0c65414bcd5d078edd8bc113 Mon Sep 17 00:00:00 2001 From: charlie Date: Sat, 18 May 2024 10:09:43 -0500 Subject: [PATCH 08/10] Fixes and more tests --- src/simplify_reshapes.cpp | 15 +++++-- test/simplify_reshapes_test.cpp | 76 ++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index dc029821309..a00178ca016 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -299,17 +299,24 @@ struct find_concat_multibroadcasts // Reduce axis by number of leading broadcasted dimensions if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) + { concat_op.axis -= std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); + } // Inputs to multibroadcasts should have the same dimensions except for the axis to // concatenate over const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); - for(std::size_t ax = 0; ax < front_in_lens.size(); ++ax) + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) { + const auto& lens = input_to_mb->get_shape().lens(); + return std::equal( + lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and + std::equal(lens.begin() + concat_op.axis + 1, + lens.end(), + front_in_lens.begin() + concat_op.axis + 1); + })) { - const auto& lens = input_to_mb->get_shape().lens(); - return std::equal(lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and - std::equal(lens.begin() + concat_op.axis + 1, lens.end(), front_in_lens.begin() + concat_op.axis + 1); + return; } auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index b9e2c12f547..efc385f0959 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -931,9 +931,8 @@ TEST_CASE(concat_multibroadcasts4) EXPECT(m1 == m); } -// different input parameter shapes -// matched by find_concat_multibroadcasts but dimensions other than concat axis do not match -// skips this case +// Matched by find_concat_multibroadcasts but skipped because dimensions other than concat axis do +// not match TEST_CASE(concat_multibroadcasts5) { migraphx::module m; @@ -952,9 +951,8 @@ TEST_CASE(concat_multibroadcasts5) EXPECT(m == m_original); } -// concat axis moved because rank(in_dims) < rank(out_dims) -// matched by find_concat_multibroadcasts but parameter inputs are not the same rank -// skips this case +// Matched by find_concat_multibroadcasts but skipped because parameter inputs are not the same +// rank. TEST_CASE(concat_multibroadcasts6) { migraphx::module m; @@ -962,6 +960,29 @@ TEST_CASE(concat_multibroadcasts6) auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; auto x = m.add_parameter("x", s0); auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {12, 60, 64, 64}; + std::vector mb_lens1 = {12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Concat axis moved to 2 because rank(in_dims) < rank(out_dims) +// Matched by find_concat_multibroadcasts but skipped because the dimensions +// other than the concat axis are not the same. +// TODO: has common broadcast axes, so can be simplified by moving multibroadcast up to have a +// smaller concat. +TEST_CASE(concat_multibroadcasts7) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); std::vector mb_lens0 = {1, 12, 60, 64, 64}; std::vector mb_lens1 = {1, 12, 60, 64, 192}; auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); @@ -973,6 +994,49 @@ TEST_CASE(concat_multibroadcasts6) EXPECT(m == m_original); } +// Shape of inputs to multibroadcasts do not have the same rank. +// Matched by find_concat_multibroadcasts but skipped. +// TODO: has a common broadcast axis, so can be simplified by moving multibroadcast up to have a +// smaller concat. +TEST_CASE(concat_multibroadcasts8) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {64, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {60, 64, 64}; + std::vector mb_lens1 = {60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Shape of inputs to multibroadcasts do not have a common broadcast axis. +// Matched by find_concat_multibroadcasts, but skipped because the dimensions other than +// the concat axis are not the same. +TEST_CASE(concat_multibroadcasts9) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 64, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {60, 64, 64}; + std::vector mb_lens1 = {60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + TEST_CASE(concat_transpose1) { migraphx::module m; From ebf469780967e7c00a39fcd3fa66ecd0e097fcaa Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 20 May 2024 14:09:36 -0500 Subject: [PATCH 09/10] formatting --- test/simplify_reshapes_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index efc385f0959..f0d100821ab 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -938,8 +938,8 @@ TEST_CASE(concat_multibroadcasts5) migraphx::module m; auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}}; auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}}; - auto x = m.add_parameter("x", s0); - auto y = m.add_parameter("y", s1); + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); std::vector mb_lens0 = {1, 12, 60, 64, 64}; std::vector mb_lens1 = {1, 12, 60, 64, 192}; auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); From b3315bf5fe909de1fc8335c24eec64d1f3754bfb Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 20 May 2024 14:16:05 -0500 Subject: [PATCH 10/10] add assert on concat axis --- src/simplify_reshapes.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a00178ca016..a0a952d6aac 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -273,6 +273,7 @@ struct find_concat_multibroadcasts auto concat_out_lens = concat_ins->get_shape().lens(); auto concat_inputs = concat_ins->inputs(); auto front_mb_strides = concat_inputs.front()->get_shape().strides(); + assert(concat_op.axis >= 0); // Only apply when concat axis is not a broadcasted dimension if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) {