Skip to content

Commit

Permalink
Prevent collapsing batch dims in dot ops with constants (ROCm#2823)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored and lajagapp committed Jul 8, 2024
1 parent 5434e87 commit d40d999
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 20 deletions.
100 changes: 80 additions & 20 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,26 +1012,28 @@ struct find_scalar_multibroadcast_reshape_or_transpose
}
};

struct find_reshape_reshape_dot
struct find_reshape_dot
{
auto matcher() const
{
return match::name("dot")(match::used_once(),
match::args(match::name("reshape").bind("inp_rsp1"),
match::name("reshape").bind("inp_rsp2")));
return match::name("dot")(
match::used_once(),
match::either_arg(0, 1)(match::name("reshape").bind("rsp"),
match::skip_broadcasts(match::any().bind("other"))));
}

// Gemm axis should not be altered by the reshape
auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const
auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const
{
auto in_lens = in->get_shape().lens();
auto inp_lens = inp->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();

return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end());
return (inp_lens.size() >= dot_axis and
rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]);
}

// Batch dims should match for both inputs
auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const
// Same batch dims
auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const
{
auto in1_lens = in1->get_shape().lens();
auto in2_lens = in2->get_shape().lens();
Expand All @@ -1043,21 +1045,79 @@ struct find_reshape_reshape_dot

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto inp_rsp1 = r.instructions["inp_rsp1"];
auto inp_rsp2 = r.instructions["inp_rsp2"];
auto dot = r.result;
auto rsp = r.instructions["rsp"];
auto other = r.instructions["other"];

auto dot_lens = dot->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();
auto inp = rsp->inputs().front();
auto inp_lens = inp->get_shape().lens();

auto inp1 = inp_rsp1->inputs().front();
auto inp2 = inp_rsp2->inputs().front();
// Gemm axis should not be altered by the reshape
bool flipped = rsp == dot->inputs().back();
size_t dot_axis = (flipped) ? 2 : 1;

if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and
is_valid_inputs(inp1, inp2)))
if(not is_valid_reshape(inp, rsp, dot_axis))
return;

auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2);
m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot);
instruction_ref new_other;
if(other->get_operator().name() == "reshape")
{
auto other_inp = other->inputs().front();
size_t other_dot_axis = (flipped) ? 1 : 2;
if(not is_valid_reshape(other_inp, other, other_dot_axis) or
not has_same_batch_dims(inp, other_inp))
return;

new_other = other_inp;
}
else
{
auto other_lens = other->get_shape().lens();
if(other_lens.size() > 2)
return;

std::vector<size_t> new_other_lens{inp_lens.begin(), inp_lens.end() - 2};
operation new_bc_op;

auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back();
auto bc_other_lens = bc_other->get_shape().lens();
new_other_lens.insert(
new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end());

// if the original weight is one dimensional, look at the original broadcast
// to determine the correct broadcast axis
if(other_lens.size() == 1)
{
auto bc_other_strides = bc_other->get_shape().strides();
auto it = std::find_if(bc_other_strides.begin(),
bc_other_strides.end(),
[&](auto i) { return i != 0; });
auto orig_bc_axis = std::distance(bc_other_strides.begin(), it);

auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis);
new_bc_op =
make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}});
}
else
{
new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}});
}

new_other = m.insert_instruction(dot, new_bc_op, other);
}

instruction_ref new_dot;
if(flipped)
{
new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp);
}
else
{
new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other);
}
m.replace_instruction(
dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot);
}
};

Expand All @@ -1081,7 +1141,7 @@ void simplify_reshapes::apply(module& m) const
find_broadcast_transpose{},
find_slice_transpose{},
find_unary_shape_transforms{},
find_reshape_reshape_dot{},
find_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
}
Expand Down
166 changes: 166 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2266,4 +2266,170 @@ TEST_CASE(reshape_reshape_dot_gemm_axis)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32, 32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_flipped)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {16, 8}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 8, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16, 16, 8}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 16, 8}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), w_bc, inp);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 16, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_dot_axis)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 4}};
migraphx::shape s_w{migraphx::shape::float_type, {32, 32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_flipped_dot_axis)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {8, 64}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 64}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp);
m1.add_return({dot});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_broadcast)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_broadcast_2)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit d40d999

Please sign in to comment.