From 1f07af9ed893c9af7e5550cf4c8e3c10a8d6b458 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Tue, 21 May 2024 05:39:10 -0400 Subject: [PATCH] Concat - multibroadcast fix (#3096) --- src/simplify_reshapes.cpp | 73 +++++++++++++++++---- test/simplify_reshapes_test.cpp | 109 +++++++++++++++++++++++++++++++- 2 files changed, 167 insertions(+), 15 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a651d6e2432..a0a952d6aac 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -244,6 +244,21 @@ struct find_nested_slice } }; +/** + * Example case + * From: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * mb0: multibroadcast(param0, output_lens = [2, 3, 4]) + * mb1: multibroadcast(param1, output_lens = [2, 3, 4]) + * concat(mb0, mb1, axis = 2) + * + * To: + * param0: lens = [3, 4], strides = [4, 1] + * param1: lens = [3, 4], strides = [4, 1] + * con0: concat(param0, param1, axis = 1) + * multibroadcast(con0, lens = [2, 3, 4]) + */ struct find_concat_multibroadcasts { auto matcher() const @@ -253,32 +268,62 @@ struct find_concat_multibroadcasts void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto op = any_cast(ins->get_operator()); - auto out_lens = ins->get_shape().lens(); - auto inputs = ins->inputs(); - auto in_strides = inputs.front()->get_shape().strides(); + auto concat_ins = mr.result; + auto concat_op = any_cast(concat_ins->get_operator()); + auto concat_out_lens = concat_ins->get_shape().lens(); + auto concat_inputs = concat_ins->inputs(); + auto front_mb_strides = concat_inputs.front()->get_shape().strides(); + assert(concat_op.axis >= 0); // Only apply when concat axis is not a broadcasted dimension - if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { - return i->get_shape().strides()[op.axis] == 0; + if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) { + return i->get_shape().strides()[concat_op.axis] == 0; })) { return; } - // Use inputs of multibroadcast ops as inputs to new concat op - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) { + // Get the inputs of multibroadcast ops. Will be used as inputs to new concat op + std::vector mb_inputs(concat_inputs.size()); + std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { return i->inputs().front(); }); + // Check that the inputs into the multibroadcasts have the same rank + const auto& first_shape = mb_inputs.front()->get_shape(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { + return mb_in->get_shape().ndim() == first_shape.ndim(); + })) + { + return; + } + // Reduce axis by number of leading broadcasted dimensions - if(inputs.front()->get_shape().lens().size() < out_lens.size()) - op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0); + if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) + { + concat_op.axis -= + std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); + } - auto concat = m.insert_instruction(ins, op, inputs); - m.replace_instruction( - ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat); + // Inputs to multibroadcasts should have the same dimensions except for the axis to + // concatenate over + const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); + if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) { + const auto& lens = input_to_mb->get_shape().lens(); + return std::equal( + lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and + std::equal(lens.begin() + concat_op.axis + 1, + lens.end(), + front_in_lens.begin() + concat_op.axis + 1); + })) + { + return; + } + + auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); + m.replace_instruction(concat_ins, + migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), + new_concat_ins); } }; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index f28c3ff8bde..f0d100821ab 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3) EXPECT(new_concat->get_operator().to_value()["axis"].to() == 2); } +// Broadcasted batch dim, axis is broadcasted dim +// matched by find_concat_multibroadcasts but it skips this case TEST_CASE(concat_multibroadcasts4) { - // Broadcasted batch dim, axis is broadcasted dim std::vector in_lens = {3, 4}; std::vector mbcast_lens = {2, 3, 4}; const int axis = 0; @@ -930,6 +931,112 @@ TEST_CASE(concat_multibroadcasts4) EXPECT(m1 == m); } +// Matched by find_concat_multibroadcasts but skipped because dimensions other than concat axis do +// not match +TEST_CASE(concat_multibroadcasts5) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 1, 60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {1, 12, 60, 64, 64}; + std::vector mb_lens1 = {1, 12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Matched by find_concat_multibroadcasts but skipped because parameter inputs are not the same +// rank. +TEST_CASE(concat_multibroadcasts6) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {12, 60, 64, 64}; + std::vector mb_lens1 = {12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Concat axis moved to 2 because rank(in_dims) < rank(out_dims) +// Matched by find_concat_multibroadcasts but skipped because the dimensions +// other than the concat axis are not the same. +// TODO: has common broadcast axes, so can be simplified by moving multibroadcast up to have a +// smaller concat. +TEST_CASE(concat_multibroadcasts7) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 1, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 64, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {1, 12, 60, 64, 64}; + std::vector mb_lens1 = {1, 12, 60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Shape of inputs to multibroadcasts do not have the same rank. +// Matched by find_concat_multibroadcasts but skipped. +// TODO: has a common broadcast axis, so can be simplified by moving multibroadcast up to have a +// smaller concat. +TEST_CASE(concat_multibroadcasts8) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {64, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {60, 64, 64}; + std::vector mb_lens1 = {60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + +// Shape of inputs to multibroadcasts do not have a common broadcast axis. +// Matched by find_concat_multibroadcasts, but skipped because the dimensions other than +// the concat axis are not the same. +TEST_CASE(concat_multibroadcasts9) +{ + migraphx::module m; + auto s0 = migraphx::shape{migraphx::shape::float_type, {1, 64, 64}}; + auto s1 = migraphx::shape{migraphx::shape::float_type, {60, 1, 192}}; + auto x = m.add_parameter("x", s0); + auto y = m.add_parameter("y", s1); + std::vector mb_lens0 = {60, 64, 64}; + std::vector mb_lens1 = {60, 64, 192}; + auto mb_x = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens0}}), x); + auto mb_y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_lens1}}), y); + auto concat_xy = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), mb_x, mb_y); + m.add_return({concat_xy}); + auto m_original = m; + run_pass(m); + EXPECT(m == m_original); +} + TEST_CASE(concat_transpose1) { migraphx::module m;