Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse Split-Reduce with MLIR #3319

Merged
merged 195 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 179 commits
Commits
Show all changes
195 commits
Select commit Hold shift + click to select a range
7e24411
Add atomic ops
pfultz2 Apr 15, 2024
244d8b8
Add missing header
pfultz2 Apr 15, 2024
c53c40a
Add support for half type
pfultz2 Apr 15, 2024
d39f832
Add fuse mthods to module
pfultz2 Apr 18, 2024
d2d3bae
Format
pfultz2 Apr 18, 2024
af83509
Add some initial code
pfultz2 Apr 18, 2024
ac47954
Format
pfultz2 Apr 18, 2024
c9407aa
Reuse find_inputs
pfultz2 Apr 18, 2024
ea41fb9
Format
pfultz2 Apr 18, 2024
15c06b5
Merge branch 'develop' into split-reduce2
pfultz2 Apr 20, 2024
3931cfc
Handle two reductions
pfultz2 Apr 20, 2024
0370543
Format
pfultz2 Apr 20, 2024
d4db0f6
Handle multi outputs in split reduce
pfultz2 Apr 20, 2024
5b37853
Format
pfultz2 Apr 20, 2024
ac747b2
Split two reductions
pfultz2 Apr 20, 2024
2f7e96c
Format
pfultz2 Apr 20, 2024
acb291b
Merge
pfultz2 Apr 24, 2024
c6a7caa
Add split fix
pfultz2 Apr 25, 2024
25442a5
Fix bug with live instruction after split
pfultz2 Apr 25, 2024
3b04922
Format
pfultz2 Apr 25, 2024
1cfa65e
Remove debug prints
pfultz2 Apr 25, 2024
61d788c
Enable with env var
pfultz2 Apr 25, 2024
cad9d3d
Merge branch 'develop' into mlir-fuse-inputs
pfultz2 Apr 26, 2024
52a6a0e
Merge branch 'develop' into mlir-fuse-inputs
pfultz2 Apr 29, 2024
6533429
Fix merge conflict
pfultz2 May 7, 2024
36af65c
Merge branch 'develop' into split-reduce2
pfultz2 May 13, 2024
78161de
Use reaches
pfultz2 May 13, 2024
166f7c9
Merge branch 'develop' into split-reduce2
pfultz2 May 14, 2024
6647d4b
Remvoe dominator
pfultz2 May 14, 2024
66e9d31
Update comments
pfultz2 May 14, 2024
9107b26
Format
pfultz2 May 14, 2024
cbf3afc
Fix param_utils
pfultz2 May 14, 2024
3947949
Filter supported ops
pfultz2 May 14, 2024
ffdba3c
Add another comment
pfultz2 May 14, 2024
019bb0d
Add test for multi out split reduce
pfultz2 May 15, 2024
c33f7fd
Format
pfultz2 May 15, 2024
86df8f1
Add dominator back
pfultz2 May 16, 2024
7e5babf
Format
pfultz2 May 16, 2024
6f99033
Merge
pfultz2 May 22, 2024
48712c9
Handle scalars
pfultz2 May 23, 2024
c3cf902
Format
pfultz2 May 23, 2024
28b013e
Merge
pfultz2 Jun 15, 2024
d40bbae
Merge branch 'develop' into split-reduce2
pfultz2 Jun 15, 2024
f9df6fa
Merge branch 'develop' into mlir-fuse-inputs
pfultz2 Jun 18, 2024
51d3ea9
Add description
pfultz2 Jun 18, 2024
0e60085
Update src/targets/gpu/fuse_mlir.cpp
causten Jun 19, 2024
fd0b7f7
Add doc
pfultz2 Jun 20, 2024
43ff58b
Merge branch 'develop' into split-reduce2
pfultz2 Jun 20, 2024
9c4d659
Add input fusion to jenkins
pfultz2 Jun 20, 2024
5eaaed3
Add unit test for fuse module
pfultz2 Jun 21, 2024
1edac2d
Format
pfultz2 Jun 21, 2024
4a825cb
Merge branch 'develop' into mlir-fuse-inputs
causten Jun 21, 2024
1ebdaf1
Add unit test
pfultz2 Jun 21, 2024
d035f3b
Format
pfultz2 Jun 21, 2024
8c4b8f0
Rename type
pfultz2 Jun 21, 2024
271ea78
Add verify test
pfultz2 Jun 21, 2024
43e76f5
Format
pfultz2 Jun 21, 2024
b357f94
Fix tidy issue
pfultz2 Jun 21, 2024
fbb630e
Fix tidy
pfultz2 Jun 22, 2024
a3ff01a
Format
pfultz2 Jun 22, 2024
7686c3d
Merge branch 'develop' into split-reduce2
pfultz2 Jun 22, 2024
5e848ba
Fix parameter name
pfultz2 Jun 22, 2024
c7ff9a7
Merge branch 'split-reduce2' of github.com:ROCmSoftwarePlatform/AMDMI…
pfultz2 Jun 22, 2024
3964597
Add line
pfultz2 Jun 22, 2024
efb1f76
Format
pfultz2 Jun 22, 2024
f3b2b95
Merge branch 'develop' into split-reduce2
causten Jun 26, 2024
ee50c26
Merge branch 'develop' into mlir-fuse-inputs
causten Jul 2, 2024
c0c51c5
Merge branch 'develop' into split-reduce2
causten Jul 2, 2024
5a7f247
Merge branch 'develop' into mlir-fuse-inputs
umangyadav Jul 10, 2024
ae29e39
add reshapes to fused mlir
umangyadav Jul 10, 2024
470984d
use fuse instead of fold_pointwise
umangyadav Jul 11, 2024
593b119
Passes make check
umangyadav Jul 11, 2024
d49cfe3
pull in changes for find_dot_slice
umangyadav Jul 12, 2024
c12c6bc
add unittest
umangyadav Jul 15, 2024
55c3c6d
add verify test
umangyadav Jul 15, 2024
1f76cc5
debugging
umangyadav Jul 16, 2024
a238d2a
add lowering for contiguous
umangyadav Jul 16, 2024
e26120b
use input_rep_map
umangyadav Jul 16, 2024
c8b06d5
add eliminate_contiguous
umangyadav Jul 16, 2024
64642c9
Add lowering for reshape
umangyadav Jul 16, 2024
0149594
Merge branch 'develop' into mlir-reshape
umangyadav Jul 16, 2024
886fc1b
Fix cppcheck
umangyadav Jul 16, 2024
04e37ad
fix tidy
umangyadav Jul 16, 2024
e533627
Merge branch 'develop' into mlir-reshape
umangyadav Jul 16, 2024
2409622
fixes
umangyadav Jul 16, 2024
8a008b6
rename test file
umangyadav Jul 16, 2024
2a75820
formatting
umangyadav Jul 16, 2024
96ac474
fix SLES
umangyadav Jul 16, 2024
7a65f2e
Merge branch 'develop' into mlir-reshape
umangyadav Jul 16, 2024
a46bbaa
fix test
umangyadav Jul 18, 2024
f4b3211
Merge branch 'develop' into mlir-reshape
umangyadav Jul 18, 2024
b88f6bd
use anonymous namespace
umangyadav Jul 19, 2024
518fce3
Merge branch 'develop' into mlir-reshape
umangyadav Jul 23, 2024
c9f5201
multi use case
umangyadav Jul 22, 2024
8a44a13
fix replace
umangyadav Jul 22, 2024
c984b83
clean up
umangyadav Jul 23, 2024
ea3fdb7
add test
umangyadav Jul 23, 2024
329955b
add multi use case
umangyadav Jul 23, 2024
d14cd66
revert test change
umangyadav Jul 23, 2024
1e981a2
add verify test
umangyadav Jul 23, 2024
2a1c4cd
fix return
umangyadav Jul 23, 2024
9a9c2c4
Foramtting
umangyadav Jul 23, 2024
9c50be6
Merge branch 'add_multi_use' into mlir-split-reduce
umangyadav Jul 23, 2024
bb76528
Add missing elipsis
pfultz2 Jul 23, 2024
51d3c5f
Add licenses
pfultz2 Jul 23, 2024
cb909a4
Format
pfultz2 Jul 23, 2024
3f4ef63
split-reduce fusion working
umangyadav Jul 24, 2024
374e74b
Merge branch 'develop' into split-reduce2
pfultz2 Jul 24, 2024
0f785f0
Update test/split_reduce.cpp
pfultz2 Jul 24, 2024
97e4861
Update test/split_reduce.cpp
pfultz2 Jul 24, 2024
32140c9
Fix test
pfultz2 Jul 24, 2024
daa607c
Format
pfultz2 Jul 24, 2024
bca36d8
Merge branch 'split-reduce2' of github.com:ROCmSoftwarePlatform/AMDMI…
pfultz2 Jul 24, 2024
94d9456
refactor pieces
umangyadav Jul 25, 2024
072f8dc
formatting
umangyadav Jul 25, 2024
0a2a8d8
renamed
umangyadav Jul 25, 2024
0a06260
refactor
umangyadav Jul 25, 2024
dff3dd4
remove debug
umangyadav Jul 25, 2024
ff94e04
add logic for checking is mlir_split_reduce
umangyadav Jul 25, 2024
9540c78
add logic for is_reduce in header files
umangyadav Jul 25, 2024
00fef22
Format
pfultz2 Jul 25, 2024
207f94e
add TODO
umangyadav Jul 25, 2024
f022edb
add assert
umangyadav Jul 25, 2024
1397e09
Merge branch 'develop' into mlir-reshape
umangyadav Jul 25, 2024
244e62e
remove else
umangyadav Jul 26, 2024
e4c9eb9
remove else
umangyadav Jul 26, 2024
ba53ce4
Merge branch 'develop' into split-reduce2
causten Jul 26, 2024
64ed1ec
Merge branch 'develop' into mlir-reshape
umangyadav Jul 29, 2024
cd8762f
Merge branch 'mlir-reshape' into add_multi_use
umangyadav Jul 29, 2024
7e356a6
Merge branch 'develop' into mlir-reshape
umangyadav Jul 29, 2024
6b81657
Merge branch 'mlir-reshape' into add_multi_use
umangyadav Jul 29, 2024
a1c5ad7
use mlir for the reshapes
umangyadav Jul 30, 2024
ca11ca4
fuse reshapes with dot
umangyadav Jul 30, 2024
9a7aa0b
remove header
umangyadav Jul 30, 2024
d3ab2af
remove changes for module split
umangyadav Jul 30, 2024
2e7c2d8
Merge branch 'mlir-reshape' into add_multi_use
umangyadav Jul 30, 2024
83fd160
flatten outputs
umangyadav Jul 30, 2024
a784df3
Merge branch 'add_multi_use' into mlir-split-reduce
umangyadav Jul 30, 2024
5d9fe2a
Merge branch 'split-reduce2' into mlir-split-reduce
umangyadav Jul 30, 2024
662a29d
disable test
umangyadav Jul 30, 2024
d4dd7af
remove TODO
umangyadav Jul 30, 2024
c1cba50
Update TODO
pfultz2 Jul 30, 2024
805793d
add verify test
umangyadav Jul 30, 2024
74496ba
Merge branch 'develop' into split-reduce2
umangyadav Jul 30, 2024
fd5a9a1
increase reduce limite, disable rewrite_reduce to reduce_sum
umangyadav Jul 30, 2024
bd1eca3
Get correct data type for lane reductions
pfultz2 Jul 30, 2024
06f54fa
Merge remote-tracking branch 'origin/lane-parallel-reduce' into mlir-…
umangyadav Jul 30, 2024
c5032ff
Merge remote-tracking branch 'origin/split-reduce2' into mlir-split-r…
umangyadav Jul 30, 2024
631127a
enable test again
umangyadav Jul 30, 2024
e82daf1
revert back split size
umangyadav Jul 31, 2024
eb4f262
add MIGRAPHX_EXPORT For the reaches
umangyadav Jul 31, 2024
4589e09
Merge branch 'split-reduce2' into mlir-split-reduce
umangyadav Jul 31, 2024
1ac328b
add test for the MLIR slow bench
umangyadav Jul 31, 2024
ddbf8ba
Merge branch 'develop' into add_multi_use
umangyadav Jul 31, 2024
68a8afb
fix merge
umangyadav Jul 31, 2024
7e83db3
fix unit-test
umangyadav Jul 31, 2024
ec3dc3f
Merge branch 'add_multi_use' into mlir-split-reduce
umangyadav Jul 31, 2024
c5b70b7
merge fixes
umangyadav Jul 31, 2024
ca7df92
fix return bug enable rewrite_reduce
umangyadav Jul 31, 2024
9f56e6a
fix wiring
umangyadav Jul 31, 2024
f1550b1
fix output shape
umangyadav Jul 31, 2024
f276db5
remove debug prints
umangyadav Jul 31, 2024
2076920
add env flag for the reduce fusion
umangyadav Aug 1, 2024
43a22e5
add doc
umangyadav Aug 1, 2024
a4d546d
formatting
umangyadav Aug 1, 2024
c64d2ee
fix cppcheck
umangyadav Aug 1, 2024
40325f9
update problem_key && jenkins
umangyadav Aug 1, 2024
67ea3c6
change EPS
umangyadav Aug 1, 2024
8ebbb0e
Merge remote-tracking branch 'origin/develop' into mlir-split-reduce
umangyadav Aug 1, 2024
ece936f
Merge branch 'develop' into mlir-split-reduce
umangyadav Aug 12, 2024
c5c4c72
Merge branch 'develop' into mlir-split-reduce
umangyadav Aug 12, 2024
5b51efd
merge fixes
umangyadav Aug 12, 2024
5e828ee
fix tidy
umangyadav Aug 12, 2024
6e78168
Merge branch 'develop' into mlir-split-reduce
umangyadav Aug 12, 2024
69fef78
change EPS For half and fp8
umangyadav Aug 13, 2024
70063f9
Merge branch 'develop' into mlir-split-reduce
umangyadav Aug 13, 2024
34c539f
Merge branch 'develop' into mlir-split-reduce
umangyadav Aug 14, 2024
1ebf2a3
address review comments
umangyadav Aug 14, 2024
102a246
formattimg
umangyadav Aug 14, 2024
f967f7d
Merge remote-tracking branch 'origin/develop' into mlir-split-reduce
umangyadav Aug 16, 2024
335be33
address review comments, add dump_mlir test
umangyadav Aug 16, 2024
112b14a
formatting
umangyadav Aug 16, 2024
57c550e
fix typo
umangyadav Aug 16, 2024
b02eb78
fix tidy
umangyadav Aug 16, 2024
86b98aa
add test
umangyadav Aug 16, 2024
df96690
add reduce.hpp header
umangyadav Aug 16, 2024
8e0acd0
add multi use unit-test
umangyadav Aug 16, 2024
a5733c5
fix licensing
umangyadav Aug 16, 2024
93d24bf
Merge branch 'develop' into mlir-split-reduce
umangyadav Aug 16, 2024
070da3d
revert problem_key changes
umangyadav Aug 16, 2024
dc71b68
add one more test
umangyadav Aug 16, 2024
1b68e45
use auto_add_return
umangyadav Aug 16, 2024
4e043c7
use `insert_inline()`
umangyadav Aug 17, 2024
848d807
fix cppcheck
umangyadav Aug 17, 2024
94e112a
Formatting
umangyadav Aug 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
}
}, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1']) {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) {
def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
Expand Down
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ Limits the number of solutions available to MLIR for tuning.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable input fusions in MLIR.

