Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse inputs with mlir #3010

Merged
merged 35 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d39f832
Add fuse mthods to module
pfultz2 Apr 18, 2024
d2d3bae
Format
pfultz2 Apr 18, 2024
af83509
Add some initial code
pfultz2 Apr 18, 2024
ac47954
Format
pfultz2 Apr 18, 2024
c9407aa
Reuse find_inputs
pfultz2 Apr 18, 2024
ea41fb9
Format
pfultz2 Apr 18, 2024
61d788c
Enable with env var
pfultz2 Apr 25, 2024
cad9d3d
Merge branch 'develop' into mlir-fuse-inputs
pfultz2 Apr 26, 2024
52a6a0e
Merge branch 'develop' into mlir-fuse-inputs
pfultz2 Apr 29, 2024
66e9d31
Update comments
pfultz2 May 14, 2024
9107b26
Format
pfultz2 May 14, 2024
cbf3afc
Fix param_utils
pfultz2 May 14, 2024
3947949
Filter supported ops
pfultz2 May 14, 2024
ffdba3c
Add another comment
pfultz2 May 14, 2024
6f99033
Merge
pfultz2 May 22, 2024
48712c9
Handle scalars
pfultz2 May 23, 2024
c3cf902
Format
pfultz2 May 23, 2024
28b013e
Merge
pfultz2 Jun 15, 2024
f9df6fa
Merge branch 'develop' into mlir-fuse-inputs
pfultz2 Jun 18, 2024
51d3ea9
Add description
pfultz2 Jun 18, 2024
0e60085
Update src/targets/gpu/fuse_mlir.cpp
causten Jun 19, 2024
fd0b7f7
Add doc
pfultz2 Jun 20, 2024
9c4d659
Add input fusion to jenkins
pfultz2 Jun 20, 2024
5eaaed3
Add unit test for fuse module
pfultz2 Jun 21, 2024
1edac2d
Format
pfultz2 Jun 21, 2024
4a825cb
Merge branch 'develop' into mlir-fuse-inputs
causten Jun 21, 2024
1ebdaf1
Add unit test
pfultz2 Jun 21, 2024
d035f3b
Format
pfultz2 Jun 21, 2024
271ea78
Add verify test
pfultz2 Jun 21, 2024
43e76f5
Format
pfultz2 Jun 21, 2024
b357f94
Fix tidy issue
pfultz2 Jun 21, 2024
5e848ba
Fix parameter name
pfultz2 Jun 22, 2024
ee50c26
Merge branch 'develop' into mlir-fuse-inputs
causten Jul 2, 2024
5a7f247
Merge branch 'develop' into mlir-fuse-inputs
umangyadav Jul 10, 2024
dd7985f
Merge branch 'develop' into mlir-fuse-inputs
umangyadav Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
}
}, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot']) {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) {
def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
Expand Down
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ Performs exhaustive tuning for MLIR.
Set to an integer greater than 1.
Limits the number of solutions available to MLIR for tuning.

.. envvar:: MIGRAPHX_ENABLE_MLIR_INPUT_FUSION

Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable input fusions in MLIR.

CK vars
-----------

Expand Down
130 changes: 26 additions & 104 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <migraphx/param_utils.hpp>
#include <iterator>
#include <map>

Expand Down Expand Up @@ -91,93 +92,14 @@ MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins)
return input_shape.ndim() == output_shape.ndim();
}

static void insert_params(module_ref sm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input : inputs)
{
if(contains(map_ins, input))
continue;
map_ins[input] =
sm->add_parameter("x" + std::to_string(n++), input->get_shape().as_standard());
}
}

static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins->inputs(), map_ins);
return sm->add_instructions({ins}, &map_ins);
}

static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_ins_in_submodule(sm, ins, map_ins);
}

static auto
insert_module_in_submodule(module_ref sm,
const std::vector<instruction_ref>& inputs,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
{
insert_params(sm, inputs, map_ins);
auto param_map = m->get_ins_param_map(inputs);
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, &map_ins, std::move(insert));
}

static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
module::inserter insert = nullptr)
{
return insert_module_in_submodule(
sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert));
}

static auto insert_module_in_submodule(module_ref sm,
const std::vector<instruction_ref>& inputs,
module_ref m,
module::inserter insert = nullptr)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert));
}

static std::vector<instruction_ref>
find_inputs(const_module_ref sm,
const module& parent,
const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param] : map_ins)
{
if(not sm->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
if(not parent.has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(result.size() == sm->get_parameter_shapes().size());
return result;
assert(ins->module_inputs().size() == 1);
return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert));
}

static void create_reduce_modules(module_pass_manager& mpm)
Expand All @@ -194,7 +116,7 @@ static void create_reduce_modules(module_pass_manager& mpm)
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass();

rm->add_return(insert_ins_in_submodule(rm, ins));
rm->add_return(rm->fuse({ins}));
auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction(
ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
Expand Down Expand Up @@ -286,23 +208,23 @@ struct find_pointwise_reduce
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = insert_ins_in_submodule(rm, input, map_ins).front();
auto rins = rm->fuse({input}, &map_ins).front();
map_ins[input] = rins;

if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
auto fbroadcast = r.instructions["final_broadcast"];
map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front();
map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front();
if(fbroadcast != broadcast)
map_ins[fbroadcast] = map_ins[broadcast];
}

// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, map_ins));
rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins));
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
}
};
Expand All @@ -327,24 +249,24 @@ struct find_reduce_pointwise
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions
insert_module_in_submodule(rm, reduce, map_ins);
insert_module_in_submodule(rm, reduce, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}

auto out = insert_ins_in_submodule(rm, pw, map_ins);
auto out = rm->fuse({pw}, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
Expand Down Expand Up @@ -372,24 +294,24 @@ struct find_reduce_reduce

std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, map_ins);
insert_module_in_submodule(rm, reduce2, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}

auto out = insert_module_in_submodule(rm, reduce1, map_ins);
auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
Expand Down Expand Up @@ -429,14 +351,14 @@ struct reduce_reshape : rewrite_reshapes_base
auto* oldm = ins->module_inputs().front();
auto* sm = mpm.create_module(oldm->name() + "_reshape");
sm->set_bypass();
insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {
sm->set_bypass();
sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {

if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
}

Expand Down
15 changes: 15 additions & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,21 @@ struct MIGRAPHX_EXPORT module
const std::vector<instruction_ref>& splits1,
const std::vector<instruction_ref>& splits2) const;

// Fuse the instruction into the module by inserting the instructions and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it obvious to people who know the codebase what the inputs and outputs of these methods are?

// parameters for any missing inputs.
std::vector<instruction_ref>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some unit-tests

fuse(const std::vector<instruction_ref>& inss,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

// Fuse another module into this module by inserting the instructions and
// parameters from the module
std::vector<instruction_ref>
fuse(const module& m,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(instruction_ref ins,
Expand Down
8 changes: 8 additions & 0 deletions src/include/migraphx/param_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module_ref.hpp>
#include <vector>
#include <string>

Expand All @@ -37,6 +38,13 @@ MIGRAPHX_EXPORT std::string param_name(std::size_t i, const std::string& prefix

void sort_params(std::vector<instruction_ref>& params);

// Find the inputs for a module by finding instructions that are mapped to the
// parameters in the module
std::vector<instruction_ref>
find_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins,
const_module_ref parent,
const_module_ref sub);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP
57 changes: 57 additions & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,63 @@ std::array<module::with_inputs, 3> module::split(const std::vector<instruction_r
return {{std::move(mods1[0]), std::move(mods2[0]), std::move(mods2[1])}};
}

// Insert parameters into the module based on the input instructions and then
// update the map_ins to map the input to the parameter.
static void insert_params(module& m,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = m.get_parameter_shapes().size();
for(auto input : inputs)
{
if(contains(map_ins, input))
continue;
map_ins[input] = m.add_parameter(param_name(n++), input->get_shape().as_standard());
}
}

std::vector<instruction_ref>
pfultz2 marked this conversation as resolved.
Show resolved Hide resolved
module::fuse(const std::vector<instruction_ref>& inss,
std::unordered_map<instruction_ref, instruction_ref>* map_ins,
module::inserter insert)
{
std::unordered_map<instruction_ref, instruction_ref> default_map_ins;
if(map_ins == nullptr)
map_ins = &default_map_ins;
std::vector<instruction_ref> inputs;
for(auto ins : inss)
{
for(auto input : ins->inputs())
{
if(contains(inss, input))
continue;
if(contains(inputs, input))
continue;
inputs.push_back(input);
}
}
insert_params(*this, inputs, *map_ins);
return this->add_instructions(inss, map_ins, std::move(insert));
}

std::vector<instruction_ref>
module::fuse(const module& m,
bpickrel marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins,
module::inserter insert)
{
std::unordered_map<instruction_ref, instruction_ref> default_map_ins;
if(map_ins == nullptr)
map_ins = &default_map_ins;
insert_params(*this, inputs, *map_ins);
auto param_map = m.get_ins_param_map(inputs);
for(auto&& [input, param] : param_map)
{
(*map_ins)[param] = map_ins->at(input);
}
return this->add_instructions(&m, map_ins, std::move(insert));
}

void module_with_inputs::replace(instruction_ref ins, instruction_ref rep)
{
auto it = std::find(inputs.begin(), inputs.end(), ins);
Expand Down
29 changes: 29 additions & 0 deletions src/param_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include <migraphx/param_utils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <map>
#include <cmath>

namespace migraphx {
Expand All @@ -49,5 +52,31 @@ void sort_params(std::vector<instruction_ref>& params)
}));
}

std::vector<instruction_ref>
find_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs a descriptive comment or no one else will ever be able to use it.

const_module_ref parent,
const_module_ref sub)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param] : map_ins)
{
if(sub != nullptr and not sub->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
if(parent != nullptr and not parent->has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(not sub or result.size() == sub->get_parameter_shapes().size());
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +74 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sub == nullptr you can just do early return

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sub == nullptr you can just do early return

Early return where?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just when it starts the body of find_inputs()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then that will skip getting the parameters. Its meant to be optional. If sub is null then it will assume all parameters come from the submodule.

return result;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading
Loading