Skip to content

Commit

Permalink
Improve reduction fusion with reshape operators (#2698)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Mar 18, 2024
1 parent 5c2302e commit 9077e74
Show file tree
Hide file tree
Showing 10 changed files with 714 additions and 91 deletions.
51 changes: 44 additions & 7 deletions src/common_dims.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -43,12 +43,6 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
return it;
}

template <class Range>
static auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}

struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims,
Expand Down Expand Up @@ -152,5 +146,48 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
return cd;
}

const std::vector<std::vector<std::size_t>>* common_dims::get_axes_map(std::size_t n) const
{
if(axes_map1.size() == n)
return &axes_map1;
if(axes_map2.size() == n)
return &axes_map2;
return nullptr;
}

std::vector<std::size_t>
common_dims::get_dimensions_for(const std::vector<std::size_t>& idims) const
{
if(dims.size() == idims.size())
return idims;
if(elements(dims) == elements(idims))
return dims;
// Bail for now since its ambiguous which axes map can be used
// TODO: Check for similiarity
if(axes_map1.size() == axes_map2.size())
return {};
const auto* axes_map = get_axes_map(idims.size());
if(axes_map == nullptr)
return {};
auto xdims = dims;
for(auto i : range(axes_map->size()))
{
auto dim = idims[i];
const auto& axes = (*axes_map)[i];
if(axes.size() == 1)
{
xdims[axes.front()] = dim;
}
else if(dim == 1)
{
for(auto axis : axes)
xdims[axis] = 1;
}
}
if(elements(xdims) == elements(idims))
return xdims;
return {};
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
53 changes: 6 additions & 47 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <iterator>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
Expand Down Expand Up @@ -193,52 +192,13 @@ static bool find_pointwise_modules(module& m)
}
return changed;
}

namespace {
struct find_pointwise_reshape_pointwise
struct pointwise_reshape : rewrite_reshapes_base
{
auto matcher() const
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous = [](auto... ms) {
return match::arg(0)(match::skip(match::name("contiguous")(match::used_once()))(ms...));
};
auto pointwise = match::name("pointwise")(match::used_once());
auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape");
return match::name("pointwise")(match::any_of[match::inputs()](reshape_pointwise));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];

auto cd = common_dims::compute(ins->get_shape().lens(), x_ins->get_shape().lens());
if(cd.dims.empty())
return;

auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input);
};
};
auto x_inputs = x_ins->inputs();
std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins));
auto new_x_ins =
m.insert_instruction(x_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs());

auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == reshape_ins)
return new_x_ins;
return reshape_input(ins)(input);
});
auto pw = m.insert_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw);
}
static std::string name() { return "pointwise"; }
};

} // namespace

void fuse_pointwise::apply(module_pass_manager& mpm) const
Expand All @@ -252,8 +212,7 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
for(int i = 0; i < 8; i++)
{
match::find_matches(mpm.get_module(), find_pointwise_reshape_pointwise{});
mpm.run_pass(simplify_reshapes{1});
mpm.run_pass(rewrite_reshapes<pointwise_reshape>{});
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
Expand Down
99 changes: 90 additions & 9 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <iterator>
#include <map>

Expand Down Expand Up @@ -100,11 +101,11 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
}

static void insert_params(module_ref sm,
instruction_ref ins,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input : ins->inputs())
for(auto input : inputs)
{
if(contains(map_ins, input))
continue;
Expand All @@ -117,7 +118,7 @@ static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
insert_params(sm, ins->inputs(), map_ins);
return sm->add_instructions({ins}, map_ins);
}

Expand All @@ -129,17 +130,37 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)

static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
const std::vector<instruction_ref>& inputs,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
{
insert_params(sm, ins, map_ins);
auto* m = ins->module_inputs().front();
auto param_map = get_ins_param_map(ins->inputs(), m);
insert_params(sm, inputs, map_ins);
auto param_map = get_ins_param_map(inputs, m);
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, map_ins);
return sm->add_instructions(m, map_ins, std::move(insert));
}

static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
{
return insert_module_in_submodule(
sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert));
}

static auto insert_module_in_submodule(module_ref sm,
const std::vector<instruction_ref>& inputs,
module_ref m,
module::inserter insert = nullptr)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert));
}

static std::vector<instruction_ref>
Expand Down Expand Up @@ -332,6 +353,65 @@ struct find_reduce_reduce
}
};

struct reduce_reshape : rewrite_reshapes_base
{
static std::string name() { return "fused_reduce"; }

template <class Transform>
static auto transform_op(Transform t)
{
return [=](module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args) {
auto new_op = t(op);
return m.insert_instruction(ins, new_op, inputs, mod_args);
};
}

template <class AxesMap>
static instruction_ref insert(module_pass_manager& mpm,
instruction_ref ins,
const std::vector<instruction_ref>& inputs,
const AxesMap& am)
{
auto op = any_cast<fused_reduce>(ins->get_operator());
std::vector<int64_t> axes;
for(auto axis : op.axes)
{
auto new_axes = am.at(axis);
axes.insert(axes.end(), new_axes.begin(), new_axes.end());
}
std::sort(axes.begin(), axes.end());
auto dims = base_dims(inputs);
auto* oldm = ins->module_inputs().front();
auto* sm = mpm.create_module(oldm->name() + "_reshape");
insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
}

static std::vector<std::size_t> base_dims(const std::vector<instruction_ref>& inputs)
{
auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) {
return i->get_shape().elements();
}));
return (*input)->get_shape().lens();
}

static std::vector<std::size_t> base_dims(instruction_ref ins)
{
return base_dims(ins->inputs());
}
};

} // namespace

void fuse_reduce::apply(module_pass_manager& mpm) const
Expand All @@ -340,6 +420,7 @@ void fuse_reduce::apply(module_pass_manager& mpm) const
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 4; i++)
{
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
Expand Down
17 changes: 16 additions & 1 deletion src/include/migraphx/common_dims.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,8 @@

#include <migraphx/config.hpp>
#include <cstdint>
#include <functional>
#include <numeric>
#include <vector>

namespace migraphx {
Expand All @@ -39,11 +41,24 @@ struct MIGRAPHX_EXPORT common_dims
{
static common_dims compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2);

/// Map the dimensions into the common higher dimensional space. The
/// dimension doesnt need to have the same number of elements as the
/// common dimension.
std::vector<std::size_t> get_dimensions_for(const std::vector<std::size_t>& idims) const;
/// Get the corresponding axes map based on the rank of tensor
const std::vector<std::vector<std::size_t>>* get_axes_map(std::size_t n) const;
std::vector<std::size_t> dims;
std::vector<std::vector<std::size_t>> axes_map1;
std::vector<std::vector<std::size_t>> axes_map2;
};

template <class Range>
auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
Loading

0 comments on commit 9077e74

Please sign in to comment.