Skip to content

Commit

Permalink
Move unary operators around shape transformation (#2958)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored and Ted Themistokleous committed Apr 25, 2024
1 parent c8e801e commit e9119dc
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 101 deletions.
127 changes: 97 additions & 30 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,41 +645,108 @@ struct find_reshape_cont
}
};

// match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper()
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};

// finds the pattern of transpose --> contiguous --> reshaper_op --> unary
// application of this matcher moves the unary operation before the contiguous so it becomes
// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out
// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth
// operator.
struct find_transpose_contiguous_reshaper_unary
struct find_unary_shape_transforms
{
static const auto& shape_transforms()
{
static const std::unordered_set<std::string> names = {
"flatten",
"reshape",
"squeeze",
"unsqueeze",
"transpose",
"broadcast",
"multibroadcast",
};
return names;
}
auto matcher() const
{
return pointwise(match::used_once(),
match::nargs(1),
match::args(match_transpose_contiguous_reshaper()));
auto output_not_pointwise =
match::none_of(match::skip_output(match::name("contiguous"))(match::pointwise()));
auto input_has_shape_transform =
match::args(match::skip(match::name("contiguous"))(match::name(shape_transforms())));
return match::pointwise(
match::used_once(), input_has_shape_transform, output_not_pointwise);
}

void apply(module& m, const match::matcher_result& r) const
static bool is_shape_transform(instruction_ref ins)
{
return ins->inputs().size() == 1 and
(contains(shape_transforms(), ins->name()) or ins->name() == "contiguous");
}

static bool can_fuse_unary(instruction_ref ins)
{
return ins->name() == "@literal" or
ins->get_operator().attributes().contains("pointwise") or
contains(ins->name(), "reduce");
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"];
auto unary_ins = m.insert_instruction(cont_ins, ins->get_operator(), trans_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
auto ins = mr.result;
if(ins->outputs().empty())
return;
auto input = ins->inputs().front();
auto output = ins->outputs().front();

auto insert_ops = [&](const auto& ops, instruction_ref z) {
for(const auto& op : ops)
{
z = m.insert_instruction(ins, op, z);
}
return z;
};

std::vector<operation> xops;
auto x = input;
while(is_shape_transform(x))
{
xops.push_back(x->get_operator());
x = x->inputs().front();
}
std::reverse(xops.begin(), xops.end());

std::vector<operation> yops;
auto y = output;
auto last_transform = m.end();
while(is_shape_transform(y) and y->outputs().size() == 1)
{
yops.push_back(y->get_operator());
last_transform = y;
y = y->outputs().front();
}

bool move_up = can_fuse_unary(x);
bool move_down = can_fuse_unary(y);

if(move_up and move_down)
{
if(x->name() == "@literal")
move_down = false; // NOLINT(bugprone-branch-clone)
else if(yops.empty())
move_up = false;
else
move_down = false;
}
else if(not move_up and not move_down)
{
if(not yops.empty())
move_up = true;
}

if(move_up)
{
auto z = m.insert_instruction(ins, ins->get_operator(), x);
z = insert_ops(xops, z);
m.replace_instruction(ins, z);
}
else if(move_down and not yops.empty())
{
auto z = insert_ops(yops, input);
m.replace_instruction(last_transform, ins->get_operator(), z);
}
}
};

Expand Down Expand Up @@ -967,7 +1034,7 @@ void simplify_reshapes::apply(module& m) const
find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{},
find_unary_shape_transforms{},
find_reshape_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
Expand Down
179 changes: 108 additions & 71 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1616,119 +1616,156 @@ TEST_CASE(reshape_cont_nonpw)
EXPECT(m1 == create_module());
}

TEST_CASE(transpose_contiguous_reshape_unary)
TEST_CASE(reshape_unary_transpose)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2);
m1.add_instruction(pass_op{}, relu);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
m1.add_instruction(pass_op{}, transpose);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins2);
auto x = m2.add_parameter("x", s);
auto relu = m2.add_instruction(migraphx::make_op("relu"), x);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
m2.add_instruction(pass_op{}, transpose);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_reshape_unary_attributes)
TEST_CASE(reshape_unary_last)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}});
auto reshape_ins1 =
auto x = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto conv = m1.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}),
reshape_ins2);
m1.add_instruction(pass_op{}, conv);
m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(pointwise_reshape_unary_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 2, 2, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s1);
auto z = m1.add_parameter("z", s2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), x, y);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), mul);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto pw = m1.add_instruction(migraphx::make_op("add"), z, relu);
m1.add_instruction(pass_op{}, pw);
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(literal_reshape_unary_transpose_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s2);
auto one = m1.add_literal(migraphx::generate_literal(s1));
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), one);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
auto pw = m1.add_instruction(migraphx::make_op("add"), x, transpose);
m1.add_instruction(pass_op{}, pw);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::half_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto conv = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}),
transpose_ins);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), conv);
m2.add_instruction(pass_op{}, reshape_ins2);
auto x = m2.add_parameter("x", s2);
auto one = m2.add_literal(migraphx::generate_literal(s1));
auto relu = m2.add_instruction(migraphx::make_op("relu"), one);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
auto pw = m2.add_instruction(migraphx::make_op("add"), x, transpose);
m2.add_instruction(pass_op{}, pw);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_squeeze_unary)
TEST_CASE(reshape_unary_transpose_pointwise)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 2, 5, 2, 5, 2}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins);
m1.add_instruction(pass_op{}, rsqrt);
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), relu);
auto add = m1.add_instruction(migraphx::make_op("add"), transpose, y);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
m2.add_instruction(pass_op{}, sq_ins);
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose);
auto add = m2.add_instruction(migraphx::make_op("add"), relu, y);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1 == m2);
}

TEST_CASE(transpose_contiguous_unsqueeze_unary)
TEST_CASE(pointwise_reshape_unary)
{
auto s = migraphx::shape{migraphx::shape::float_type, {2, 8, 5, 5}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("nearbyint"), unsq_ins);
m1.add_instruction(pass_op{}, round);
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto add = m1.add_instruction(migraphx::make_op("add"), x, y);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), add);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins);
m1.add_instruction(pass_op{}, relu);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins);
auto unsq_ins = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round);
m2.add_instruction(pass_op{}, unsq_ins);
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto add = m2.add_instruction(migraphx::make_op("add"), x, y);
auto relu = m2.add_instruction(migraphx::make_op("relu"), add);
auto reshape_ins =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins);
}
EXPECT(m1 == m2);
}
Expand Down

0 comments on commit e9119dc

Please sign in to comment.