.. envvar:: MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION

Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable reduction fusions in MLIR.

.. envvar:: MIGRAPHX_MLIR_ENABLE_SPLITK

Set to "1", "enable", "enabled", "yes", or "true" to use.
Expand Down
4 changes: 2 additions & 2 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,8 @@ module::fuse(const module& m,
if(map_ins == nullptr)
map_ins = &default_map_ins;
insert_params(*this, inputs, *map_ins);
auto param_map = m.get_ins_param_map(inputs);
for(auto&& [input, param] : param_map)
auto param_map = m.get_ins_param_map(inputs, true);
for(auto&& [param, input] : param_map)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this flipped?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are cases where same input instruction is mapped to multiple parameters.
e.g. split_fused_reduce(x, y, x)

In those cases, having mapping from input--> param would de-duplicate it and only add single parameter.

Later

copy_ins = m.add_parameter(name, s);

Here it won't find parameter in the map_ins and would try to add it and fails

{
(*map_ins)[param] = map_ins->at(input);
}
Expand Down
151 changes: 151 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace gpu {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
Expand Down Expand Up @@ -386,13 +387,59 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return false;
}

bool is_reduce_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {
type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> reduce_ops = {"reduce_mean", "reduce_sum"};
return contains(reduce_ops, i.name());
}

// A separate function so we can remove operators that are supported by mlir
// but not supported for an input fusion.
bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i)
{
return is_pointwise_op_supported_by_mlir(i);
}

MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins)
{
if(ins->name() != "split_fused_reduce")
return false;
auto* mod_arg = ins->module_inputs().front();
auto supported_reshapes = reshaper_names();
supported_reshapes.erase("slice");
std::unordered_set<std::string> builtins = {"@param", "@literal", "@return"};
for(const auto i : iterator_for(*mod_arg))
{
if(is_reduce(*i))
{
if(not is_reduce_op_supported_by_mlir(*i))
return false;
}
else if(i->name() == "pointwise")
{
if(not std::all_of(i->module_inputs().front()->begin(),
i->module_inputs().front()->end(),
&is_pointwise_op_supported_by_mlir))
return false;
}
else if(not contains(reshaper_names(), i->name()) and not contains(builtins, i->name()))
{
return false;
}
}
return true;
}

MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
{
if(ins->name() != "pointwise")
Expand Down Expand Up @@ -423,6 +470,100 @@ std::vector<instruction_ref> mlir_contiguous(module_pass_manager& mpm,
return result;
}

struct find_mlir_split_reduce
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::name("gpu::mlir_op");
// TODO: Handle reshapes inbetween
return mlir_split_reduce()(match::any_of[match::inputs()](dot_or_conv.bind("gemm")));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce_ins = r.result;
auto gemm_ins = r.instructions["gemm"];
assert(gemm_ins->get_shape().sub_shapes().empty());
auto* rm = reduce_ins->module_inputs().front();
auto names = rm->get_parameter_names();
std::sort(names.begin(), names.end());
module_ref gemm_old_mm = gemm_ins->module_inputs().front();
module_ref mm =
mpm.create_module(gemm_old_mm->name() + "_split_fused_reduce", *gemm_old_mm);
// remove last return instruction
if(std::prev(mm->end())->name() == "@return")
{
mm->remove_instruction(std::prev(mm->end()));
}
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
param_map[gemm_ins] = std::prev(mm->end());
bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1;
auto return_vals =
mm->fuse(*rm,
reduce_ins->inputs(),
&param_map,
[&](module& main_mod,
instruction_ref pos,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args) {
if(op.name() == "pointwise")
{
for(const auto& skip_param : inputs)
{
if(not contains(param_map, skip_param))
{
param_map[skip_param] =
skip_param; // skip adding parameter for inputs of
// pointwise inside split_fused_reduce
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? If its not in the param_map then that seems like a bug.

Copy link
Member Author

@umangyadav umangyadav Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fusing pointwise module arg inside the fused_mlir_module. Note that pointwise module is submodule to split_fused_reduce module as well.

Inputs to pointwise would be available in map_ins internally to fuse method.

(*map_ins)[param] = map_ins->at(input);

But they won't be passed down when "inserter" is invoked.

copy_ins = insert(m, ins, sins->get_operator(), copy_inputs, mod_args);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thats because you are reusing the outer param_map. I am not sure that is needed though. You could just use another param_map:

auto param_map_2 = create_param_map_with_literals(&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args));
return main_mod.fuse(*sub_pm, inputs, &param_map_2).front();

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding params to same param_map would handle case if pointwise module accesses instruction from its' parent split_fused_reduce module.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also need to add inputs of the pointwise op into param_map to avoid adding params for that. Adding to same param_map solves both issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding params to same param_map would handle case if pointwise module accesses instruction from its' parent split_fused_reduce module.

That should never happen. Since the pointwise module is used to generate a single GPU kernel(or function in the case of fused_reduce) it should never access instructions from the parent scope.

I also need to add inputs of the pointwise op into param_map to avoid adding params for that.

Sure, you can use get_ins_param_map for that.

auto param_map_2 = sub_pm->get_ins_param_map(inputs, true);
auto literal_param_map = create_param_map_with_literals(&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args));
param_map_2.insert(literal_param_map.begin(), literal_param_map.end());
return main_mod.fuse(*sub_pm, inputs, &param_map_2).front();

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, you can use get_ins_param_map for that.

No, I can't it will create map from param -> input it won't skip adding parameters for the inputs.

if(contains(map_ins, input))

I need map from input --> input

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some changes.

auto* sub_pm = mod_args.front();
auto param_map_2 = create_param_map_with_literals(
&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args));
param_map.insert(param_map_2.begin(), param_map_2.end());
return main_mod.fuse(*sub_pm, inputs, &param_map).front();
}
return main_mod.insert_instruction(pos, op, inputs, mod_args);
});
if(gemm_has_multi_outs)
{
return_vals.insert(return_vals.end(), param_map[gemm_ins]);
}
mm->add_return(return_vals);
std::vector<instruction_ref> inputs;
std::copy_if(reduce_ins->inputs().begin(),
reduce_ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_ins; });
inputs.insert(inputs.end(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
if(gemm_has_multi_outs)
{
auto fused_ins = mpm.get_module().insert_instruction(
reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
auto dot_ins = mpm.get_module().insert_instruction(
reduce_ins,
migraphx::make_op("get_tuple_elem", {{"index", return_vals.size() - 1}}),
fused_ins);

mpm.get_module().replace_instruction(gemm_ins, dot_ins);
for(const auto outs : reduce_ins->outputs())
{
assert(outs->get_operator().name() == "get_tuple_elem");
mpm.get_module().replace_instruction(outs, outs->get_operator(), fused_ins);
}
}
else
{
mpm.get_module().replace_instruction(
reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
}
}
};

struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
Expand Down Expand Up @@ -714,15 +855,25 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});

