Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move unary operators around shape transformation #2958

Merged
merged 11 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 97 additions & 30 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,41 +645,108 @@ struct find_reshape_cont
}
};

// match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper()
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};

// finds the pattern of transpose --> contiguous --> reshaper_op --> unary
// application of this matcher moves the unary operation before the contiguous so it becomes
// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out
// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth
// operator.
struct find_transpose_contiguous_reshaper_unary
struct find_unary_shape_transforms
{
static const auto& shape_transforms()
{
static const std::unordered_set<std::string> names = {
"flatten",
"reshape",
"squeeze",
"unsqueeze",
"transpose",
"broadcast",
"multibroadcast",
};
return names;
}
auto matcher() const
{
return pointwise(match::used_once(),
match::nargs(1),
match::args(match_transpose_contiguous_reshaper()));
auto output_not_pointwise =
match::none_of(match::skip_output(match::name("contiguous"))(match::pointwise()));
auto input_has_shape_transform =
match::args(match::skip(match::name("contiguous"))(match::name(shape_transforms())));
return match::pointwise(
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
match::used_once(), input_has_shape_transform, output_not_pointwise);
}

void apply(module& m, const match::matcher_result& r) const
static bool is_shape_transform(instruction_ref ins)
{
return ins->inputs().size() == 1 and
(contains(shape_transforms(), ins->name()) or ins->name() == "contiguous");
}

static bool can_fuse_unary(instruction_ref ins)
{
return ins->name() == "@literal" or
ins->get_operator().attributes().contains("pointwise") or
contains(ins->name(), "reduce");
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"];
auto unary_ins = m.insert_instruction(cont_ins, ins->get_operator(), trans_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
auto ins = mr.result;
if(ins->outputs().empty())
return;
auto input = ins->inputs().front();
auto output = ins->outputs().front();

auto insert_ops = [&](const auto& ops, instruction_ref z) {
for(const auto& op : ops)
{
z = m.insert_instruction(ins, op, z);
}
return z;
};

std::vector<operation> xops;
auto x = input;
while(is_shape_transform(x))
{
xops.push_back(x->get_operator());
x = x->inputs().front();
}
std::reverse(xops.begin(), xops.end());

std::vector<operation> yops;
auto y = output;
auto last_transform = m.end();
while(is_shape_transform(y) and y->outputs().size() == 1)
{
yops.push_back(y->get_operator());
last_transform = y;
y = y->outputs().front();
}

bool move_up = can_fuse_unary(x);
bool move_down = can_fuse_unary(y);

if(move_up and move_down)
{
if(x->name() == "@literal")
move_down = false; // NOLINT(bugprone-branch-clone)
else if(yops.empty())
move_up = false;
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
else
move_down = false;
}
else if(not move_up and not move_down)
{
if(not yops.empty())
move_up = true;
}

if(move_up)
{
auto z = m.insert_instruction(ins, ins->get_operator(), x);
z = insert_ops(xops, z);
m.replace_instruction(ins, z);
}
else if(move_down and not yops.empty())
{
auto z = insert_ops(yops, input);
m.replace_instruction(last_transform, ins->get_operator(), z);
}
}
};

Expand Down Expand Up @@ -967,7 +1034,7 @@ void simplify_reshapes::apply(module& m) const
find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{},
find_unary_shape_transforms{},
find_reshape_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
Expand Down
179 changes: 108 additions & 71 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1616,119 +1616,156 @@ TEST_CASE(reshape_cont_nonpw)
EXPECT(m1 == create_module());
}

TEST_CASE(transpose_contiguous_reshape_unary)
TEST_CASE(reshape_unary_transpose)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2);
m1.add_instruction(pass_op{}, relu);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
m1.add_instruction(pass_op{}, transpose);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins2);
auto x = m2.add_parameter("x", s);
auto relu = m2.add_instruction(migraphx::make_op("relu"), x);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
m2.add_instruction(pass_op{}, transpose);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_reshape_unary_attributes)
TEST_CASE(reshape_unary_last)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto conv = m1.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}),
reshape_ins2);
m1.add_instruction(pass_op{}, conv);
m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(pointwise_reshape_unary_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 2, 2, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s1);
auto z = m1.add_parameter("z", s2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), x, y);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), mul);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto pw = m1.add_instruction(migraphx::make_op("add"), z, relu);
m1.add_instruction(pass_op{}, pw);
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(literal_reshape_unary_transpose_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s2);
auto one = m1.add_literal(migraphx::generate_literal(s1));
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), one);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
auto pw = m1.add_instruction(migraphx::make_op("add"), x, transpose);
m1.add_instruction(pass_op{}, pw);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto conv = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}),
transpose_ins);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), conv);
m2.add_instruction(pass_op{}, reshape_ins2);
auto x = m2.add_parameter("x", s2);
auto one = m2.add_literal(migraphx::generate_literal(s1));
auto relu = m2.add_instruction(migraphx::make_op("relu"), one);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
auto pw = m2.add_instruction(migraphx::make_op("add"), x, transpose);
m2.add_instruction(pass_op{}, pw);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_squeeze_unary)
TEST_CASE(reshape_unary_transpose_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins);
m1.add_instruction(pass_op{}, rsqrt);
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
auto add = m1.add_instruction(migraphx::make_op("add"), transpose, y);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
m2.add_instruction(pass_op{}, sq_ins);
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose);
auto add = m2.add_instruction(migraphx::make_op("add"), relu, y);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_unsqueeze_unary)
TEST_CASE(pointwise_reshape_unary)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("nearbyint"), unsq_ins);
m1.add_instruction(pass_op{}, round);
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto add = m1.add_instruction(migraphx::make_op("add"), x, y);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), add);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
m1.add_instruction(pass_op{}, relu);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins);
auto unsq_ins = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round);
m2.add_instruction(pass_op{}, unsq_ins);
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto add = m2.add_instruction(migraphx::make_op("add"), x, y);
auto relu = m2.add_instruction(migraphx::make_op("relu"), add);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins);
}
EXPECT(m1 == m2);
}
Expand Down
Loading