diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 169ca37ad94..25f1a7cba93 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -723,8 +723,6 @@ struct find_dot_broadcast auto ins = r.result; auto a = ins->inputs()[0]; auto b = ins->inputs()[1]; - if(a->get_operator().name() != b->get_operator().name()) - return; if(ins->get_shape().lens().size() < 3) return; auto nbatch_axes = ins->get_shape().lens().size() - 2; @@ -742,10 +740,13 @@ struct find_dot_broadcast std::vector axes(naxes); std::iota(axes.begin(), axes.end(), 0); - auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref { - auto input = b_ins->inputs()[0]; - std::vector lens(b_ins->get_shape().lens().begin() + naxes, - b_ins->get_shape().lens().end()); + auto insert_broadcast = [&](instruction_ref x_ins) -> instruction_ref { + auto input = x_ins->inputs()[0]; + std::vector lens(x_ins->get_shape().lens().begin() + naxes, + x_ins->get_shape().lens().end()); + + if(input->get_shape().lens() == lens) + return input; auto input_naxis = input->get_shape().lens().size(); auto new_bc_naxis = lens.size(); @@ -756,14 +757,15 @@ struct find_dot_broadcast input = m.insert_instruction(ins, make_op("squeeze", {{"axes", axes_to_sq}}), input); } - if(b_ins->name() == "multibroadcast") + + if(x_ins->name() == "multibroadcast") { return m.insert_instruction( ins, make_op("multibroadcast", {{"out_lens", lens}}), input); } - else if(b_ins->name() == "broadcast") + else if(x_ins->name() == "broadcast") { - auto v = b_ins->get_operator().to_value(); + auto v = x_ins->get_operator().to_value(); auto axis = v.at("axis").to() - naxes; return m.insert_instruction( ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input); diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 2aad43a5acb..aa07b6f8866 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -3569,6 +3569,65 @@ TEST_CASE(reorder_slice_ins_deps) EXPECT(m == create_module()); } +TEST_CASE(dot_broadcast_different_broadcast1) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {64}}); + auto y = m1.add_parameter("y", {migraphx::shape::float_type, {64, 64}}); + auto xb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 4, 4, 64}}}), x); + auto yb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 64, 64}}}), y); + auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb); + m1.add_return({dot}); + }; + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {64}}); + auto y = m2.add_parameter("y", {migraphx::shape::float_type, {64, 64}}); + auto xb = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {4, 64}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 4, 64}}}), dot); + m2.add_return({broadcast}); + }; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(dot_broadcast_different_broadcast2) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {384}}); + auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}}); + auto xb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 384, 768}}}), x); + auto yb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y); + auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb); + m1.add_return({dot}); + }; + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {384}}); + auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}}); + auto xb = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 768}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot); + m2.add_return({broadcast}); + }; + + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(dot_broadcast_different_rank) { migraphx::module m1; @@ -3589,9 +3648,7 @@ TEST_CASE(dot_broadcast_different_rank) auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}}); auto xb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x); - auto yb = - m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y); - auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb); + auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y); auto broadcast = m2.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot); m2.add_return({broadcast}); @@ -3622,9 +3679,7 @@ TEST_CASE(dot_broadcast_unsqueezed_input) auto x_sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), x); auto xb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 8}}}), x_sq); - auto yb = - m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 8}}}), y); - auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb); + auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y); auto broadcast = m2.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 8, 8}}}), dot); m2.add_return({broadcast});