Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move unary operators around shape transformation #2958

Merged
merged 11 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 95 additions & 30 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,41 +645,106 @@
}
};

// match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper()
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};

// finds the pattern of transpose --> contiguous --> reshaper_op --> unary
// application of this matcher moves the unary operation before the contiguous so it becomes
// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out
// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth
// operator.
struct find_transpose_contiguous_reshaper_unary
struct find_unary_shape_transforms
{
static const auto& shape_transforms()
{
static const std::unordered_set<std::string> names = {
"flatten",
"reshape",
"squeeze",
"unsqueeze",
"transpose",
"broadcast",
"multibroadcast",
};
return names;
}
auto matcher() const
{
return pointwise(match::used_once(),
match::nargs(1),
match::args(match_transpose_contiguous_reshaper()));
auto output_not_pointwise =
match::none_of(match::skip_output(match::name("contiguous"))(match::pointwise()));
auto input_has_shape_transform =
match::args(match::skip(match::name("contiguous"))(match::name(shape_transforms())));
return match::pointwise(
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
match::used_once(), input_has_shape_transform, output_not_pointwise);
}

void apply(module& m, const match::matcher_result& r) const
static bool is_shape_transform(instruction_ref ins)
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"];
auto unary_ins = m.insert_instruction(cont_ins, ins->get_operator(), trans_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
return ins->inputs().size() == 1 and
(contains(shape_transforms(), ins->name()) or ins->name() == "contiguous");
}

static bool can_fuse_unary(instruction_ref ins)
{
return ins->name() == "@literal" or
ins->get_operator().attributes().contains("pointwise") or
contains(ins->name(), "reduce");
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().front();
auto output = ins->outputs().front();

auto insert_ops = [&](const auto& ops, instruction_ref z) {
for(const auto& op : ops)
{
z = m.insert_instruction(ins, op, z);
}
return z;
};

std::vector<operation> xops;
auto x = input;
while(is_shape_transform(x))
{
xops.push_back(x->get_operator());
x = x->inputs().front();
}
std::reverse(xops.begin(), xops.end());

std::vector<operation> yops;
auto y = output;
auto last_transform = m.end();
while(is_shape_transform(y) and y->outputs().size() == 1)
{
yops.push_back(y->get_operator());
last_transform = y;
y = y->outputs().front();
}

bool move_up = can_fuse_unary(x);
bool move_down = can_fuse_unary(y);

if(move_up and move_down)
{
if(x->name() == "@literal")

Check warning on line 724 in src/simplify_reshapes.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_reshapes.cpp#L724

Added line #L724 was not covered by tests
move_down = false; // NOLINT(bugprone-branch-clone)
else if(yops.empty())

Check warning on line 726 in src/simplify_reshapes.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_reshapes.cpp#L726

Added line #L726 was not covered by tests
move_up = false;
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
else
move_down = false;
}
else if(not move_up and not move_down)
{
if(not yops.empty())
move_up = true;
}

if(move_up)
{
auto z = m.insert_instruction(ins, ins->get_operator(), x);
z = insert_ops(xops, z);
m.replace_instruction(ins, z);
}
else if(move_down and not yops.empty())
{
auto z = insert_ops(yops, input);
m.replace_instruction(last_transform, ins->get_operator(), z);
}
}
};

Expand Down Expand Up @@ -967,7 +1032,7 @@
find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{},
find_unary_shape_transforms{},
find_reshape_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
Expand Down
131 changes: 50 additions & 81 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1616,119 +1616,88 @@ TEST_CASE(reshape_cont_nonpw)
EXPECT(m1 == create_module());
}

TEST_CASE(transpose_contiguous_reshape_unary)
TEST_CASE(reshape_unary_transpose)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2);
m1.add_instruction(pass_op{}, relu);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
m1.add_instruction(pass_op{}, transpose);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins2);
auto x = m2.add_parameter("x", s);
auto relu = m2.add_instruction(migraphx::make_op("relu"), x);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
m2.add_instruction(pass_op{}, transpose);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_reshape_unary_attributes)
TEST_CASE(reshape_unary_transpose_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto conv = m1.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}),
reshape_ins2);
m1.add_instruction(pass_op{}, conv);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
auto add = m1.add_instruction(migraphx::make_op("add"), transpose, y);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto conv = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}),
transpose_ins);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), conv);
m2.add_instruction(pass_op{}, reshape_ins2);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_squeeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins);
m1.add_instruction(pass_op{}, rsqrt);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
m2.add_instruction(pass_op{}, sq_ins);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose);
auto add = m2.add_instruction(migraphx::make_op("add"), relu, y);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_unsqueeze_unary)
TEST_CASE(pointwise_reshape_unary)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("nearbyint"), unsq_ins);
m1.add_instruction(pass_op{}, round);
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto add = m1.add_instruction(migraphx::make_op("add"), x, y);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), add);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
m1.add_instruction(pass_op{}, relu);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins);
auto unsq_ins = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round);
m2.add_instruction(pass_op{}, unsq_ins);
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto add = m2.add_instruction(migraphx::make_op("add"), x, y);
auto relu = m2.add_instruction(migraphx::make_op("relu"), add);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins);
}
EXPECT(m1 == m2);
}
Expand Down
Loading