Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve reduction fusion with reshape operators #2698

Merged
merged 41 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d6791f3
Add custom inserter
pfultz2 Dec 22, 2023
42739b0
Format
pfultz2 Dec 22, 2023
350554c
Refactor the reshape rewriting
pfultz2 Jan 2, 2024
59f3ecf
Format
pfultz2 Jan 2, 2024
02e9989
Try to rewrite reshapes for reduce as well
pfultz2 Jan 2, 2024
f2d0745
Format
pfultz2 Jan 2, 2024
ef386c2
Sort axes
pfultz2 Jan 2, 2024
bcbe021
Merge branch 'develop' into fuse-reshape-pointwise
pfultz2 Jan 26, 2024
9249bdd
Add unit test
pfultz2 Jan 27, 2024
2c0f919
Format
pfultz2 Jan 27, 2024
0fea84f
Handle base shape
pfultz2 Jan 29, 2024
33eea67
Format
pfultz2 Jan 29, 2024
9c9fe65
Fixes for fusion
pfultz2 Jan 30, 2024
2fc67f7
Format
pfultz2 Jan 30, 2024
ff95292
Fix test
pfultz2 Jan 30, 2024
b974caa
Improve testing
pfultz2 Jan 30, 2024
b563470
Format
pfultz2 Jan 30, 2024
4e7a5d7
Merge branch 'develop' into fuse-reshape-pointwise
pfultz2 Feb 6, 2024
96ad06f
Update license
pfultz2 Feb 12, 2024
af71ca2
Format
pfultz2 Feb 12, 2024
9246cf0
Fix tidy warnings
pfultz2 Feb 13, 2024
464ce7e
Add unit tests
pfultz2 Feb 13, 2024
73b056b
Format
pfultz2 Feb 13, 2024
8eb47f7
Add tests
pfultz2 Feb 15, 2024
a335c13
Format
pfultz2 Feb 15, 2024
9ac82fc
Update year for test
pfultz2 Feb 15, 2024
9807e66
Add default
pfultz2 Feb 16, 2024
2b4cd04
Improve test coverage and fix bugs
pfultz2 Feb 16, 2024
147526e
Format
pfultz2 Feb 16, 2024
a59086c
Merge branch 'develop' into fuse-reshape-pointwise
pfultz2 Feb 28, 2024
34da10f
Add another test
pfultz2 Mar 2, 2024
d5a51f2
Format
pfultz2 Mar 2, 2024
ad0544c
Add TODOs
pfultz2 Mar 2, 2024
e40eb2f
Remove comment
pfultz2 Mar 18, 2024
69ec320
Add test for contiguous skip
pfultz2 Mar 18, 2024
7f07a85
Format
pfultz2 Mar 18, 2024
6e1c2d9
Update license
pfultz2 Mar 18, 2024
cec782d
Merge branch 'develop' into fuse-reshape-pointwise
pfultz2 Mar 18, 2024
96713e1
Format
pfultz2 Mar 18, 2024
895acf4
Add missing header
pfultz2 Mar 18, 2024
4c3fca8
Merge branch 'develop' into fuse-reshape-pointwise
pfultz2 Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions src/common_dims.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -43,12 +43,6 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
return it;
}

template <class Range>
static auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}

struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims,
Expand Down Expand Up @@ -152,5 +146,48 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
return cd;
}

const std::vector<std::vector<std::size_t>>* common_dims::get_axes_map(std::size_t n) const
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
{
if(axes_map1.size() == n)
return &axes_map1;
if(axes_map2.size() == n)
return &axes_map2;
return nullptr;
}

std::vector<std::size_t>
common_dims::get_dimensions_for(const std::vector<std::size_t>& idims) const
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
{
if(dims.size() == idims.size())
return idims;
if(elements(dims) == elements(idims))
return dims;
// Bail for now since its ambiguous which axes map can be used
// TODO: Check for similiarity
if(axes_map1.size() == axes_map2.size())
return {};
const auto* axes_map = get_axes_map(idims.size());
if(axes_map == nullptr)
return {};
auto xdims = dims;
for(auto i : range(axes_map->size()))
{
auto dim = idims[i];
const auto& axes = (*axes_map)[i];
if(axes.size() == 1)
{
xdims[axes.front()] = dim;
}
else if(dim == 1)
{
for(auto axis : axes)
xdims[axis] = 1;
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
}
}
if(elements(xdims) == elements(idims))
return xdims;
return {};
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
53 changes: 6 additions & 47 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <iterator>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
Expand Down Expand Up @@ -193,52 +192,13 @@ static bool find_pointwise_modules(module& m)
}
return changed;
}

