Skip to content

Commit

Permalink
Pointwise + Concat fusion (#2785)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Mar 13, 2024
1 parent 4bedf9e commit 21b4d70
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 42 deletions.
79 changes: 71 additions & 8 deletions src/fuse_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

unsigned int get_noop_counter()
{
static unsigned int counter = 0;
return counter++;
}

struct fused_concat
{
int64_t axis = 0;
Expand Down Expand Up @@ -98,6 +104,64 @@ struct fused_concat
MIGRAPHX_REGISTER_OP(fused_concat);

namespace {
struct find_concat_pointwise
{
auto matcher() const
{
return match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto concat_ins = r.result;

std::vector<instruction_ref> inputs;
size_t num_noops = 0;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
}
else
{
num_noops++;
inputs.push_back(input);
}
}
if(num_noops > std::max(size_t{1}, concat_ins->inputs().size() / 4))
{
return;
}
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
pm->add_return({x});
return pm;
});
auto* post_pm = mpm.create_module("noop:concat" + std::to_string(get_noop_counter()));
auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()});
post_pm->add_return({x});
module_inputs.push_back(post_pm);
mpm.get_module().replace_instruction(
concat_ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};

struct find_pointwise_concat_pointwise
{
Expand All @@ -119,7 +183,7 @@ struct find_pointwise_concat_pointwise
std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise")
if(input->name() == "pointwise" and input->outputs().size() == 1)
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
Expand All @@ -130,22 +194,19 @@ struct find_pointwise_concat_pointwise
[&](auto input) { return input != concat_ins; });

std::vector<module_ref> module_inputs;
static unsigned int counter = 0;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm =
mpm.create_module("concat:identity" + std::to_string(counter++));

auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
pm->add_return({x});
return pm;
});

Expand Down Expand Up @@ -173,6 +234,8 @@ struct find_pointwise_concat_pointwise
void fuse_concat::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_pointwise_concat_pointwise{});
mpm.run_pass(migraphx::dead_code_elimination{});
match::find_matches(mpm, find_concat_pointwise{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ std::string generate_name_from_ops(const module& m, const std::string& postname)
auto op_names = get_op_names(m);
if(not postname.empty())
op_names.push_back(postname);
if(op_names.empty())
return "noop";
return join_strings(op_names, "_");
}

Expand Down
167 changes: 149 additions & 18 deletions test/fuse_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> input

template <class Arg, class... Args>
migraphx::instruction_ref
add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
add_pointwise_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
{
std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs;
Expand Down Expand Up @@ -81,7 +81,99 @@ add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
}

TEST_CASE(simple_pointwise_concat)
TEST_CASE(simple_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
mm->add_return({concat});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("noop:concat0", {}, noop_pointwise()),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(partial_pointwise_concat)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling);
mm->add_return({concat});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("noop:concat2", {}, noop_pointwise()),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop1", {pooling}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(skip_pointwise_concat)
{
// number of no-ops are two in this case and therefore pointwise+concat fusion wouldn't be
// applicable.
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto w = mm->add_parameter("w", s1);
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto reduce_ins = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), w);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(
migraphx::make_op("concat", {{"axis", 1}}), reduce_ins, add, pooling);
mm->add_return({concat});
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1 == p2);
}

TEST_CASE(simple_pointwise_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
Expand All @@ -102,17 +194,17 @@ TEST_CASE(simple_pointwise_concat)
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(partial_pointwise_concat)
TEST_CASE(partial_pointwise_concat_pointwise)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
Expand All @@ -139,16 +231,55 @@ TEST_CASE(partial_pointwise_concat)
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:identity0", {pooling}, single_pointwise("identity")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop3", {pooling}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(multiple_use_pointwise_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), relu);
auto mul = add_pointwise(p1, "main:pointwise3", {slice, sub}, single_pointwise({"mul"}));
mm->add_return({mul});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto sub = add_pointwise(p2, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop4", {sub}, noop_pointwise()));
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}),
fused_concat);
auto mul = add_pointwise(p2, "main:pointwise3", {slice, sub}, single_pointwise({"mul"}));
mm->add_return({mul});
}
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_concat_fusion)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 3}};
Expand All @@ -172,11 +303,11 @@ TEST_CASE(pointwise_concat_fusion)
auto y = mm->add_parameter("y", s2);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")),
arg("concat:identity1", {yc}, single_pointwise("identity")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")),
arg("concat:noop5", {yc}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
Expand Down
7 changes: 6 additions & 1 deletion test/include/pointwise.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -67,6 +67,11 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
return add_pointwise(p, p.get_main_module(), name, inputs, f);
}

inline auto noop_pointwise()
{
return [=](auto*, const auto& inputs) { return inputs; };
}

inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
Expand Down
Loading

0 comments on commit 21b4d70

Please sign in to comment.