Skip to content

Commit

Permalink
Handle different broadcasts operators in find_dot_broadcast (#3188)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Jun 17, 2024
1 parent 5b82104 commit 2c4dd4a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 15 deletions.
20 changes: 11 additions & 9 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,6 @@ struct find_dot_broadcast
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
Expand All @@ -742,10 +740,13 @@ struct find_dot_broadcast
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);

auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
auto insert_broadcast = [&](instruction_ref x_ins) -> instruction_ref {
auto input = x_ins->inputs()[0];
std::vector<std::size_t> lens(x_ins->get_shape().lens().begin() + naxes,
x_ins->get_shape().lens().end());

if(input->get_shape().lens() == lens)
return input;

auto input_naxis = input->get_shape().lens().size();
auto new_bc_naxis = lens.size();
Expand All @@ -756,14 +757,15 @@ struct find_dot_broadcast
input =
m.insert_instruction(ins, make_op("squeeze", {{"axes", axes_to_sq}}), input);
}
if(b_ins->name() == "multibroadcast")

if(x_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
else if(x_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto v = x_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
Expand Down
67 changes: 61 additions & 6 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3569,6 +3569,65 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT(m == create_module());
}

TEST_CASE(dot_broadcast_different_broadcast1)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {64}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {64, 64}});
auto xb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 4, 4, 64}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 64, 64}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot});
};

migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {64}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {64, 64}});
auto xb = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {4, 64}}}), x);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 4, 64}}}), dot);
m2.add_return({broadcast});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_broadcast_different_broadcast2)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {384}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 384, 768}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot});
};

migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {384}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 768}}}), x);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast});
};

run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_broadcast_different_rank)
{
migraphx::module m1;
Expand All @@ -3589,9 +3648,7 @@ TEST_CASE(dot_broadcast_different_rank)
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast});
Expand Down Expand Up @@ -3622,9 +3679,7 @@ TEST_CASE(dot_broadcast_unsqueezed_input)
auto x_sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), x);
auto xb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 8}}}), x_sq);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 8}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, y);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 8, 8}}}), dot);
m2.add_return({broadcast});
Expand Down

0 comments on commit 2c4dd4a

Please sign in to comment.