Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite reduce mean/variance #2883

Merged
merged 49 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b3f8fc1
Rewrite reduce mean variance
pfultz2 Feb 19, 2024
b5e17af
Format
pfultz2 Feb 19, 2024
6736534
Add codegen
pfultz2 Feb 20, 2024
d2064c1
Format
pfultz2 Feb 20, 2024
c3244b3
Improve fusions with broadcast
pfultz2 Feb 20, 2024
97c9b78
Format
pfultz2 Feb 20, 2024
04bf76c
Get inputs
pfultz2 Feb 20, 2024
4389985
Format
pfultz2 Feb 20, 2024
299274f
Remove semicolon
pfultz2 Feb 20, 2024
7fdfc9c
Optimize before
pfultz2 Feb 20, 2024
2c4b949
Format
pfultz2 Feb 20, 2024
defca20
Fixes
pfultz2 Feb 21, 2024
d7b9593
Format
pfultz2 Feb 21, 2024
cd67d8f
Check shape is the same as well
pfultz2 Feb 21, 2024
974b44f
Skip reduce_mean
pfultz2 Feb 21, 2024
2aab79e
Rewrite reduce_mean
pfultz2 Feb 21, 2024
05fdeed
Format
pfultz2 Feb 21, 2024
c916f61
Fix div by zero
pfultz2 Feb 21, 2024
611051e
Format
pfultz2 Feb 21, 2024
86ef51e
Merge branch 'develop' into reduce-mean-variance
pfultz2 Feb 21, 2024
fcc1781
Merge
pfultz2 Mar 12, 2024
f0fbe43
Format
pfultz2 Mar 12, 2024
b678037
Use explict constructor
pfultz2 Mar 12, 2024
b1abc01
Merge branch 'develop' into reduce-mean-variance
pfultz2 Mar 20, 2024
2d6face
Add license
pfultz2 Mar 20, 2024
3d44fb3
Format
pfultz2 Mar 20, 2024
e7dd3a0
Fix tidy warnings
pfultz2 Mar 20, 2024
ca3bd58
Format
pfultz2 Mar 20, 2024
1f0dab3
Add test
pfultz2 Mar 26, 2024
5fe11f1
Format
pfultz2 Mar 26, 2024
a82c410
Fix test case
pfultz2 Mar 26, 2024
e819d9d
Rewrite test case
pfultz2 Mar 26, 2024
ca7c7c4
Format
pfultz2 Mar 26, 2024
da23ade
Update docs
pfultz2 Mar 28, 2024
459a337
Fix parallel reduce
pfultz2 Mar 29, 2024
8f2f2c1
Format
pfultz2 Mar 29, 2024
b345606
Add test
pfultz2 Mar 29, 2024
152b052
Format
pfultz2 Mar 29, 2024
d0f2ed2
Update license
pfultz2 Mar 29, 2024
07f6ca5
Rename variables
pfultz2 Apr 11, 2024
dce056f
Format
pfultz2 Apr 11, 2024
75c18b7
Remove comment
pfultz2 Apr 11, 2024
1d94de4
Add more unit tests
pfultz2 Apr 11, 2024
792a0a5
Format
pfultz2 Apr 11, 2024
fba41ea
Fix compile error with test
pfultz2 Apr 24, 2024
7ba5dd4
Handle contiguous output
pfultz2 Apr 24, 2024
cc0c435
Format
pfultz2 Apr 24, 2024
98b69d4
Format
pfultz2 Apr 26, 2024
52fe8e9
Merge branch 'develop' into reduce-mean-variance
pfultz2 Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ Disables use of the rocMLIR library.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Uses ``hip_copy_to_gpu`` with a new ``literal`` instruction rather than using ``hip_copy_literal{}``.

.. envvar:: MIGRAPHX_DISABLE_LAYERNORM_FUSION

Set to "1", "enable", "enabled", "yes", or "true" to use.
Disables layrnorm fusion.

Compilation traces
----------------------

Expand Down
4 changes: 2 additions & 2 deletions src/dom_info.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -31,7 +31,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2)
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) const
{
if(ins1 == ins2)
return false;
Expand Down
47 changes: 46 additions & 1 deletion src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
Expand Down Expand Up @@ -185,6 +186,41 @@ static void create_reduce_modules(module_pass_manager& mpm)
}
}

