Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointwise + Concat fusion #2785

Merged
merged 21 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions src/fuse_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

unsigned int get_noop_counter()
{
static unsigned int counter = 0;
return counter++;
}

struct fused_concat
{
int64_t axis = 0;
Expand Down Expand Up @@ -98,6 +104,64 @@ struct fused_concat
MIGRAPHX_REGISTER_OP(fused_concat);

namespace {
struct find_concat_pointwise
{
auto matcher() const
{
return match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto concat_ins = r.result;

std::vector<instruction_ref> inputs;
size_t num_noops = 0;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
}
else
{
num_noops++;
inputs.push_back(input);
}
}
if(num_noops > std::max(size_t{1}, concat_ins->inputs().size() / 4))
{
return;
}
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
pm->add_return({x});
return pm;
});
auto* post_pm = mpm.create_module("noop:concat" + std::to_string(get_noop_counter()));
auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this prefixed with a !? I would prefer not to have it start with a special character because this will become _x0 in the C++.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure. I copied same logic from pointwise_concat_pointwise fusion. I can remove it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses x0 in that pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());

I am referring this line.

post_pm->add_return({x});
module_inputs.push_back(post_pm);
mpm.get_module().replace_instruction(
concat_ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};

struct find_pointwise_concat_pointwise
{
Expand All @@ -119,7 +183,7 @@ struct find_pointwise_concat_pointwise
std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise")
if(input->name() == "pointwise" and input->outputs().size() == 1)
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
Expand All @@ -130,22 +194,19 @@ struct find_pointwise_concat_pointwise
[&](auto input) { return input != concat_ins; });

std::vector<module_ref> module_inputs;
static unsigned int counter = 0;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm =
mpm.create_module("concat:identity" + std::to_string(counter++));

auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
pm->add_return({x});
return pm;
});

Expand Down Expand Up @@ -173,6 +234,8 @@ struct find_pointwise_concat_pointwise
void fuse_concat::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_pointwise_concat_pointwise{});
mpm.run_pass(migraphx::dead_code_elimination{});
match::find_matches(mpm, find_concat_pointwise{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ std::string generate_name_from_ops(const module& m, const std::string& postname)
auto op_names = get_op_names(m);
if(not postname.empty())
op_names.push_back(postname);
if(op_names.empty())
return "noop";
return join_strings(op_names, "_");
}

Expand Down
167 changes: 149 additions & 18 deletions test/fuse_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> input

template <class Arg, class... Args>
migraphx::instruction_ref
add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
add_pointwise_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
{
std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs;
Expand Down Expand Up @@ -81,7 +81,99 @@ add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
}

TEST_CASE(simple_pointwise_concat)
TEST_CASE(simple_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
mm->add_return({concat});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("noop:concat0", {}, noop_pointwise()),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(partial_pointwise_concat)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling);
mm->add_return({concat});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("noop:concat2", {}, noop_pointwise()),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop1", {pooling}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(skip_pointwise_concat)
{
// number of no-ops are two in this case and therefore pointwise+concat fusion wouldn't be
// applicable.
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto w = mm->add_parameter("w", s1);
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto reduce_ins = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), w);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(
migraphx::make_op("concat", {{"axis", 1}}), reduce_ins, add, pooling);
mm->add_return({concat});
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1 == p2);
}

TEST_CASE(simple_pointwise_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
Expand All @@ -102,17 +194,17 @@ TEST_CASE(simple_pointwise_concat)
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(partial_pointwise_concat)
TEST_CASE(partial_pointwise_concat_pointwise)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
Expand All @@ -139,16 +231,55 @@ TEST_CASE(partial_pointwise_concat)
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:identity0", {pooling}, single_pointwise("identity")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop3", {pooling}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(multiple_use_pointwise_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), relu);
auto mul = add_pointwise(p1, "main:pointwise3", {slice, sub}, single_pointwise({"mul"}));
mm->add_return({mul});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto sub = add_pointwise(p2, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop4", {sub}, noop_pointwise()));
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}),
fused_concat);
auto mul = add_pointwise(p2, "main:pointwise3", {slice, sub}, single_pointwise({"mul"}));
mm->add_return({mul});
}
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_concat_fusion)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 3}};
Expand All @@ -172,11 +303,11 @@ TEST_CASE(pointwise_concat_fusion)
auto y = mm->add_parameter("y", s2);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")),
arg("concat:identity1", {yc}, single_pointwise("identity")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")),
arg("concat:noop5", {yc}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
Expand Down
7 changes: 6 additions & 1 deletion test/include/pointwise.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -67,6 +67,11 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
return add_pointwise(p, p.get_main_module(), name, inputs, f);
}

inline auto noop_pointwise()
{
return [=](auto*, const auto& inputs) { return inputs; };
}

inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
Expand Down
Loading
Loading