namespace {
struct find_pointwise_reshape_pointwise
struct pointwise_reshape : rewrite_reshapes_base
{
auto matcher() const
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous = [](auto... ms) {
return match::arg(0)(match::skip(match::name("contiguous")(match::used_once()))(ms...));
};
auto pointwise = match::name("pointwise")(match::used_once());
auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape");
return match::name("pointwise")(match::any_of[match::inputs()](reshape_pointwise));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];

auto cd = common_dims::compute(ins->get_shape().lens(), x_ins->get_shape().lens());
if(cd.dims.empty())
return;

auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input);
};
};
auto x_inputs = x_ins->inputs();
std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins));
auto new_x_ins =
m.insert_instruction(x_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs());

auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == reshape_ins)
return new_x_ins;
return reshape_input(ins)(input);
});
auto pw = m.insert_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw);
}
static std::string name() { return "pointwise"; }
};

} // namespace

void fuse_pointwise::apply(module_pass_manager& mpm) const
Expand All @@ -252,8 +212,7 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
for(int i = 0; i < 8; i++)
{
match::find_matches(mpm.get_module(), find_pointwise_reshape_pointwise{});
mpm.run_pass(simplify_reshapes{1});
mpm.run_pass(rewrite_reshapes<pointwise_reshape>{});
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
Expand Down
99 changes: 90 additions & 9 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <iterator>
#include <map>

Expand Down Expand Up @@ -100,11 +101,11 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
}

static void insert_params(module_ref sm,
instruction_ref ins,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input : ins->inputs())
for(auto input : inputs)
{
if(contains(map_ins, input))
continue;
Expand All @@ -117,7 +118,7 @@ static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
insert_params(sm, ins->inputs(), map_ins);
return sm->add_instructions({ins}, map_ins);
}

Expand All @@ -129,17 +130,37 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)

static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
const std::vector<instruction_ref>& inputs,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
{
insert_params(sm, ins, map_ins);
auto* m = ins->module_inputs().front();
auto param_map = get_ins_param_map(ins->inputs(), m);
insert_params(sm, inputs, map_ins);
auto param_map = get_ins_param_map(inputs, m);
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, map_ins);
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
return sm->add_instructions(m, map_ins, std::move(insert));
}

static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
{
return insert_module_in_submodule(
sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert));
}

static auto insert_module_in_submodule(module_ref sm,
const std::vector<instruction_ref>& inputs,
module_ref m,
module::inserter insert = nullptr)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert));
}

static std::vector<instruction_ref>
Expand Down Expand Up @@ -332,6 +353,65 @@ struct find_reduce_reduce
}
};

struct reduce_reshape : rewrite_reshapes_base
{
static std::string name() { return "fused_reduce"; }

template <class Transform>
static auto transform_op(Transform t)
{
return [=](module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args) {
auto new_op = t(op);
return m.insert_instruction(ins, new_op, inputs, mod_args);
};
}

template <class AxesMap>
static instruction_ref insert(module_pass_manager& mpm,
instruction_ref ins,
const std::vector<instruction_ref>& inputs,
const AxesMap& am)
{
auto op = any_cast<fused_reduce>(ins->get_operator());
std::vector<int64_t> axes;
for(auto axis : op.axes)
{
auto new_axes = am.at(axis);
axes.insert(axes.end(), new_axes.begin(), new_axes.end());
}
std::sort(axes.begin(), axes.end());
auto dims = base_dims(inputs);
auto* oldm = ins->module_inputs().front();
auto* sm = mpm.create_module(oldm->name() + "_reshape");
insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
}

static std::vector<std::size_t> base_dims(const std::vector<instruction_ref>& inputs)
{
auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) {
return i->get_shape().elements();
}));
return (*input)->get_shape().lens();
}

static std::vector<std::size_t> base_dims(instruction_ref ins)
{
return base_dims(ins->inputs());
}
};

} // namespace

void fuse_reduce::apply(module_pass_manager& mpm) const
Expand All @@ -340,6 +420,7 @@ void fuse_reduce::apply(module_pass_manager& mpm) const
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 4; i++)
{
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
Expand Down
17 changes: 16 additions & 1 deletion src/include/migraphx/common_dims.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,8 @@

#include <migraphx/config.hpp>
#include <cstdint>
#include <functional>
#include <numeric>
#include <vector>

namespace migraphx {
Expand All @@ -39,11 +41,24 @@ struct MIGRAPHX_EXPORT common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);

/// Map the dimensions into the common higher dimensional space. The
/// dimension doesnt need to have the same number of elements as the
/// common dimension.
std::vector<std::size_t> get_dimensions_for(const std::vector<std::size_t>& idims) const;
/// Get the corresponding axes map based on the rank of tensor
const std::vector<std::vector<std::size_t>>* get_axes_map(std::size_t n) const;
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};

template <class Range>
auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
Loading
Loading