From 6dc309cf90bac77565bdbfbc3b2848112ab4a030 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 4 Mar 2024 19:06:06 +0000 Subject: [PATCH] combine reshape-dot matchers --- src/simplify_reshapes.cpp | 141 +++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 77 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 345a616ddf0..c10182280b2 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -931,55 +931,6 @@ struct find_const_multibroadcast } }; -struct find_reshape_reshape_dot -{ - auto matcher() const - { - return match::name("dot")(match::used_once(), - match::args(match::name("reshape").bind("inp_rsp1"), - match::name("reshape").bind("inp_rsp2"))); - } - - // Gemm axis should not be altered by the reshape - auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const - { - auto in_lens = in->get_shape().lens(); - auto rsp_lens = rsp->get_shape().lens(); - - return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end()); - } - - // Batch dims should match for both inputs - auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const - { - auto in1_lens = in1->get_shape().lens(); - auto in2_lens = in2->get_shape().lens(); - - return ( - in1_lens.size() == in2_lens.size() and - std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2)); - } - - void apply(module& m, const match::matcher_result& r) const - { - auto dot = r.result; - auto inp_rsp1 = r.instructions["inp_rsp1"]; - auto inp_rsp2 = r.instructions["inp_rsp2"]; - - auto dot_lens = dot->get_shape().lens(); - - auto inp1 = inp_rsp1->inputs().front(); - auto inp2 = inp_rsp2->inputs().front(); - - if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and - is_valid_inputs(inp1, inp2))) - return; - - auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2); - m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot); - } -}; - // Move convert before reshape when preceeding a dot op struct find_reshape_convert_dot { @@ -1011,16 +962,33 @@ struct find_reshape_dot match::skip_broadcasts(match::any().bind("other")))); } + // Gemm axis should not be altered by the reshape + auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const + { + auto inp_lens = inp->get_shape().lens(); + auto rsp_lens = rsp->get_shape().lens(); + + return (inp_lens.size() >= dot_axis and + rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]); + } + + // Same batch dims + auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const + { + auto in1_lens = in1->get_shape().lens(); + auto in2_lens = in2->get_shape().lens(); + + return ( + in1_lens.size() == in2_lens.size() and + std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2)); + } + void apply(module& m, const match::matcher_result& r) const { auto dot = r.result; auto rsp = r.instructions["rsp"]; auto other = r.instructions["other"]; - auto other_lens = other->get_shape().lens(); - if(other_lens.size() > 2) - return; - auto rsp_lens = rsp->get_shape().lens(); auto inp = rsp->inputs().front(); auto inp_lens = inp->get_shape().lens(); @@ -1029,44 +997,64 @@ struct find_reshape_dot bool flipped = rsp == dot->inputs().back(); size_t dot_axis = (flipped) ? 2 : 1; - if(inp_lens.size() < dot_axis or - rsp_lens[rsp_lens.size() - dot_axis] != inp_lens[inp_lens.size() - dot_axis]) + if(not is_valid_reshape(inp, rsp, dot_axis)) return; - std::vector new_other_lens{inp_lens.begin(), inp_lens.end() - 2}; - operation new_bc_op; - - auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back(); - auto bc_other_lens = bc_other->get_shape().lens(); - new_other_lens.insert(new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end()); - - // if the original weight is one dimensional, look at the original broadcast - // to determine the correct broadcast axis - if(other_lens.size() == 1) + instruction_ref new_other; + if(other->get_operator().name() == "reshape") { - auto bc_other_strides = bc_other->get_shape().strides(); - auto it = std::find_if( - bc_other_strides.begin(), bc_other_strides.end(), [&](auto i) { return i != 0; }); - auto orig_bc_axis = std::distance(bc_other_strides.begin(), it); + auto other_inp = other->inputs().front(); + size_t other_dot_axis = (flipped) ? 1 : 2; + if(not is_valid_reshape(other_inp, other, other_dot_axis) or + not has_same_batch_dims(inp, other_inp)) + return; - auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis); - new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}}); + new_other = other_inp; } else { - new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}}); - } + auto other_lens = other->get_shape().lens(); + if(other_lens.size() > 2) + return; + + std::vector new_other_lens{inp_lens.begin(), inp_lens.end() - 2}; + operation new_bc_op; + + auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back(); + auto bc_other_lens = bc_other->get_shape().lens(); + new_other_lens.insert( + new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end()); - auto new_bc_other = m.insert_instruction(dot, new_bc_op, other); + // if the original weight is one dimensional, look at the original broadcast + // to determine the correct broadcast axis + if(other_lens.size() == 1) + { + auto bc_other_strides = bc_other->get_shape().strides(); + auto it = std::find_if(bc_other_strides.begin(), + bc_other_strides.end(), + [&](auto i) { return i != 0; }); + auto orig_bc_axis = std::distance(bc_other_strides.begin(), it); + + auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis); + new_bc_op = + make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}}); + } + else + { + new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}}); + } + + new_other = m.insert_instruction(dot, new_bc_op, other); + } instruction_ref new_dot; if(flipped) { - new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_other, inp); + new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp); } else { - new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_other); + new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other); } m.replace_instruction( dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot); @@ -1095,7 +1083,6 @@ void simplify_reshapes::apply(module& m) const find_slice_transpose{}, find_reshape_convert_dot{}, find_transpose_contiguous_reshaper_unary{}, - find_reshape_reshape_dot{}, find_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m);