match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)});

mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{}))
{
match::find_matches(
mpm,
find_mlir_split_reduce{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});
}

if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{}))
{
match::find_matches(mpm, find_pointwise_mlir{});
}
#else
(void)mpm;
#endif
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/mlir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object
std::vector<value> prefill_values = {};
};

MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& ins);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see this used outside of mlir.cpp. I think it can be removed from the header.

Copy link
Member Author

@umangyadav umangyadav Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is being used on both fuse_mlir.cpp and mlir.cpp


MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& in_shapes,
Expand Down
9 changes: 8 additions & 1 deletion src/targets/gpu/jit/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ struct mlir_compiler : compiler<mlir_compiler>
dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front());
mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution);
auto pw_shapes = to_shapes(mod_splits[1].inputs);
pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front());
if(mod_splits[1].mod.get_output_shapes().size() == 1)
{
pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front());
}
else
{
pw_shapes.push_back(shape{mod_splits[1].mod.get_output_shapes()});
}
assert(pw_shapes.back() == ins->get_shape());
auto pw_mod = create_pointwise_module(&mod_splits[1].mod);
auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod);
Expand Down
71 changes: 68 additions & 3 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <cstdint>
#include <migraphx/algorithm.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <mlir-c/Dialect/RockEnums.h>
#include <numeric>
#include <ostream>

