-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse Split-Reduce with MLIR #3319
Changes from 179 commits
7e24411
244d8b8
c53c40a
d39f832
d2d3bae
af83509
ac47954
c9407aa
ea41fb9
15c06b5
3931cfc
0370543
d4db0f6
5b37853
ac747b2
2f7e96c
acb291b
c6a7caa
25442a5
3b04922
1cfa65e
61d788c
cad9d3d
52a6a0e
6533429
36af65c
78161de
166f7c9
6647d4b
66e9d31
9107b26
cbf3afc
3947949
ffdba3c
019bb0d
c33f7fd
86df8f1
7e5babf
6f99033
48712c9
c3cf902
28b013e
d40bbae
f9df6fa
51d3ea9
0e60085
fd0b7f7
43ff58b
9c4d659
5eaaed3
1edac2d
4a825cb
1ebdaf1
d035f3b
8c4b8f0
271ea78
43e76f5
b357f94
fbb630e
a3ff01a
7686c3d
5e848ba
c7ff9a7
3964597
efb1f76
f3b2b95
ee50c26
c0c51c5
5a7f247
ae29e39
470984d
593b119
d49cfe3
c12c6bc
55c3c6d
1f76cc5
a238d2a
e26120b
c8b06d5
64642c9
0149594
886fc1b
04e37ad
e533627
2409622
8a008b6
2a75820
96ac474
7a65f2e
a46bbaa
f4b3211
b88f6bd
518fce3
c9f5201
8a44a13
c984b83
ea3fdb7
329955b
d14cd66
1e981a2
2a1c4cd
9a9c2c4
9c50be6
bb76528
51d3c5f
cb909a4
3f4ef63
374e74b
0f785f0
97e4861
32140c9
daa607c
bca36d8
94d9456
072f8dc
0a2a8d8
0a06260
dff3dd4
ff94e04
9540c78
00fef22
207f94e
f022edb
1397e09
244e62e
e4c9eb9
ba53ce4
64ed1ec
cd8762f
7e356a6
6b81657
a1c5ad7
ca11ca4
9a7aa0b
d3ab2af
2e7c2d8
83fd160
a784df3
5d9fe2a
662a29d
d4dd7af
c1cba50
805793d
74496ba
fd5a9a1
bd1eca3
06f54fa
c5032ff
631127a
e82daf1
eb4f262
4589e09
1ac328b
ddbf8ba
68a8afb
7e83db3
ec3dc3f
c5b70b7
ca7df92
9f56e6a
f1550b1
f276db5
2076920
43a22e5
a4d546d
c64d2ee
40325f9
67ea3c6
8ebbb0e
ece936f
c5c4c72
5b51efd
5e828ee
6e78168
69fef78
70063f9
34c539f
1ebf2a3
102a246
f967f7d
335be33
112b14a
57c550e
b02eb78
86b98aa
df96690
8e0acd0
a5733c5
93d24bf
070da3d
dc71b68
1b68e45
4e043c7
848d807
94e112a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -43,6 +43,7 @@ namespace gpu { | |||||||
|
||||||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); | ||||||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); | ||||||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); | ||||||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); | ||||||||
/** | ||||||||
* @brief Declares a new MIGraphX environment variable which forces to generate | ||||||||
|
@@ -386,13 +387,59 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) | |||||||
return false; | ||||||||
} | ||||||||
|
||||||||
bool is_reduce_op_supported_by_mlir(const instruction& i) | ||||||||
{ | ||||||||
using type_t = shape::type_t; | ||||||||
const auto& name = i.name(); | ||||||||
const auto result_type = i.get_shape().type(); | ||||||||
const std::initializer_list<type_t> allowed_types = { | ||||||||
type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}; | ||||||||
// Preliminary type check. | ||||||||
if(not contains(allowed_types, result_type)) | ||||||||
{ | ||||||||
return false; | ||||||||
} | ||||||||
const std::initializer_list<std::string> reduce_ops = {"reduce_mean", "reduce_sum"}; | ||||||||
return contains(reduce_ops, i.name()); | ||||||||
} | ||||||||
|
||||||||
// A separate function so we can remove operators that are supported by mlir | ||||||||
// but not supported for an input fusion. | ||||||||
bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) | ||||||||
{ | ||||||||
return is_pointwise_op_supported_by_mlir(i); | ||||||||
} | ||||||||
|
||||||||
MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins) | ||||||||
{ | ||||||||
if(ins->name() != "split_fused_reduce") | ||||||||
return false; | ||||||||
auto* mod_arg = ins->module_inputs().front(); | ||||||||
auto supported_reshapes = reshaper_names(); | ||||||||
supported_reshapes.erase("slice"); | ||||||||
std::unordered_set<std::string> builtins = {"@param", "@literal", "@return"}; | ||||||||
for(const auto i : iterator_for(*mod_arg)) | ||||||||
{ | ||||||||
if(is_reduce(*i)) | ||||||||
{ | ||||||||
if(not is_reduce_op_supported_by_mlir(*i)) | ||||||||
return false; | ||||||||
} | ||||||||
else if(i->name() == "pointwise") | ||||||||
{ | ||||||||
if(not std::all_of(i->module_inputs().front()->begin(), | ||||||||
i->module_inputs().front()->end(), | ||||||||
&is_pointwise_op_supported_by_mlir)) | ||||||||
return false; | ||||||||
} | ||||||||
else if(not contains(reshaper_names(), i->name()) and not contains(builtins, i->name())) | ||||||||
{ | ||||||||
return false; | ||||||||
} | ||||||||
} | ||||||||
return true; | ||||||||
} | ||||||||
|
||||||||
MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) | ||||||||
{ | ||||||||
if(ins->name() != "pointwise") | ||||||||
|
@@ -423,6 +470,100 @@ std::vector<instruction_ref> mlir_contiguous(module_pass_manager& mpm, | |||||||
return result; | ||||||||
} | ||||||||
|
||||||||
struct find_mlir_split_reduce | ||||||||
{ | ||||||||
mlir_mode conv_mode = mlir_mode::none; | ||||||||
mlir_mode dot_mode = mlir_mode::none; | ||||||||
auto matcher() const | ||||||||
{ | ||||||||
auto dot_or_conv = match::name("gpu::mlir_op"); | ||||||||
// TODO: Handle reshapes inbetween | ||||||||
return mlir_split_reduce()(match::any_of[match::inputs()](dot_or_conv.bind("gemm"))); | ||||||||
} | ||||||||
|
||||||||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const | ||||||||
{ | ||||||||
auto reduce_ins = r.result; | ||||||||
auto gemm_ins = r.instructions["gemm"]; | ||||||||
assert(gemm_ins->get_shape().sub_shapes().empty()); | ||||||||
auto* rm = reduce_ins->module_inputs().front(); | ||||||||
auto names = rm->get_parameter_names(); | ||||||||
std::sort(names.begin(), names.end()); | ||||||||
module_ref gemm_old_mm = gemm_ins->module_inputs().front(); | ||||||||
module_ref mm = | ||||||||
mpm.create_module(gemm_old_mm->name() + "_split_fused_reduce", *gemm_old_mm); | ||||||||
// remove last return instruction | ||||||||
if(std::prev(mm->end())->name() == "@return") | ||||||||
{ | ||||||||
mm->remove_instruction(std::prev(mm->end())); | ||||||||
} | ||||||||
mm->set_bypass(); | ||||||||
std::unordered_map<instruction_ref, instruction_ref> param_map; | ||||||||
param_map[gemm_ins] = std::prev(mm->end()); | ||||||||
bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1; | ||||||||
auto return_vals = | ||||||||
mm->fuse(*rm, | ||||||||
reduce_ins->inputs(), | ||||||||
¶m_map, | ||||||||
[&](module& main_mod, | ||||||||
instruction_ref pos, | ||||||||
const operation& op, | ||||||||
const std::vector<instruction_ref>& inputs, | ||||||||
const std::vector<module_ref>& mod_args) { | ||||||||
if(op.name() == "pointwise") | ||||||||
{ | ||||||||
for(const auto& skip_param : inputs) | ||||||||
{ | ||||||||
if(not contains(param_map, skip_param)) | ||||||||
{ | ||||||||
param_map[skip_param] = | ||||||||
skip_param; // skip adding parameter for inputs of | ||||||||
// pointwise inside split_fused_reduce | ||||||||
} | ||||||||
} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? If its not in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fusing pointwise module arg inside the fused_mlir_module. Note that pointwise module is submodule to Inputs to pointwise would be available in Line 1047 in ae2b026
But they won't be passed down when "inserter" is invoked. Line 260 in ae2b026
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see thats because you are reusing the outer param_map. I am not sure that is needed though. You could just use another param_map: auto param_map_2 = create_param_map_with_literals(&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args));
return main_mod.fuse(*sub_pm, inputs, ¶m_map_2).front(); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding params to same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also need to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That should never happen. Since the pointwise module is used to generate a single GPU kernel(or function in the case of fused_reduce) it should never access instructions from the parent scope.
Sure, you can use auto param_map_2 = sub_pm->get_ins_param_map(inputs, true);
auto literal_param_map = create_param_map_with_literals(&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args));
param_map_2.insert(literal_param_map.begin(), literal_param_map.end());
return main_mod.fuse(*sub_pm, inputs, ¶m_map_2).front(); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, I can't it will create map from Line 1004 in ae2b026
I need map from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made some changes. |
||||||||
auto* sub_pm = mod_args.front(); | ||||||||
auto param_map_2 = create_param_map_with_literals( | ||||||||
&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); | ||||||||
param_map.insert(param_map_2.begin(), param_map_2.end()); | ||||||||
return main_mod.fuse(*sub_pm, inputs, ¶m_map).front(); | ||||||||
} | ||||||||
return main_mod.insert_instruction(pos, op, inputs, mod_args); | ||||||||
}); | ||||||||
if(gemm_has_multi_outs) | ||||||||
{ | ||||||||
return_vals.insert(return_vals.end(), param_map[gemm_ins]); | ||||||||
} | ||||||||
mm->add_return(return_vals); | ||||||||
std::vector<instruction_ref> inputs; | ||||||||
std::copy_if(reduce_ins->inputs().begin(), | ||||||||
reduce_ins->inputs().end(), | ||||||||
std::back_inserter(inputs), | ||||||||
[&](auto input) { return input != gemm_ins; }); | ||||||||
inputs.insert(inputs.end(), gemm_ins->inputs().begin(), gemm_ins->inputs().end()); | ||||||||
if(gemm_has_multi_outs) | ||||||||
{ | ||||||||
auto fused_ins = mpm.get_module().insert_instruction( | ||||||||
reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); | ||||||||
auto dot_ins = mpm.get_module().insert_instruction( | ||||||||
reduce_ins, | ||||||||
migraphx::make_op("get_tuple_elem", {{"index", return_vals.size() - 1}}), | ||||||||
fused_ins); | ||||||||
|
||||||||
mpm.get_module().replace_instruction(gemm_ins, dot_ins); | ||||||||
for(const auto outs : reduce_ins->outputs()) | ||||||||
{ | ||||||||
assert(outs->get_operator().name() == "get_tuple_elem"); | ||||||||
mpm.get_module().replace_instruction(outs, outs->get_operator(), fused_ins); | ||||||||
} | ||||||||
} | ||||||||
else | ||||||||
{ | ||||||||
mpm.get_module().replace_instruction( | ||||||||
reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); | ||||||||
} | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
struct find_mlir_fused_ops | ||||||||
{ | ||||||||
mlir_mode conv_mode = mlir_mode::none; | ||||||||
|
@@ -714,15 +855,25 @@ void fuse_mlir::apply(module_pass_manager& mpm) const | |||||||
mpm, | ||||||||
find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), | ||||||||
.dot_mode = get_mode("fused_dot", mlir_mode::fast)}); | ||||||||
|
||||||||
match::find_matches( | ||||||||
mpm, | ||||||||
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, | ||||||||
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); | ||||||||
|
||||||||
mpm.run_pass(dead_code_elimination{}); | ||||||||
if(enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{})) | ||||||||
{ | ||||||||
match::find_matches( | ||||||||
mpm, | ||||||||
find_mlir_split_reduce{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), | ||||||||
.dot_mode = get_mode("fused_dot", mlir_mode::fast)}); | ||||||||
} | ||||||||
|
||||||||
if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) | ||||||||
{ | ||||||||
match::find_matches(mpm, find_pointwise_mlir{}); | ||||||||
} | ||||||||
#else | ||||||||
(void)mpm; | ||||||||
#endif | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,8 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object | |
std::vector<value> prefill_values = {}; | ||
}; | ||
|
||
MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& ins); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont see this used outside of mlir.cpp. I think it can be removed from the header. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is being used on both |
||
|
||
MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx, | ||
module m, | ||
const std::vector<shape>& in_shapes, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,16 @@ | |
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
* THE SOFTWARE. | ||
*/ | ||
#include <algorithm> | ||
#include <cstdint> | ||
#include <migraphx/algorithm.hpp> | ||
#include <migraphx/make_op.hpp> | ||
#include <migraphx/stringutils.hpp> | ||
#include <migraphx/dead_code_elimination.hpp> | ||
#include <migraphx/pass_manager.hpp> | ||
#include <migraphx/gpu/mlir.hpp> | ||
#include <mlir-c/Dialect/RockEnums.h> | ||
#include <numeric> | ||
#include <ostream> | ||
|
||
#ifdef MIGRAPHX_MLIR | ||
|
@@ -951,11 +956,60 @@ struct mlir_program | |
std::string sym_name; | ||
}; | ||
|
||
bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } | ||
|
||
static void rewrite_reduce(module& m) | ||
{ | ||
for(auto i : iterator_for(m)) | ||
{ | ||
if(is_reduce(*i)) | ||
{ | ||
auto reduce_op = i->get_operator().to_value(); | ||
auto reduce_axes = reduce_op["axes"].to_vector<size_t>(); | ||
auto reduce_lens = i->get_shape().lens(); | ||
auto in_shape = i->inputs().front()->get_shape(); | ||
auto in_lens = in_shape.lens(); | ||
assert(in_shape.standard()); | ||
assert(reduce_lens.size() == in_lens.size()); | ||
assert(std::adjacent_find( | ||
reduce_axes.begin(), reduce_axes.end(), [](auto axis_1, auto axis_2) { | ||
return axis_2 - axis_1 > 1; | ||
}) == reduce_axes.end()); | ||
|
||
std::vector<int64_t> new_rsp_dims; | ||
std::vector<int64_t> new_reduce_axes; | ||
for(const auto axis : range(in_shape.ndim())) | ||
{ | ||
if(reduce_lens[axis] == in_lens[axis]) | ||
{ | ||
new_rsp_dims.push_back(in_lens[axis]); | ||
} | ||
else if(new_reduce_axes.empty()) | ||
{ | ||
assert(reduce_lens[axis] == 1); | ||
new_rsp_dims.push_back(-1); | ||
new_reduce_axes.push_back(axis); | ||
} | ||
} | ||
auto rsp_ins = m.insert_instruction( | ||
i, migraphx::make_op("reshape", {{"dims", new_rsp_dims}}), i->inputs().front()); | ||
auto collapsed_reduce = m.insert_instruction( | ||
i, migraphx::make_op("reduce_sum", {{"axes", new_reduce_axes}}), rsp_ins); | ||
auto rsp_back = m.insert_instruction( | ||
i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), collapsed_reduce); | ||
m.replace_instruction(i, rsp_back); | ||
} | ||
} | ||
migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); | ||
} | ||
|
||
bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) | ||
{ | ||
auto mm = m; | ||
rewrite_reduce(mm); | ||
mlir_program mp; | ||
mp.set_gpu_properties(migraphx_ctx); | ||
mp.parse(m); | ||
mp.parse(mm); | ||
mp.run_high_level_pipeline(); | ||
return mlirIsModuleFusible(mp.mmodule.get(), make_mlir_string_ref(*solution.if_string())); | ||
} | ||
|
@@ -988,6 +1042,7 @@ std::string dump_mlir(const module& m, const std::vector<shape>& inputs) | |
mr = &mm; | ||
adjust_param_shapes(mm, inputs); | ||
} | ||
rewrite_reduce(mm); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This wont dump correctly if the inputs are empty. Probably should take the module by value in the function and remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why is that ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It rewrites the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
mlir_program mp; | ||
mp.parse(*mr); | ||
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); | ||
|
@@ -1002,6 +1057,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, | |
const value& solution) | ||
{ | ||
adjust_param_shapes(m, in_shapes); | ||
rewrite_reduce(m); | ||
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); | ||
|
||
static std::mutex mutex; | ||
|
@@ -1081,12 +1137,21 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, | |
bool exhaustive) | ||
{ | ||
adjust_param_shapes(m, inputs); | ||
|
||
rewrite_reduce(m); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably make one function(like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can do that after this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
mlir_program mp; | ||
mp.set_gpu_properties(migraphx_ctx); | ||
mp.parse(m); | ||
auto tc = mp.get_tuning_config(exhaustive); | ||
|
||
std::string problem_config = tc.problem.to<std::string>(); | ||
for(const auto i : iterator_for(m)) | ||
{ | ||
if(starts_with(i->name(), "@")) | ||
{ | ||
continue; | ||
} | ||
problem_config += " " + i->name(); | ||
} | ||
tc.problem = problem_config; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had these changes to experiment and work around problem_cache issue. Reverting these |
||
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); | ||
static std::mutex mutex; | ||
if(trace) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this flipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are cases where same input instruction is mapped to multiple parameters.
e.g.
split_fused_reduce(x, y, x)
In those cases, having mapping from
input--> param
would de-duplicate it and only add single parameter.Later
AMDMIGraphX/src/module.cpp
Line 238 in ae2b026
Here it won't find parameter in the
map_ins
and would try to add it and fails