diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 099a7cd35a1..a519925f1b3 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -645,41 +645,108 @@ struct find_reshape_cont } }; -// match sequence of transpose --> contiguous --> reshaper_op -auto match_transpose_contiguous_reshaper() -{ - return match::name({"reshape", "squeeze", "unsqueeze"})( - match::used_once(), - match::args( - match::name("contiguous")( - match::used_once(), match::args(match::transpose_shape().bind("trans_ins"))) - .bind("cont_ins"))) - .bind("reshaper_ins"); -}; - -// finds the pattern of transpose --> contiguous --> reshaper_op --> unary -// application of this matcher moves the unary operation before the contiguous so it becomes -// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out -// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth -// operator. -struct find_transpose_contiguous_reshaper_unary +struct find_unary_shape_transforms { + static const auto& shape_transforms() + { + static const std::unordered_set names = { + "flatten", + "reshape", + "squeeze", + "unsqueeze", + "transpose", + "broadcast", + "multibroadcast", + }; + return names; + } auto matcher() const { - return pointwise(match::used_once(), - match::nargs(1), - match::args(match_transpose_contiguous_reshaper())); + auto output_not_pointwise = + match::none_of(match::skip_output(match::name("contiguous"))(match::pointwise())); + auto input_has_shape_transform = + match::args(match::skip(match::name("contiguous"))(match::name(shape_transforms()))); + return match::pointwise( + match::used_once(), input_has_shape_transform, output_not_pointwise); } - void apply(module& m, const match::matcher_result& r) const + static bool is_shape_transform(instruction_ref ins) + { + return ins->inputs().size() == 1 and + (contains(shape_transforms(), ins->name()) or ins->name() == "contiguous"); + } + + static bool can_fuse_unary(instruction_ref ins) + { + return ins->name() == "@literal" or + ins->get_operator().attributes().contains("pointwise") or + contains(ins->name(), "reduce"); + } + + void apply(module& m, const match::matcher_result& mr) const { - auto ins = r.result; - auto reshaper_ins = r.instructions["reshaper_ins"]; - auto trans_ins = r.instructions["trans_ins"]; - auto cont_ins = r.instructions["cont_ins"]; - auto unary_ins = m.insert_instruction(cont_ins, ins->get_operator(), trans_ins); - // older cont and reshape are removed by deadcode elimination - m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins); + auto ins = mr.result; + if(ins->outputs().empty()) + return; + auto input = ins->inputs().front(); + auto output = ins->outputs().front(); + + auto insert_ops = [&](const auto& ops, instruction_ref z) { + for(const auto& op : ops) + { + z = m.insert_instruction(ins, op, z); + } + return z; + }; + + std::vector xops; + auto x = input; + while(is_shape_transform(x)) + { + xops.push_back(x->get_operator()); + x = x->inputs().front(); + } + std::reverse(xops.begin(), xops.end()); + + std::vector yops; + auto y = output; + auto last_transform = m.end(); + while(is_shape_transform(y) and y->outputs().size() == 1) + { + yops.push_back(y->get_operator()); + last_transform = y; + y = y->outputs().front(); + } + + bool move_up = can_fuse_unary(x); + bool move_down = can_fuse_unary(y); + + if(move_up and move_down) + { + if(x->name() == "@literal") + move_down = false; // NOLINT(bugprone-branch-clone) + else if(yops.empty()) + move_up = false; + else + move_down = false; + } + else if(not move_up and not move_down) + { + if(not yops.empty()) + move_up = true; + } + + if(move_up) + { + auto z = m.insert_instruction(ins, ins->get_operator(), x); + z = insert_ops(xops, z); + m.replace_instruction(ins, z); + } + else if(move_down and not yops.empty()) + { + auto z = insert_ops(yops, input); + m.replace_instruction(last_transform, ins->get_operator(), z); + } } }; @@ -967,7 +1034,7 @@ void simplify_reshapes::apply(module& m) const find_transpose_slice{}, find_broadcast_transpose{}, find_slice_transpose{}, - find_transpose_contiguous_reshaper_unary{}, + find_unary_shape_transforms{}, find_reshape_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 3a50d2a036d..f28c3ff8bde 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1616,119 +1616,156 @@ TEST_CASE(reshape_cont_nonpw) EXPECT(m1 == create_module()); } -TEST_CASE(transpose_contiguous_reshape_unary) +TEST_CASE(reshape_unary_transpose) { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}}; migraphx::module m1; { - auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); - auto reshape_ins1 = + auto x = m1.add_parameter("x", s); + auto reshape_ins = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); - auto transpose_ins = m1.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); - auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); - auto reshape_ins2 = - m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins); - auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2); - m1.add_instruction(pass_op{}, relu); + auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins); + auto transpose = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu); + m1.add_instruction(pass_op{}, transpose); } run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); - auto reshape_ins1 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); - auto transpose_ins = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); - auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins); - auto reshape_ins2 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu); - m2.add_instruction(pass_op{}, reshape_ins2); + auto x = m2.add_parameter("x", s); + auto relu = m2.add_instruction(migraphx::make_op("relu"), x); + auto reshape_ins = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu); + auto transpose = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins); + m2.add_instruction(pass_op{}, transpose); } EXPECT(m1 == m2); } -TEST_CASE(transpose_contiguous_reshape_unary_attributes) +TEST_CASE(reshape_unary_last) { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}}; migraphx::module m1; { - auto x = m1.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}}); - auto reshape_ins1 = + auto x = m1.add_parameter("x", s); + auto reshape_ins = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); - auto transpose_ins = m1.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); - auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); - auto reshape_ins2 = - m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins); - auto conv = m1.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), - reshape_ins2); - m1.add_instruction(pass_op{}, conv); + m1.add_instruction(migraphx::make_op("relu"), reshape_ins); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(pointwise_reshape_unary_pointwise) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 2, 2, 5, 5}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s1); + auto z = m1.add_parameter("z", s2); + auto mul = m1.add_instruction(migraphx::make_op("mul"), x, y); + auto reshape_ins = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), mul); + auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins); + auto pw = m1.add_instruction(migraphx::make_op("add"), z, relu); + m1.add_instruction(pass_op{}, pw); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(literal_reshape_unary_transpose_pointwise) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s2); + auto one = m1.add_literal(migraphx::generate_literal(s1)); + auto reshape_ins = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), one); + auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins); + auto transpose = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu); + auto pw = m1.add_instruction(migraphx::make_op("add"), x, transpose); + m1.add_instruction(pass_op{}, pw); } run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}}); - auto reshape_ins1 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); - auto transpose_ins = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); - auto conv = m2.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), - transpose_ins); - auto reshape_ins2 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), conv); - m2.add_instruction(pass_op{}, reshape_ins2); + auto x = m2.add_parameter("x", s2); + auto one = m2.add_literal(migraphx::generate_literal(s1)); + auto relu = m2.add_instruction(migraphx::make_op("relu"), one); + auto reshape_ins = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu); + auto transpose = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins); + auto pw = m2.add_instruction(migraphx::make_op("add"), x, transpose); + m2.add_instruction(pass_op{}, pw); } EXPECT(m1 == m2); } -TEST_CASE(transpose_contiguous_squeeze_unary) +TEST_CASE(reshape_unary_transpose_pointwise) { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}}; migraphx::module m1; { - auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}}); - auto transpose_ins = - m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); - auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); - auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins); - auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins); - m1.add_instruction(pass_op{}, rsqrt); + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto reshape_ins = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); + auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins); + auto transpose = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu); + auto add = m1.add_instruction(migraphx::make_op("add"), transpose, y); + m1.add_instruction(pass_op{}, add); } run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}}); - auto transpose_ins = - m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); - auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins); - auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt); - m2.add_instruction(pass_op{}, sq_ins); + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto reshape_ins = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); + auto transpose = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins); + auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose); + auto add = m2.add_instruction(migraphx::make_op("add"), relu, y); + m2.add_instruction(pass_op{}, add); } EXPECT(m1 == m2); } -TEST_CASE(transpose_contiguous_unsqueeze_unary) +TEST_CASE(pointwise_reshape_unary) { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}}; migraphx::module m1; { - auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); - auto transpose_ins = - m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); - auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); - auto unsq_ins = - m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins); - auto round = m1.add_instruction(migraphx::make_op("nearbyint"), unsq_ins); - m1.add_instruction(pass_op{}, round); + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto add = m1.add_instruction(migraphx::make_op("add"), x, y); + auto reshape_ins = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), add); + auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins); + m1.add_instruction(pass_op{}, relu); } run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); - auto transpose_ins = - m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); - auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins); - auto unsq_ins = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round); - m2.add_instruction(pass_op{}, unsq_ins); + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + auto relu = m2.add_instruction(migraphx::make_op("relu"), add); + auto reshape_ins = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu); + m2.add_instruction(pass_op{}, reshape_ins); } EXPECT(m1 == m2); }