Skip to content

Commit

Permalink
combine reshape-dot matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Mar 4, 2024
1 parent efa81d9 commit 6dc309c
Showing 1 changed file with 64 additions and 77 deletions.
141 changes: 64 additions & 77 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,55 +931,6 @@ struct find_const_multibroadcast
}
};

struct find_reshape_reshape_dot
{
auto matcher() const
{
return match::name("dot")(match::used_once(),
match::args(match::name("reshape").bind("inp_rsp1"),
match::name("reshape").bind("inp_rsp2")));
}

// Gemm axis should not be altered by the reshape
auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const
{
auto in_lens = in->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();

return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end());
}

// Batch dims should match for both inputs
auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const
{
auto in1_lens = in1->get_shape().lens();
auto in2_lens = in2->get_shape().lens();

return (
in1_lens.size() == in2_lens.size() and
std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2));
}

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto inp_rsp1 = r.instructions["inp_rsp1"];
auto inp_rsp2 = r.instructions["inp_rsp2"];

auto dot_lens = dot->get_shape().lens();

auto inp1 = inp_rsp1->inputs().front();
auto inp2 = inp_rsp2->inputs().front();

if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and
is_valid_inputs(inp1, inp2)))
return;

auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2);
m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot);
}
};

// Move convert before reshape when preceeding a dot op
struct find_reshape_convert_dot
{
Expand Down Expand Up @@ -1011,16 +962,33 @@ struct find_reshape_dot
match::skip_broadcasts(match::any().bind("other"))));
}

// Gemm axis should not be altered by the reshape
auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const
{
auto inp_lens = inp->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();

return (inp_lens.size() >= dot_axis and
rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]);
}

// Same batch dims
auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const
{
auto in1_lens = in1->get_shape().lens();
auto in2_lens = in2->get_shape().lens();

return (
in1_lens.size() == in2_lens.size() and
std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2));
}

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto rsp = r.instructions["rsp"];
auto other = r.instructions["other"];

auto other_lens = other->get_shape().lens();
if(other_lens.size() > 2)
return;

auto rsp_lens = rsp->get_shape().lens();
auto inp = rsp->inputs().front();
auto inp_lens = inp->get_shape().lens();
Expand All @@ -1029,44 +997,64 @@ struct find_reshape_dot
bool flipped = rsp == dot->inputs().back();
size_t dot_axis = (flipped) ? 2 : 1;

if(inp_lens.size() < dot_axis or
rsp_lens[rsp_lens.size() - dot_axis] != inp_lens[inp_lens.size() - dot_axis])
if(not is_valid_reshape(inp, rsp, dot_axis))
return;

std::vector<size_t> new_other_lens{inp_lens.begin(), inp_lens.end() - 2};
operation new_bc_op;

auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back();
auto bc_other_lens = bc_other->get_shape().lens();
new_other_lens.insert(new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end());

// if the original weight is one dimensional, look at the original broadcast
// to determine the correct broadcast axis
if(other_lens.size() == 1)
instruction_ref new_other;
if(other->get_operator().name() == "reshape")
{
auto bc_other_strides = bc_other->get_shape().strides();
auto it = std::find_if(
bc_other_strides.begin(), bc_other_strides.end(), [&](auto i) { return i != 0; });
auto orig_bc_axis = std::distance(bc_other_strides.begin(), it);
auto other_inp = other->inputs().front();
size_t other_dot_axis = (flipped) ? 1 : 2;
if(not is_valid_reshape(other_inp, other, other_dot_axis) or
not has_same_batch_dims(inp, other_inp))

Check warning on line 1009 in src/simplify_reshapes.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_reshapes.cpp#L1009

Added line #L1009 was not covered by tests
return;

auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis);
new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}});
new_other = other_inp;
}
else
{
new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}});
}
auto other_lens = other->get_shape().lens();
if(other_lens.size() > 2)
return;

std::vector<size_t> new_other_lens{inp_lens.begin(), inp_lens.end() - 2};
operation new_bc_op;

auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back();
auto bc_other_lens = bc_other->get_shape().lens();
new_other_lens.insert(
new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end());

auto new_bc_other = m.insert_instruction(dot, new_bc_op, other);
// if the original weight is one dimensional, look at the original broadcast
// to determine the correct broadcast axis
if(other_lens.size() == 1)
{
auto bc_other_strides = bc_other->get_shape().strides();
auto it = std::find_if(bc_other_strides.begin(),
bc_other_strides.end(),
[&](auto i) { return i != 0; });
auto orig_bc_axis = std::distance(bc_other_strides.begin(), it);

auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis);
new_bc_op =
make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}});
}
else
{
new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}});
}

new_other = m.insert_instruction(dot, new_bc_op, other);
}

instruction_ref new_dot;
if(flipped)
{
new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_other, inp);
new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp);
}
else
{
new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_other);
new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other);
}
m.replace_instruction(
dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot);
Expand Down Expand Up @@ -1095,7 +1083,6 @@ void simplify_reshapes::apply(module& m) const
find_slice_transpose{},
find_reshape_convert_dot{},
find_transpose_contiguous_reshaper_unary{},
find_reshape_reshape_dot{},
find_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
Expand Down

0 comments on commit 6dc309c

Please sign in to comment.