Skip to content

Commit

Permalink
Rewrite reduce mean/variance (#2883)
Browse files Browse the repository at this point in the history
Rewrites mean/variance to use reduce_mean(x) and reduce_mean(x*x) so it can be fused in the same reduction.
  • Loading branch information
pfultz2 authored Apr 27, 2024
1 parent ee68f72 commit 56d341d
Show file tree
Hide file tree
Showing 17 changed files with 756 additions and 48 deletions.
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ Disables use of the rocMLIR library.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Uses ``hip_copy_to_gpu`` with a new ``literal`` instruction rather than using ``hip_copy_literal{}``.

.. envvar:: MIGRAPHX_DISABLE_LAYERNORM_FUSION

Set to "1", "enable", "enabled", "yes", or "true" to use.
Disables layrnorm fusion.

Compilation traces
----------------------

Expand Down
4 changes: 2 additions & 2 deletions src/dom_info.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -31,7 +31,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2)
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) const
{
if(ins1 == ins2)
return false;
Expand Down
47 changes: 46 additions & 1 deletion src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
Expand Down Expand Up @@ -185,6 +186,41 @@ static void create_reduce_modules(module_pass_manager& mpm)
}
}

namespace {

instruction_ref get_broadcast_output(instruction_ref broadcast)
{
if(broadcast->outputs().size() != 1)
return broadcast;
auto output = broadcast->outputs().front();
if(output->name() == "contiguous")
return get_broadcast_output(output);
return output;
}

MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return true;
if(ins->outputs().size() == 2)
{
auto is_broadcast = [](instruction_ref output) {
return contains(output->name(), "broadcast");
};
auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
if(broadcast == ins->outputs().end())
return false;
auto non_broadcast =
std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
if(non_broadcast == ins->outputs().end())
return false;
auto output = get_broadcast_output(*broadcast);
return output == *non_broadcast;
}

return false;
}
} // namespace
template <class... Ms>
static auto match_broadcast(Ms... ms)
{
Expand All @@ -202,12 +238,18 @@ static auto any_input(Ms... ms)

static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(match::used_once()).bind(name);
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, broadcast_match_op_input);
}

static void finalize_reduce_module(module_ref m)
{
eliminate_common_subexpression{}.apply(*m);
dead_code_elimination{}.apply(*m);
}

namespace {
struct find_pointwise_reduce
{
Expand Down Expand Up @@ -242,6 +284,7 @@ struct find_pointwise_reduce

// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, map_ins));
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
Expand Down Expand Up @@ -283,6 +326,7 @@ struct find_reduce_pointwise

auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
Expand Down Expand Up @@ -327,6 +371,7 @@ struct find_reduce_reduce

auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
Expand Down
4 changes: 2 additions & 2 deletions src/include/migraphx/dom_info.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -36,7 +36,7 @@ struct module;

struct MIGRAPHX_EXPORT dominator_info
{
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2);
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2) const;

std::unordered_map<instruction_ref, instruction_ref> ins2idom;
};
Expand Down
76 changes: 60 additions & 16 deletions src/rewrite_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,87 @@ struct find_softmax
}
};

struct find_reduce_mean_variance
{
auto matcher() const
{
auto reduce_mean = match::name("reduce_mean");
auto x_minus_mean =
match::name("sub")(match::arg(0)(match::any().bind("x")),
match::arg(1)(match::skip_broadcasts(reduce_mean.bind("mean"))));
auto pow_x_minus_mean =
match::name("pow")(match::arg(0)(x_minus_mean), match::arg(1)(match::has_value(2.0f)));
auto mul_x_minus_mean =
match::name("mul")(match::arg(0)(x_minus_mean), match::arg(1)(x_minus_mean));
auto sqdiff = match::name("sqdiff")(match::either_arg(0, 1)(
match::any().bind("x"), skip_broadcasts(reduce_mean.bind("mean"))));
return reduce_mean(
match::arg(0)(match::any_of(pow_x_minus_mean, mul_x_minus_mean, sqdiff)));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto mean = r.instructions["mean"];

if(ins->get_operator() != mean->get_operator())
return;

if(mean->inputs().front() != x_ins)
return;

auto x2 = m.insert_instruction(ins, make_op("mul"), x_ins, x_ins);
auto mean_x2 = m.insert_instruction(ins, mean->get_operator(), x2);
auto mean_x_2 = m.insert_instruction(ins, make_op("mul"), mean, mean);
m.replace_instruction(ins, make_op("sub"), mean_x2, mean_x_2);
}
};