namespace {

instruction_ref get_broadcast_output(instruction_ref broadcast)
{
if(broadcast->outputs().size() != 1)
return broadcast;
auto output = broadcast->outputs().front();
if(output->name() == "contiguous")
return get_broadcast_output(output);
return output;
}

MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins)
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
{
if(ins->outputs().size() == 1)
return true;
if(ins->outputs().size() == 2)
{
auto is_broadcast = [](instruction_ref output) {
return contains(output->name(), "broadcast");
};
auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
if(broadcast == ins->outputs().end())
return false;
auto non_broadcast =
std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
if(non_broadcast == ins->outputs().end())
return false;
auto output = get_broadcast_output(*broadcast);
return output == *non_broadcast;
}

return false;
}
} // namespace
template <class... Ms>
static auto match_broadcast(Ms... ms)
{
Expand All @@ -202,12 +238,18 @@ static auto any_input(Ms... ms)

static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(match::used_once()).bind(name);
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, broadcast_match_op_input);
}

static void finalize_reduce_module(module_ref m)
{
eliminate_common_subexpression{}.apply(*m);
dead_code_elimination{}.apply(*m);
}

namespace {
struct find_pointwise_reduce
{
Expand Down Expand Up @@ -242,6 +284,7 @@ struct find_pointwise_reduce

// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, map_ins));
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
Expand Down Expand Up @@ -283,6 +326,7 @@ struct find_reduce_pointwise

auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
Expand Down Expand Up @@ -327,6 +371,7 @@ struct find_reduce_reduce

auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
Expand Down
4 changes: 2 additions & 2 deletions src/include/migraphx/dom_info.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -36,7 +36,7 @@ struct module;

struct MIGRAPHX_EXPORT dominator_info
{
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2);
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2) const;

std::unordered_map<instruction_ref, instruction_ref> ins2idom;
};
Expand Down
76 changes: 60 additions & 16 deletions src/rewrite_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,87 @@ struct find_softmax
}
};

struct find_reduce_mean_variance
{
auto matcher() const
{
auto reduce_mean = match::name("reduce_mean");
auto x_minus_mean =
match::name("sub")(match::arg(0)(match::any().bind("x")),
match::arg(1)(match::skip_broadcasts(reduce_mean.bind("mean"))));
auto pow_x_minus_mean =
match::name("pow")(match::arg(0)(x_minus_mean), match::arg(1)(match::has_value(2.0f)));
auto mul_x_minus_mean =
match::name("mul")(match::arg(0)(x_minus_mean), match::arg(1)(x_minus_mean));
auto sqdiff = match::name("sqdiff")(match::either_arg(0, 1)(
match::any().bind("x"), skip_broadcasts(reduce_mean.bind("mean"))));
return reduce_mean(
match::arg(0)(match::any_of(pow_x_minus_mean, mul_x_minus_mean, sqdiff)));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto mean = r.instructions["mean"];

if(ins->get_operator() != mean->get_operator())
return;

if(mean->inputs().front() != x_ins)
return;

auto x2 = m.insert_instruction(ins, make_op("mul"), x_ins, x_ins);
auto mean_x2 = m.insert_instruction(ins, mean->get_operator(), x2);
auto mean_x_2 = m.insert_instruction(ins, make_op("mul"), mean, mean);
m.replace_instruction(ins, make_op("sub"), mean_x2, mean_x_2);
}
};

