Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent collapsing batch dims in dot ops with constants #2823

Merged
merged 13 commits into from
May 31, 2024
Merged
104 changes: 104 additions & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,38 @@
}
};

// Remove unnecessary preceeding size 1 dims for constants
struct find_const_multibroadcast
{
auto matcher() const
{
return match::name("multibroadcast")(
match::arg(0)(match::is_constant()(match::used_once()).bind("constant")));
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
}

void apply(module& m, const match::matcher_result& mr) const
{
auto mbr = mr.result;
auto constant = mr.instructions["constant"];

if(constant->get_shape().lens().size() <= 1)
return;

auto const_lens = constant->get_shape().lens();
auto it = std::find_if(const_lens.begin(), const_lens.end(), [](auto i) { return i != 1; });
auto naxes = std::distance(const_lens.begin(), it);
if(naxes == 0)
return;

std::vector<std::size_t> sq_axes(naxes);
std::iota(sq_axes.begin(), sq_axes.end(), 0);

auto sq_const =
m.insert_instruction(mbr, make_op("squeeze", {{"axes", sq_axes}}), constant);
m.replace_instruction(mbr, mbr->get_operator(), sq_const);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we replace it with broadcast instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just removing any unnecessary preceding dims in literals eg. {1, 1, 640, 640) which are later broadcasted to something like {2, 32, 640, 640}. Would broadcast work for this? I thought it only does 1 axis

}
};
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved

struct find_reshape_reshape_dot
{
auto matcher() const
Expand Down Expand Up @@ -949,6 +981,76 @@
}
};

struct find_reshape_const_dot
{
auto matcher() const
{
return match::name("dot")(
match::used_once(),
match::either_arg(0, 1)(match::name("reshape").bind("rsp"),
match::skip_broadcasts(match::is_constant().bind("constant"))));
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
}

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto rsp = r.instructions["rsp"];
auto constant = r.instructions["constant"];

auto const_lens = constant->get_shape().lens();
if(const_lens.size() > 2)
return;

auto rsp_lens = rsp->get_shape().lens();
auto inp = rsp->inputs().front();
auto inp_lens = inp->get_shape().lens();

// Gemm axis should not be altered by the reshape
bool flipped = rsp == dot->inputs().back();
size_t dot_axis = (flipped) ? -2 : -1;
if(rsp_lens.end()[dot_axis] != inp_lens.end()[dot_axis])

Check warning on line 1011 in src/simplify_reshapes.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

warning: Iterators to containers from different expressions 'rsp_lens' and 'inp_lens' are used together. [mismatchingContainerExpression]
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
return;

std::vector<size_t> new_const_lens{inp_lens.begin(), inp_lens.end() - 2};
operation new_bc_op;

auto bc_const = (flipped) ? dot->inputs().front() : dot->inputs().back();
auto bc_const_lens = bc_const->get_shape().lens();
new_const_lens.insert(new_const_lens.end(), bc_const_lens.end() - 2, bc_const_lens.end());

// if the orignal weight is one dimensional, look at the original broadcast
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
// to determine the correct broadcast axis
if(const_lens.size() == 1)
{
auto bc_const_strides = bc_const->get_shape().strides();
auto it = std::find_if(
bc_const_strides.begin(), bc_const_strides.end(), [&](auto i) { return i != 0; });
auto orig_bc_axis = std::distance(bc_const_strides.begin(), it);

auto new_bc_axis = new_const_lens.size() - (bc_const_lens.size() - orig_bc_axis);
new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_const_lens}});
}
else
{
new_bc_op = make_op("multibroadcast", {{"out_lens", new_const_lens}});
}

auto new_bc_const = m.insert_instruction(dot, new_bc_op, constant);

instruction_ref new_dot;
if(flipped)
{
new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_const, inp);
}
else
{
new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_const);
}
m.replace_instruction(
dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot);
}
};

void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < depth; i++)
Expand All @@ -960,6 +1062,7 @@
find_reshaper{},
find_reshape_cont{},
find_transpose{},
find_const_multibroadcast{},
find_concat_slice{},
find_concat_transpose{},
find_concat_multibroadcasts{},
Expand All @@ -970,6 +1073,7 @@
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{},
find_reshape_reshape_dot{},
find_reshape_const_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
}
Expand Down
210 changes: 208 additions & 2 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1665,8 +1665,8 @@ TEST_CASE(transpose_contiguous_squeeze_unary)
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
m2.add_instruction(pass_op{}, sq_ins);
}
EXPECT(m1 == m2);
Expand Down Expand Up @@ -2087,4 +2087,210 @@ TEST_CASE(reshape_reshape_dot_gemm_axis)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(const_multibroadcast)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 64, 1}};
migraphx::module m1;
{
auto a = m1.add_literal(migraphx::generate_literal(s));
auto mbc = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 64, 64}}}), a);
m1.add_return({mbc});
};
run_pass(m1);

migraphx::module m2;
{
auto a = m2.add_literal(migraphx::generate_literal(s));
auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), a);
auto mbc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 64, 64}}}), sq);
m2.add_return({mbc});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(const_multibroadcast_no_apply)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1, 64, 1}};
migraphx::module m1;
{
auto a = m1.add_literal(migraphx::generate_literal(s));
auto mbc = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 64, 64}}}), a);
m1.add_return({mbc});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_const_dot)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32, 32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_const_dot_flipped)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {16, 8}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 8, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16, 16, 8}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 16, 8}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), w_bc, inp);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 16, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_const_dot_dot_axis)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 4}};
migraphx::shape s_w{migraphx::shape::float_type, {32, 32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_const_dot_flipped_dot_axis)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {8, 64}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 64}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp);
m1.add_return({dot});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_const_dot_broadcast)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_const_dot_broadcast_2)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading