Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointwise + Concat fusion #2785

Merged
merged 21 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions src/fuse_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

unsigned int get_noop_counter()
{
static unsigned int counter = 0;
return counter++;
}

struct fused_concat
{
int64_t axis = 0;
Expand Down Expand Up @@ -98,6 +104,54 @@ struct fused_concat
MIGRAPHX_REGISTER_OP(fused_concat);

namespace {
struct find_concat_pointwise
{
auto matcher() const
{
return match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto concat_ins = r.result;

std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise" and input->outputs().size() == 1)
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
}
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
pm->add_return({x});
return pm;
});
auto* post_pm = mpm.create_module("noop:concat" + std::to_string(get_noop_counter()));
auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this prefixed with a !? I would prefer not to have it start with a special character because this will become _x0 in the C++.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure. I copied same logic from pointwise_concat_pointwise fusion. I can remove it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses x0 in that pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());

I am referring this line.

post_pm->add_return({x});
module_inputs.push_back(post_pm);
mpm.get_module().replace_instruction(
concat_ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};

struct find_pointwise_concat_pointwise
{
Expand All @@ -119,7 +173,7 @@ struct find_pointwise_concat_pointwise
std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise")
if(input->name() == "pointwise" and input->outputs().size() == 1)
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
Expand All @@ -130,22 +184,19 @@ struct find_pointwise_concat_pointwise
[&](auto input) { return input != concat_ins; });

std::vector<module_ref> module_inputs;
static unsigned int counter = 0;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm =
mpm.create_module("concat:identity" + std::to_string(counter++));

auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
pm->add_return({x});
return pm;
});

Expand Down Expand Up @@ -173,6 +224,8 @@ struct find_pointwise_concat_pointwise
void fuse_concat::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_pointwise_concat_pointwise{});
mpm.run_pass(migraphx::dead_code_elimination{});
match::find_matches(mpm, find_concat_pointwise{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ std::string generate_name_from_ops(const module& m, const std::string& postname)
auto op_names = get_op_names(m);
if(not postname.empty())
op_names.push_back(postname);
if(op_names.empty())
return "noop";
return join_strings(op_names, "_");
}

Expand Down
141 changes: 123 additions & 18 deletions test/fuse_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> input

template <class Arg, class... Args>
migraphx::instruction_ref
add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
add_pointwise_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
{
std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs;
Expand Down Expand Up @@ -81,7 +81,73 @@ add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
}

TEST_CASE(simple_pointwise_concat)
TEST_CASE(simple_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
mm->add_return({concat});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("noop:concat0", {}, noop_pointwise()),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(partial_pointwise_concat)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling);
mm->add_return({concat});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("noop:concat2", {}, noop_pointwise()),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop1", {pooling}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(simple_pointwise_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
Expand All @@ -102,17 +168,17 @@ TEST_CASE(simple_pointwise_concat)
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(partial_pointwise_concat)
TEST_CASE(partial_pointwise_concat_pointwise)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}};
Expand All @@ -139,16 +205,55 @@ TEST_CASE(partial_pointwise_concat)
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:identity0", {pooling}, single_pointwise("identity")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop3", {pooling}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}

TEST_CASE(multiple_use_pointwise_concat_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), relu);
auto mul = add_pointwise(p1, "main:pointwise3", {slice, sub}, single_pointwise({"mul"}));
mm->add_return({mul});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto sub = add_pointwise(p2, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto fused_concat =
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:noop4", {sub}, noop_pointwise()));
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}),
fused_concat);
auto mul = add_pointwise(p2, "main:pointwise3", {slice, sub}, single_pointwise({"mul"}));
mm->add_return({mul});
}
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_concat_fusion)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 3}};
Expand All @@ -172,11 +277,11 @@ TEST_CASE(pointwise_concat_fusion)
auto y = mm->add_parameter("y", s2);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")),
arg("concat:identity1", {yc}, single_pointwise("identity")));
add_pointwise_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")),
arg("concat:noop5", {yc}, noop_pointwise()));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
Expand Down
7 changes: 6 additions & 1 deletion test/include/pointwise.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -67,6 +67,11 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
return add_pointwise(p, p.get_main_module(), name, inputs, f);
}

inline auto noop_pointwise()
{
return [=](auto*, const auto& inputs) { return inputs; };
}

inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
Expand Down
51 changes: 51 additions & 0 deletions test/verify/test_add_sub_concat_slice_mul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>

// test for fuse_concat pass where pointwise input to concat has multiple outputs
struct test_add_sub_concat_slice_mul : verify_program<test_add_sub_concat_slice_mul>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}};
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
auto sub = mm->add_instruction(migraphx::make_op("sub"), x, y);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = mm->add_instruction(migraphx::make_op("relu"), concat);
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {6}}}), relu);
auto mul = mm->add_instruction(migraphx::make_op("mul"), sub, slice);
mm->add_return({mul});
return p;
}
};
Loading
Loading