struct find_reduce_mean
{
auto matcher() const { return match::name("reduce_mean"); }

void apply(module& m, const match::matcher_result& r) const
{
auto reduce_mean = r.result;
auto op = reduce_mean->get_operator().to_value();
auto axes = op["axes"].to_vector<std::int64_t>();
auto input = reduce_mean->inputs().front();
auto ins = r.result;
auto op = ins->get_operator().to_value();
auto axes = op["axes"].to_vector<std::int64_t>();
auto input = ins->inputs().front();

bool is_integral = false;
double max_n = 0;
std::size_t size = 0;
input->get_shape().visit_type([&](auto t) {
is_integral = t.is_integral();
max_n = t.max();
size = t.size();
});

auto n = input->get_shape().elements() / reduce_mean->get_shape().elements();
auto n = input->get_shape().elements() / ins->get_shape().elements();

// avoid overflow (the larger value will be later handled)
if(n >= max_n / 4)
return;
if(n >= max_n / 4 and size < 3)
{
shape::type_t t = is_integral ? shape::int32_type : shape::float_type;
input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input);
}

auto n_literal = m.add_literal(literal{{input->get_shape().type(), {1}}, {n}});
if(is_integral)
{
auto reduce_sum =
m.insert_instruction(reduce_mean, make_op("reduce_sum", {{"axes", axes}}), input);
auto div = insert_common_op(m, reduce_mean, make_op("div"), {reduce_sum, n_literal});
m.replace_instruction(reduce_mean, div);
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), input);
auto div = insert_common_op(m, ins, make_op("div"), {reduce_sum, n_literal});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), div);
}
else
{
auto new_input = insert_common_op(m, reduce_mean, make_op("div"), {input, n_literal});
auto reduce_sum = m.insert_instruction(
reduce_mean, make_op("reduce_sum", {{"axes", axes}}), new_input);
m.replace_instruction(reduce_mean, reduce_sum);
auto new_input = insert_common_op(m, ins, make_op("div"), {input, n_literal});
auto reduce_sum =
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), new_input);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), reduce_sum);
}
}
};
Expand All @@ -101,7 +144,8 @@ struct find_reduce_mean

void rewrite_reduce::apply(module& m) const
{
match::find_matches(m, find_softmax{}, find_reduce_mean{});
match::find_matches(m, find_softmax{}, find_reduce_mean_variance{});
match::find_matches(m, find_reduce_mean{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
prefuse_ops.cpp
prepare_reduce.cpp
perfdb.cpp
pooling.cpp
problem_cache.cpp
Expand Down
58 changes: 42 additions & 16 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/prepare_reduce.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
Expand Down Expand Up @@ -215,19 +216,21 @@ std::string generate_pointwise(const module& pm, const std::string& name, bool a

std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" +
join_strings(inputs, ", ") + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
void reduce_op::set(const std::string& name, const shape& input, const shape& output)
{
if(op.name() == "reduce_sum")
assert(input.type() != shape::tuple_type);
assert(output.type() != shape::tuple_type);
if(name == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
else if(name == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
auto reduce_elements = input.elements() / output.elements();
auto reduce_type = input.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
Expand All @@ -238,17 +241,17 @@ void reduce_op::set(instruction_ref ins, const operation& op)
else
write = mean;
}
else if(op.name() == "reduce_max")
else if(name == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
else if(name == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
else if(name == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
Expand All @@ -258,7 +261,23 @@ void reduce_op::set(instruction_ref ins, const operation& op)
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)

void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "gpu::parallel_reduce")
{
auto rop = from_value<operation>(op.to_value().at("op"));
auto input = ins->inputs().front()->get_shape();
auto output = ins->get_shape().sub_shapes().front();
set(rop.name(), input, output);
read = "compose(array_apply(" + read + "), MIGRAPHX_LIFT(make_array))";
}
else
{
set(op.name(), ins->inputs().front()->get_shape(), ins->get_shape());
}
}
std::string reduce_op::generate(instruction_ref ins, const std::vector<std::string>& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
Expand Down Expand Up @@ -289,7 +308,7 @@ void preload_params(module& m)
std::string generate_reduce(module m, const std::string& name)
{
preload_params(m);
run_passes(m, {optimize_module{}});
run_passes(m, {optimize_module{}, prepare_reduce{}, optimize_module{}});
m.sort();
cpp_generator g;
auto param_shapes = m.get_parameter_shapes();
Expand All @@ -302,9 +321,9 @@ std::string generate_reduce(module m, const std::string& name)
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
return reduce_op::generate(ins, cpp_generator::to_args(ins->inputs(), names));
}
else if(ins->name() == "pointwise")
if(ins->name() == "pointwise")
{
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
Expand Down Expand Up @@ -346,11 +365,18 @@ std::string generate_reduce(module m, const std::string& name)
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
else if(ins->name() == "identity")
if(ins->name() == "get_tuple_elem")
{
const auto& x = names.at(ins->inputs().front());
auto index = ins->get_operator().to_value()["index"].to<std::size_t>();
return interpolate_string("${x}[${index}]",
{{"x", x}, {"index", std::to_string(index)}});
}
if(ins->name() == "identity")
{
const auto& x = names.at(ins->inputs().front());
return "r.inner(op::id{})(" + x + ")";
Expand Down
Loading
Loading