struct find_reduce_mean
{
auto matcher() const { return match::name("reduce_mean"); }

void apply(module& m, const match::matcher_result& r) const
{
auto reduce_mean = r.result;
auto op = reduce_mean->get_operator().to_value();
auto axes = op["axes"].to_vector<std::int64_t>();
auto input = reduce_mean->inputs().front();
auto ins = r.result;
auto op = ins->get_operator().to_value();
auto axes = op["axes"].to_vector<std::int64_t>();
auto input = ins->inputs().front();

bool is_integral = false;
double max_n = 0;
std::size_t size = 0;
input->get_shape().visit_type([&](auto t) {
is_integral = t.is_integral();
max_n = t.max();
size = t.size();
});

auto n = input->get_shape().elements() / reduce_mean->get_shape().elements();
auto n = input->get_shape().elements() / ins->get_shape().elements();

// avoid overflow (the larger value will be later handled)
if(n >= max_n / 4)
return;
if(n >= max_n / 4 and size < 3)
{
shape::type_t t = is_integral ? shape::int32_type : shape::float_type;
input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input);
}

auto n_literal = m.add_literal(literal{{input->get_shape().type(), {1}}, {n}});
if(is_integral)
{
auto reduce_sum =
m.insert_instruction(reduce_mean, make_op("reduce_sum", {{"axes", axes}}), input);
auto div = insert_common_op(m, reduce_mean, make_op("div"), {reduce_sum, n_literal});
m.replace_instruction(reduce_mean, div);
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), input);
auto div = insert_common_op(m, ins, make_op("div"), {reduce_sum, n_literal});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), div);
}
else
{
auto new_input = insert_common_op(m, reduce_mean, make_op("div"), {input, n_literal});
auto reduce_sum = m.insert_instruction(
reduce_mean, make_op("reduce_sum", {{"axes", axes}}), new_input);
m.replace_instruction(reduce_mean, reduce_sum);
auto new_input = insert_common_op(m, ins, make_op("div"), {input, n_literal});
auto reduce_sum =
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), new_input);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), reduce_sum);
}
}
};
Expand All @@ -101,7 +144,8 @@ struct find_reduce_mean

void rewrite_reduce::apply(module& m) const
{
match::find_matches(m, find_softmax{}, find_reduce_mean{});
match::find_matches(m, find_softmax{}, find_reduce_mean_variance{});
match::find_matches(m, find_reduce_mean{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
prefuse_ops.cpp
prepare_reduce.cpp
perfdb.cpp
pooling.cpp
problem_cache.cpp
Expand Down
58 changes: 42 additions & 16 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/prepare_reduce.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
Expand Down Expand Up @@ -215,19 +216,21 @@ std::string generate_pointwise(const module& pm, const std::string& name, bool a

std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" +
join_strings(inputs, ", ") + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
void reduce_op::set(const std::string& name, const shape& input, const shape& output)
{
if(op.name() == "reduce_sum")
assert(input.type() != shape::tuple_type);
assert(output.type() != shape::tuple_type);
if(name == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
else if(name == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
auto reduce_elements = input.elements() / output.elements();
auto reduce_type = input.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
Expand All @@ -238,17 +241,17 @@ void reduce_op::set(instruction_ref ins, const operation& op)
else
write = mean;
}
else if(op.name() == "reduce_max")
else if(name == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
else if(name == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
else if(name == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
Expand All @@ -258,7 +261,23 @@ void reduce_op::set(instruction_ref ins, const operation& op)
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)

void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "gpu::parallel_reduce")
{
auto rop = from_value<operation>(op.to_value().at("op"));
auto input = ins->inputs().front()->get_shape();
auto output = ins->get_shape().sub_shapes().front();
set(rop.name(), input, output);
read = "compose(array_apply(" + read + "), MIGRAPHX_LIFT(make_array))";
}
else
{
set(op.name(), ins->inputs().front()->get_shape(), ins->get_shape());
}
}
std::string reduce_op::generate(instruction_ref ins, const std::vector<std::string>& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
Expand Down Expand Up @@ -289,7 +308,7 @@ void preload_params(module& m)
std::string generate_reduce(module m, const std::string& name)
{
preload_params(m);
run_passes(m, {optimize_module{}});
run_passes(m, {optimize_module{}, prepare_reduce{}, optimize_module{}});
m.sort();
cpp_generator g;
auto param_shapes = m.get_parameter_shapes();
Expand All @@ -302,9 +321,9 @@ std::string generate_reduce(module m, const std::string& name)
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
return reduce_op::generate(ins, cpp_generator::to_args(ins->inputs(), names));
}
else if(ins->name() == "pointwise")
if(ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
Expand Down Expand Up @@ -346,11 +365,18 @@ std::string generate_reduce(module m, const std::string& name)
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
else if(ins->name() == "identity")
if(ins->name() == "get_tuple_elem")
{
const auto& x = names.at(ins->inputs().front());
auto index = ins->get_operator().to_value()["index"].to<std::size_t>();
return interpolate_string("${x}[${index}]",
{{"x", x}, {"index", std::to_string(index)}});
}
if(ins->name() == "identity")
{
const auto& x = names.at(ins->inputs().front());
return "r.inner(op::id{})(" + x + ")";
Expand Down
Loading

0 comments on commit 56d341d

Please sign in to comment.