#ifdef MIGRAPHX_MLIR
Expand Down Expand Up @@ -951,11 +956,60 @@ struct mlir_program
std::string sym_name;
};

bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); }

static void rewrite_reduce(module& m)
{
for(auto i : iterator_for(m))
{
if(is_reduce(*i))
{
auto reduce_op = i->get_operator().to_value();
auto reduce_axes = reduce_op["axes"].to_vector<size_t>();
auto reduce_lens = i->get_shape().lens();
auto in_shape = i->inputs().front()->get_shape();
auto in_lens = in_shape.lens();
assert(in_shape.standard());
assert(reduce_lens.size() == in_lens.size());
assert(std::adjacent_find(
reduce_axes.begin(), reduce_axes.end(), [](auto axis_1, auto axis_2) {
return axis_2 - axis_1 > 1;
}) == reduce_axes.end());

std::vector<int64_t> new_rsp_dims;
std::vector<int64_t> new_reduce_axes;
for(const auto axis : range(in_shape.ndim()))
{
if(reduce_lens[axis] == in_lens[axis])
{
new_rsp_dims.push_back(in_lens[axis]);
}
else if(new_reduce_axes.empty())
{
assert(reduce_lens[axis] == 1);
new_rsp_dims.push_back(-1);
new_reduce_axes.push_back(axis);
}
}
auto rsp_ins = m.insert_instruction(
i, migraphx::make_op("reshape", {{"dims", new_rsp_dims}}), i->inputs().front());
auto collapsed_reduce = m.insert_instruction(
i, migraphx::make_op("reduce_sum", {{"axes", new_reduce_axes}}), rsp_ins);
auto rsp_back = m.insert_instruction(
i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), collapsed_reduce);
m.replace_instruction(i, rsp_back);
}
}
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}

bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution)
{
auto mm = m;
rewrite_reduce(mm);
mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
mp.parse(mm);
mp.run_high_level_pipeline();
return mlirIsModuleFusible(mp.mmodule.get(), make_mlir_string_ref(*solution.if_string()));
}
Expand Down Expand Up @@ -988,6 +1042,7 @@ std::string dump_mlir(const module& m, const std::vector<shape>& inputs)
mr = &mm;
adjust_param_shapes(mm, inputs);
}
rewrite_reduce(mm);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wont dump correctly if the inputs are empty. Probably should take the module by value in the function and remove the mm and mr variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wont dump correctly if the inputs are empty.

Why is that ?
rewrite_reduce is just reshaping reduction ops. It doesn't touch inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wont dump correctly if the inputs are empty.

Why is that ? rewrite_reduce is just reshaping reduction ops. It doesn't touch inputs.

It rewrites the mm variable which is only set to the m variable(ie the module passed to the function) when the inputs are not empty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

mlir_program mp;
mp.parse(*mr);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
Expand All @@ -1002,6 +1057,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx,
const value& solution)
{
adjust_param_shapes(m, in_shapes);
rewrite_reduce(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});

static std::mutex mutex;
Expand Down Expand Up @@ -1081,12 +1137,21 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
bool exhaustive)
{
adjust_param_shapes(m, inputs);

rewrite_reduce(m);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make one function(like mlir_rewrites) that does adjust_param_shapes and rewrite_reduce together. This we have once place to add any additional rewrites in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that after this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
auto tc = mp.get_tuning_config(exhaustive);

std::string problem_config = tc.problem.to<std::string>();
for(const auto i : iterator_for(m))
{
if(starts_with(i->name(), "@"))
{
continue;
}
problem_config += " " + i->name();
}
tc.problem = problem_config;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had these changes to experiment and work around problem_cache issue. Reverting these

const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
Expand Down
Loading
Loading