Skip to content

Commit

Permalink
Add concat fusions (#2460)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Jan 4, 2024
1 parent 496d44b commit 7532007
Show file tree
Hide file tree
Showing 23 changed files with 607 additions and 83 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ add_library(migraphx
eliminate_pad.cpp
env.cpp
file_buffer.cpp
fuse_concat.cpp
fuse_pointwise.cpp
fuse_reduce.cpp
generate.cpp
Expand Down
22 changes: 9 additions & 13 deletions src/cpp_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -48,8 +48,8 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate
ss << "// " << ins->get_operator() << " -> " << ins->get_shape() << "\n";
if(ins->name() == "@param")
{
names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
names[ins] = to_c_id(
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter);
}
else if(ins->name() == "@return")
{
Expand Down Expand Up @@ -95,13 +95,13 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
return param{p.first, "T" + to_c_id(p.first)};
});

std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
[&](auto&& p) { return "class T" + to_c_id(p.first); });
this->return_type = "auto";
return *this;
}
Expand Down Expand Up @@ -200,13 +200,9 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g)
{
function f;
auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string {
f.set_name(to_c_id(m.name()))
.set_types(m)
.set_body(m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
{
std::string string_literal;
Expand Down Expand Up @@ -265,7 +261,7 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
impl->fs << delim;
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << p.name;
impl->fs << delim << p.type << " " << to_c_id(p.name);
delim = ',';
}
impl->fs << ") {\n" << f.body << "\n}\n";
Expand Down
172 changes: 172 additions & 0 deletions src/fuse_concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_concat.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct fused_concat
{
int64_t axis = 0;

std::string name() const { return "fused_concat"; }

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}

shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if((inputs.size() + 1) == mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
auto input_iter = inputs.begin();
std::vector<shape> concat_inputs;
for(module_ref mod : range(mods.begin(), mods.end() - 1))
{
concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size();
}
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
auto mismatch_it =
std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(),
lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
});
if(mismatch_it != concat_inputs.end())
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " +
std::to_string(axis) + ": {" + to_string_range(first_shape_lens) +
"} != {" + to_string_range(mismatch_it->lens()) + "}");

std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
};
MIGRAPHX_REGISTER_OP(fused_concat);

namespace {

struct find_pointwise_concat_pointwise
{
auto matcher() const
{
auto concat = match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat_ins = r.instructions["concat"];

auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
ins->inputs().begin();
std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise")
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
}
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != concat_ins; });

std::vector<module_ref> module_inputs;
static unsigned int counter = 0;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm =
mpm.create_module("concat:identity" + std::to_string(counter++));

auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
return pm;
});

auto* post_pm = ins->module_inputs().front();
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
std::vector<std::string> names = rm->get_parameter_names();
std::sort(names.begin(), names.end());
auto concat_param_name = names[concat_arg];
auto concat_param = rm->get_parameter(concat_param_name);
auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
rm->replace_instruction(concat_param, param);
rm->remove_instruction(concat_param);

module_inputs.push_back(rm);
mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};

} // namespace

void fuse_concat::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_pointwise_concat_pointwise{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
4 changes: 3 additions & 1 deletion src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -23,6 +23,7 @@
*/
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
Expand Down Expand Up @@ -242,6 +243,7 @@ struct find_pointwise_reshape_pointwise

void fuse_pointwise::apply(module_pass_manager& mpm) const
{
mpm.run_pass(eliminate_identity{});
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
Expand Down
43 changes: 43 additions & 0 deletions src/include/migraphx/fuse_concat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP

#include <migraphx/config.hpp>
#include <string>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module_pass_manager;

struct MIGRAPHX_EXPORT fuse_concat
{
std::string name() const { return "fuse_concat"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
5 changes: 4 additions & 1 deletion src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -242,7 +242,10 @@ struct MIGRAPHX_EXPORT module
MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return not(x == y); }

friend struct program;

private:
void set_name(const std::string& name);
void assign(const module& m);
void calc_implicit_deps(const module& smod,
const module& pmod,
Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/op/identity.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,6 +37,8 @@ struct identity
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape, std::vector<argument> args) const { return args[0]; }

value attributes() const { return {{"pointwise", true}, {"point_op", "${0}"}}; }

std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};

Expand Down
3 changes: 2 additions & 1 deletion src/include/migraphx/pass_manager.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -39,6 +39,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual module* create_module(const std::string& name, const module& m) = 0;
virtual module* get_common_parent() = 0;
virtual module* get_root_module() = 0;
virtual void run_pass(const pass& p) = 0;
Expand Down
3 changes: 2 additions & 1 deletion src/include/migraphx/program.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -136,6 +136,7 @@ struct MIGRAPHX_EXPORT program

// module related api
module* create_module(const std::string& name);
module* create_module(const std::string& name, module m);
module* get_module(const std::string& name);
const module* get_module(const std::string& name) const;

Expand Down
Loading

0 comments on commit 7532007

Please sign in to comment.