From 0da31730c278148f6bd8169e406a680b7296466f Mon Sep 17 00:00:00 2001 From: shivadbhavsar <105248561+shivadbhavsar@users.noreply.github.com> Date: Fri, 31 May 2024 13:53:17 -0700 Subject: [PATCH] Prevent collapsing batch dims in dot ops with constants (#2823) --- src/simplify_reshapes.cpp | 100 +++++++++++++++---- test/simplify_reshapes_test.cpp | 166 ++++++++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 20 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a0a952d6aac..d8787bf0d93 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1012,26 +1012,28 @@ struct find_scalar_multibroadcast_reshape_or_transpose } }; -struct find_reshape_reshape_dot +struct find_reshape_dot { auto matcher() const { - return match::name("dot")(match::used_once(), - match::args(match::name("reshape").bind("inp_rsp1"), - match::name("reshape").bind("inp_rsp2"))); + return match::name("dot")( + match::used_once(), + match::either_arg(0, 1)(match::name("reshape").bind("rsp"), + match::skip_broadcasts(match::any().bind("other")))); } // Gemm axis should not be altered by the reshape - auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const + auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const { - auto in_lens = in->get_shape().lens(); + auto inp_lens = inp->get_shape().lens(); auto rsp_lens = rsp->get_shape().lens(); - return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end()); + return (inp_lens.size() >= dot_axis and + rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]); } - // Batch dims should match for both inputs - auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const + // Same batch dims + auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const { auto in1_lens = in1->get_shape().lens(); auto in2_lens = in2->get_shape().lens(); @@ -1043,21 +1045,79 @@ struct find_reshape_reshape_dot void apply(module& m, const match::matcher_result& r) const { - auto dot = r.result; - auto inp_rsp1 = r.instructions["inp_rsp1"]; - auto inp_rsp2 = r.instructions["inp_rsp2"]; + auto dot = r.result; + auto rsp = r.instructions["rsp"]; + auto other = r.instructions["other"]; - auto dot_lens = dot->get_shape().lens(); + auto rsp_lens = rsp->get_shape().lens(); + auto inp = rsp->inputs().front(); + auto inp_lens = inp->get_shape().lens(); - auto inp1 = inp_rsp1->inputs().front(); - auto inp2 = inp_rsp2->inputs().front(); + // Gemm axis should not be altered by the reshape + bool flipped = rsp == dot->inputs().back(); + size_t dot_axis = (flipped) ? 2 : 1; - if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and - is_valid_inputs(inp1, inp2))) + if(not is_valid_reshape(inp, rsp, dot_axis)) return; - auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2); - m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot); + instruction_ref new_other; + if(other->get_operator().name() == "reshape") + { + auto other_inp = other->inputs().front(); + size_t other_dot_axis = (flipped) ? 1 : 2; + if(not is_valid_reshape(other_inp, other, other_dot_axis) or + not has_same_batch_dims(inp, other_inp)) + return; + + new_other = other_inp; + } + else + { + auto other_lens = other->get_shape().lens(); + if(other_lens.size() > 2) + return; + + std::vector new_other_lens{inp_lens.begin(), inp_lens.end() - 2}; + operation new_bc_op; + + auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back(); + auto bc_other_lens = bc_other->get_shape().lens(); + new_other_lens.insert( + new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end()); + + // if the original weight is one dimensional, look at the original broadcast + // to determine the correct broadcast axis + if(other_lens.size() == 1) + { + auto bc_other_strides = bc_other->get_shape().strides(); + auto it = std::find_if(bc_other_strides.begin(), + bc_other_strides.end(), + [&](auto i) { return i != 0; }); + auto orig_bc_axis = std::distance(bc_other_strides.begin(), it); + + auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis); + new_bc_op = + make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}}); + } + else + { + new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}}); + } + + new_other = m.insert_instruction(dot, new_bc_op, other); + } + + instruction_ref new_dot; + if(flipped) + { + new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp); + } + else + { + new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other); + } + m.replace_instruction( + dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot); } }; @@ -1081,7 +1141,7 @@ void simplify_reshapes::apply(module& m) const find_broadcast_transpose{}, find_slice_transpose{}, find_unary_shape_transforms{}, - find_reshape_reshape_dot{}, + find_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index f0d100821ab..41f2071ea3a 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2266,4 +2266,170 @@ TEST_CASE(reshape_reshape_dot_gemm_axis) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reshape_dot) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {32, 32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_dot_flipped) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {16, 8}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 8, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16, 16, 8}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 16, 8}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), w_bc, inp); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 16, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_dot_dot_axis) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 4}}; + migraphx::shape s_w{migraphx::shape::float_type, {32, 32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_dot_flipped_dot_axis) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {8, 64}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 64}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp); + m1.add_return({dot}); + }; + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_dot_broadcast) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_dot_broadcast_2) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }