From d39f83252cc6fc8bde53170e7cd4845ddab13313 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 17 Apr 2024 19:17:27 -0700 Subject: [PATCH 01/26] Add fuse mthods to module --- src/fuse_reduce.cpp | 95 ++++++++------------------------- src/include/migraphx/module.hpp | 10 ++++ src/module.cpp | 56 +++++++++++++++++++ src/targets/gpu/fuse_mlir.cpp | 16 ++++++ 4 files changed, 103 insertions(+), 74 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 7f60d5ebe70..cbcca1cad8b 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -83,67 +83,14 @@ struct fused_reduce }; MIGRAPHX_REGISTER_OP(fused_reduce); -static void insert_params(module_ref sm, - const std::vector& inputs, - std::unordered_map& map_ins) -{ - auto n = sm->get_parameter_shapes().size(); - for(auto input : inputs) - { - if(contains(map_ins, input)) - continue; - map_ins[input] = - sm->add_parameter("x" + std::to_string(n++), input->get_shape().as_standard()); - } -} - -static auto insert_ins_in_submodule(module_ref sm, - instruction_ref ins, - std::unordered_map& map_ins) -{ - insert_params(sm, ins->inputs(), map_ins); - return sm->add_instructions({ins}, &map_ins); -} - -static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) -{ - std::unordered_map map_ins; - return insert_ins_in_submodule(sm, ins, map_ins); -} - -static auto -insert_module_in_submodule(module_ref sm, - const std::vector& inputs, - module_ref m, - std::unordered_map& map_ins, - module::inserter insert = nullptr) -{ - insert_params(sm, inputs, map_ins); - auto param_map = m->get_ins_param_map(inputs); - for(auto&& [input, param] : param_map) - { - map_ins[param] = map_ins.at(input); - } - return sm->add_instructions(m, &map_ins, std::move(insert)); -} - static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, - std::unordered_map& map_ins, + std::unordered_map* map_ins = nullptr, module::inserter insert = nullptr) { - return insert_module_in_submodule( - sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert)); -} - -static auto insert_module_in_submodule(module_ref sm, - const std::vector& inputs, - module_ref m, - module::inserter insert = nullptr) -{ - std::unordered_map map_ins; - return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert)); + assert(ins->module_inputs().size() == 1); + return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); } static std::vector @@ -186,7 +133,7 @@ static void create_reduce_modules(module_pass_manager& mpm) mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++)); rm->set_bypass(); - rm->add_return(insert_ins_in_submodule(rm, ins)); + rm->add_return(rm->fuse({ins})); auto v = ins->get_operator().to_value(); mpm.get_module().replace_instruction( ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm}); @@ -234,17 +181,17 @@ struct find_pointwise_reduce std::unordered_map map_ins; // Insert pointwise - auto rins = insert_ins_in_submodule(rm, input, map_ins).front(); + auto rins = rm->fuse({input}, &map_ins).front(); map_ins[input] = rins; if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; - map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front(); + map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front(); } // Insert fused_reduce - rm->add_return(insert_module_in_submodule(rm, reduce, map_ins)); + rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins)); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); @@ -271,12 +218,12 @@ struct find_reduce_pointwise rm->set_bypass(); std::unordered_map map_ins; // Copy module instructions - insert_module_in_submodule(rm, reduce, map_ins); + insert_module_in_submodule(rm, reduce, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); - auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); + auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else @@ -284,7 +231,7 @@ struct find_reduce_pointwise map_ins[input] = rm->get_returns().front(); } - auto out = insert_ins_in_submodule(rm, pw, map_ins); + auto out = rm->fuse({pw}, &map_ins); rm->replace_return(out); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); @@ -315,12 +262,12 @@ struct find_reduce_reduce std::unordered_map map_ins; // Copy reduce1 instructions - insert_module_in_submodule(rm, reduce2, map_ins); + insert_module_in_submodule(rm, reduce2, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); - auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); + auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else @@ -328,7 +275,7 @@ struct find_reduce_reduce map_ins[input] = rm->get_returns().front(); } - auto out = insert_module_in_submodule(rm, reduce1, map_ins); + auto out = insert_module_in_submodule(rm, reduce1, &map_ins); rm->replace_return(out); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); @@ -370,14 +317,14 @@ struct reduce_reshape : rewrite_reshapes_base auto dims = base_dims(inputs); auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); - insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) { - if(contains(sop.name(), "reduce")) - return make_op(sop.name(), {{"axes", axes}}); - if(sop.name() == "multibroadcast") - return make_op("multibroadcast", {{"out_lens", dims}}); - assert(sop.name() == "pointwise"); - return sop; - })); + sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) { + if(contains(sop.name(), "reduce")) + return make_op(sop.name(), {{"axes", axes}}); + if(sop.name() == "multibroadcast") + return make_op("multibroadcast", {{"out_lens", dims}}); + assert(sop.name() == "pointwise"); + return sop; + })); return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); } diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index f9d41121159..159313dfbc9 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -231,6 +231,16 @@ struct MIGRAPHX_EXPORT module const std::vector& splits1, const std::vector& splits2) const; + std::vector fuse( + const std::vector& inss, + std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + + std::vector + fuse(const module& m, + const std::vector& inputs, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, diff --git a/src/module.cpp b/src/module.cpp index 8c28c15d22d..9a9f8c0b07d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -914,6 +914,62 @@ std::array module::split(const std::vector& inputs, + std::unordered_map& map_ins) +{ + auto n = m.get_parameter_shapes().size(); + for(auto input : inputs) + { + if(contains(map_ins, input)) + continue; + map_ins[input] = + m.add_parameter(param_name(n++), input->get_shape().as_standard()); + } +} + +std::vector module::fuse( + const std::vector& inss, + std::unordered_map* map_ins, module::inserter insert) +{ + std::unordered_map default_map_ins; + if(not map_ins) + map_ins = &default_map_ins; + std::vector inputs; + for(auto ins:inss) + { + for(auto input:ins->inputs()) + { + if(contains(inss, input)) + continue; + if(contains(inputs, input)) + continue; + inputs.push_back(input); + } + } + insert_params(*this, inputs, *map_ins); + return this->add_instructions(inss, map_ins, std::move(insert)); +} + +std::vector + module::fuse( + const module& m, + const std::vector& inputs, + std::unordered_map* map_ins, + module::inserter insert) +{ + std::unordered_map default_map_ins; + if(not map_ins) + map_ins = &default_map_ins; + insert_params(*this, inputs, *map_ins); + auto param_map = m.get_ins_param_map(inputs); + for(auto&& [input, param] : param_map) + { + (*map_ins)[param] = map_ins->at(input); + } + return this->add_instructions(&m, map_ins, std::move(insert)); +} + void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) { auto it = std::find(inputs.begin(), inputs.end(), ins); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a0a16512358..de1ffa0330d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -604,6 +604,22 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op } }; +struct find_pointwise_mlir +{ + auto matcher() const + { + return match::name("gpu::mlir_op")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise"))); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto pw = r.instructions["pointwise"]; + + + } +}; + } // namespace #endif // MIGRAPHX_MLIR From d2d3baeaf99a89b04685d5f7b22c32fb36c68fbf Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 17 Apr 2024 19:17:30 -0700 Subject: [PATCH 02/26] Format --- src/fuse_reduce.cpp | 16 ++++++++-------- src/include/migraphx/module.hpp | 13 +++++++------ src/module.cpp | 23 +++++++++++------------ src/targets/gpu/fuse_mlir.cpp | 7 +++---- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index cbcca1cad8b..efb897d2d2b 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -87,7 +87,7 @@ static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map* map_ins = nullptr, - module::inserter insert = nullptr) + module::inserter insert = nullptr) { assert(ins->module_inputs().size() == 1); return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); @@ -318,13 +318,13 @@ struct reduce_reshape : rewrite_reshapes_base auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) { - if(contains(sop.name(), "reduce")) - return make_op(sop.name(), {{"axes", axes}}); - if(sop.name() == "multibroadcast") - return make_op("multibroadcast", {{"out_lens", dims}}); - assert(sop.name() == "pointwise"); - return sop; - })); + if(contains(sop.name(), "reduce")) + return make_op(sop.name(), {{"axes", axes}}); + if(sop.name() == "multibroadcast") + return make_op("multibroadcast", {{"out_lens", dims}}); + assert(sop.name() == "pointwise"); + return sop; + })); return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); } diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 159313dfbc9..2fa9d1aa6e5 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -231,15 +231,16 @@ struct MIGRAPHX_EXPORT module const std::vector& splits1, const std::vector& splits2) const; - std::vector fuse( - const std::vector& inss, - std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + std::vector + fuse(const std::vector& inss, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector fuse(const module& m, - const std::vector& inputs, - std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + const std::vector& inputs, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); void debug_print() const; void debug_print(instruction_ref ins) const; diff --git a/src/module.cpp b/src/module.cpp index 9a9f8c0b07d..73b51bd05e9 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -923,22 +923,22 @@ static void insert_params(module& m, { if(contains(map_ins, input)) continue; - map_ins[input] = - m.add_parameter(param_name(n++), input->get_shape().as_standard()); + map_ins[input] = m.add_parameter(param_name(n++), input->get_shape().as_standard()); } } -std::vector module::fuse( - const std::vector& inss, - std::unordered_map* map_ins, module::inserter insert) +std::vector +module::fuse(const std::vector& inss, + std::unordered_map* map_ins, + module::inserter insert) { std::unordered_map default_map_ins; if(not map_ins) map_ins = &default_map_ins; std::vector inputs; - for(auto ins:inss) + for(auto ins : inss) { - for(auto input:ins->inputs()) + for(auto input : ins->inputs()) { if(contains(inss, input)) continue; @@ -952,11 +952,10 @@ std::vector module::fuse( } std::vector - module::fuse( - const module& m, - const std::vector& inputs, - std::unordered_map* map_ins, - module::inserter insert) +module::fuse(const module& m, + const std::vector& inputs, + std::unordered_map* map_ins, + module::inserter insert) { std::unordered_map default_map_ins; if(not map_ins) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index de1ffa0330d..1f86845d758 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -608,15 +608,14 @@ struct find_pointwise_mlir { auto matcher() const { - return match::name("gpu::mlir_op")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise"))); + return match::name("gpu::mlir_op")(match::any_of[match::inputs()]( + match::name("pointwise")(match::used_once()).bind("pointwise"))); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; - auto pw = r.instructions["pointwise"]; - - + auto pw = r.instructions["pointwise"]; } }; From af835094b51c3bfb48c82da7b8fdcb92153a9189 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 07:29:45 -0700 Subject: [PATCH 03/26] Add some initial code --- src/targets/gpu/fuse_mlir.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 1f86845d758..87528c06e6d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -616,6 +616,20 @@ struct find_pointwise_mlir { auto ins = r.result; auto pw = r.instructions["pointwise"]; + + auto* mm = ins->module_inputs().front(); + auto* pm = pw->module_inputs().front(); + + module_ref m = mpm.create_module(pm->name() + ":" + mm->name(), *pm); + m->fuse(*mm, ins->inputs()); + + // TODO: Use find_inputs + auto inputs = pw->inputs(); + inputs.insert(inputs.end(), ins->inputs().begin(), ins->inputs().end()); + + mpm.get_module().replace_instruction( + ins, ins->get_operator(), inputs, {m}); + } }; From ac479548aa7996ee3b783dce856440625c3ec60b Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 07:29:51 -0700 Subject: [PATCH 04/26] Format --- src/targets/gpu/fuse_mlir.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 87528c06e6d..f96350c1790 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -627,9 +627,7 @@ struct find_pointwise_mlir auto inputs = pw->inputs(); inputs.insert(inputs.end(), ins->inputs().begin(), ins->inputs().end()); - mpm.get_module().replace_instruction( - ins, ins->get_operator(), inputs, {m}); - + mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m}); } }; From c9407aa1c8cb85ec8568c348e5d743a08056131d Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 13:48:53 -0700 Subject: [PATCH 05/26] Reuse find_inputs --- src/fuse_reduce.cpp | 33 ++++------------------------ src/include/migraphx/param_utils.hpp | 6 +++++ src/param_utils.cpp | 29 ++++++++++++++++++++++++ src/targets/gpu/fuse_mlir.cpp | 16 +++++++++----- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index efb897d2d2b..1ba50a8cee0 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -93,32 +94,6 @@ insert_module_in_submodule(module_ref sm, return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); } -static std::vector -find_inputs(const_module_ref sm, - const module& parent, - const std::unordered_map& map_ins) -{ - std::vector result; - std::map names; - for(auto&& [input, param] : map_ins) - { - if(not sm->has_instruction(param)) - continue; - if(param->name() != "@param") - continue; - if(not parent.has_instruction(input)) - continue; - auto v = param->get_operator().to_value(); - auto name = v.at("parameter").to(); - names[name] = input; - } - std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { - return p.second; - }); - assert(result.size() == sm->get_parameter_shapes().size()); - return result; -} - static void create_reduce_modules(module_pass_manager& mpm) { std::size_t n = 0; @@ -193,7 +168,7 @@ struct find_pointwise_reduce // Insert fused_reduce rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins)); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); } }; @@ -234,7 +209,7 @@ struct find_reduce_pointwise auto out = rm->fuse({pw}, &map_ins); rm->replace_return(out); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); } }; @@ -278,7 +253,7 @@ struct find_reduce_reduce auto out = insert_module_in_submodule(rm, reduce1, &map_ins); rm->replace_return(out); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm}); } }; diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 1552c28300b..f594f8be7f7 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -37,6 +38,11 @@ std::string param_name(std::size_t i, const std::string& prefix = "x"); void sort_params(std::vector& params); +std::vector +find_inputs(const std::unordered_map& map_ins, + const_module_ref parent, + const_module_ref sub); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 61302a0afba..a3a07acaa26 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -25,6 +25,9 @@ #include #include #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -43,5 +46,31 @@ void sort_params(std::vector& params) })); } +std::vector +find_inputs(const std::unordered_map& map_ins, + const_module_ref parent, + const_module_ref sub) +{ + std::vector result; + std::map names; + for(auto&& [input, param] : map_ins) + { + if(sub and not sub->has_instruction(param)) + continue; + if(param->name() != "@param") + continue; + if(parent and not parent->has_instruction(input)) + continue; + auto v = param->get_operator().to_value(); + auto name = v.at("parameter").to(); + names[name] = input; + } + std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { + return p.second; + }); + assert(not sub or result.size() == sub->get_parameter_shapes().size()); + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f96350c1790..a71451176f9 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace migraphx { @@ -620,13 +621,16 @@ struct find_pointwise_mlir auto* mm = ins->module_inputs().front(); auto* pm = pw->module_inputs().front(); - module_ref m = mpm.create_module(pm->name() + ":" + mm->name(), *pm); - m->fuse(*mm, ins->inputs()); + std::unordered_map map_ins; + module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); + m->set_bypass(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); + map_ins[pw] = rins; - // TODO: Use find_inputs - auto inputs = pw->inputs(); - inputs.insert(inputs.end(), ins->inputs().begin(), ins->inputs().end()); + auto ret = m->fuse(*mm, ins->inputs(), &map_ins); + m->add_return({ret}); + auto inputs = find_inputs(map_ins, &mpm.get_module(), m); mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m}); } }; @@ -666,6 +670,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const mpm, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); + + match::find_matches(mpm, find_pointwise_mlir{}); #else (void)mpm; #endif From ea41fb98676f32eed6c7e3b28afee1fd7902724c Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 13:48:57 -0700 Subject: [PATCH 06/26] Format --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a71451176f9..671ddd9042c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -624,7 +624,7 @@ struct find_pointwise_mlir std::unordered_map map_ins; module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); m->set_bypass(); - auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); map_ins[pw] = rins; auto ret = m->fuse(*mm, ins->inputs(), &map_ins); From 61d788cd5788943ab78fc129ce3016c71a231523 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 25 Apr 2024 13:19:37 -0700 Subject: [PATCH 07/26] Enable with env var --- src/targets/gpu/fuse_mlir.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 671ddd9042c..3ae0ac35c35 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -40,6 +40,7 @@ struct module; namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -671,7 +672,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); - match::find_matches(mpm, find_pointwise_mlir{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + match::find_matches(mpm, find_pointwise_mlir{}); #else (void)mpm; #endif From 66e9d31830042be3ef3a64195e2ac6cb7933e7a4 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:33:28 -0700 Subject: [PATCH 08/26] Update comments --- src/include/migraphx/module.hpp | 6 +++++- src/module.cpp | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 5c7f8aa24a4..dbe98b7f0d7 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -244,12 +244,16 @@ struct MIGRAPHX_EXPORT module std::array split(const std::vector& args, const std::vector& splits1, const std::vector& splits2) const; - + + // Fuse the instruction into the module by inserting the instructions and + // parameters for any missing inputs. std::vector fuse(const std::vector& inss, std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + // Fuse another module into this module by inserting the instructions and + // parameters from the module std::vector fuse(const module& m, const std::vector& inputs, diff --git a/src/module.cpp b/src/module.cpp index 46a5aae7317..7db4125070b 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -998,7 +998,7 @@ module::fuse(const std::vector& inss, module::inserter insert) { std::unordered_map default_map_ins; - if(not map_ins) + if(map_ins == nullptr) map_ins = &default_map_ins; std::vector inputs; for(auto ins : inss) @@ -1023,7 +1023,7 @@ module::fuse(const module& m, module::inserter insert) { std::unordered_map default_map_ins; - if(not map_ins) + if(map_ins == nullptr) map_ins = &default_map_ins; insert_params(*this, inputs, *map_ins); auto param_map = m.get_ins_param_map(inputs); From 9107b26f93bfb7935b1dcf7b4ad74d751921dea0 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:33:34 -0700 Subject: [PATCH 09/26] Format --- src/include/migraphx/module.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index dbe98b7f0d7..b9c5c2f3541 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -244,7 +244,7 @@ struct MIGRAPHX_EXPORT module std::array split(const std::vector& args, const std::vector& splits1, const std::vector& splits2) const; - + // Fuse the instruction into the module by inserting the instructions and // parameters for any missing inputs. std::vector From cbf3afcacfc50bf3976049a92eed542b3e594ccb Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:35:36 -0700 Subject: [PATCH 10/26] Fix param_utils --- src/include/migraphx/param_utils.hpp | 2 ++ src/param_utils.cpp | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index f594f8be7f7..f645229cd10 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -38,6 +38,8 @@ std::string param_name(std::size_t i, const std::string& prefix = "x"); void sort_params(std::vector& params); +// Find the inputs for a module by finding instructions that are mapped to the +// parameters in the module std::vector find_inputs(const std::unordered_map& map_ins, const_module_ref parent, diff --git a/src/param_utils.cpp b/src/param_utils.cpp index a3a07acaa26..4a447a4ae2e 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -55,11 +55,11 @@ find_inputs(const std::unordered_map& map_ins, std::map names; for(auto&& [input, param] : map_ins) { - if(sub and not sub->has_instruction(param)) + if(sub != nullptr and not sub->has_instruction(param)) continue; if(param->name() != "@param") continue; - if(parent and not parent->has_instruction(input)) + if(parent != nullptr and not parent->has_instruction(input)) continue; auto v = param->get_operator().to_value(); auto name = v.at("parameter").to(); From 3947949760b240174611de1225320eff6de88b6a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:39:16 -0700 Subject: [PATCH 11/26] Filter supported ops --- src/targets/gpu/fuse_mlir.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 9d2dcd0b524..f8c5820f5ac 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -392,14 +392,25 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } +bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) +{ + return is_pointwise_op_supported_by_mlir(i); +} + MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) { if(ins->name() != "pointwise") return false; auto* pm = ins->module_inputs().front(); - return std::all_of(pm->begin(), pm->end(), [&](const auto& i) { - return is_pointwise_op_supported_by_mlir(i); - }); + return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir); +} + +MIGRAPHX_PRED_MATCHER(mlir_input_pointwise, instruction_ref ins) +{ + if(ins->name() != "pointwise") + return false; + auto* pm = ins->module_inputs().front(); + return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir_for_input); } struct find_mlir_fused_ops @@ -563,7 +574,7 @@ struct find_pointwise_mlir auto matcher() const { return match::name("gpu::mlir_op")(match::any_of[match::inputs()]( - match::name("pointwise")(match::used_once()).bind("pointwise"))); + mlir_input_pointwise(match::used_once()).bind("pointwise"))); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const From ffdba3cae478400570783730234bb178a4990c40 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:41:20 -0700 Subject: [PATCH 12/26] Add another comment --- src/targets/gpu/fuse_mlir.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f8c5820f5ac..749b89cdd32 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -392,6 +392,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } +// A seprate function so we can remove operators that are supported by mlir +// but not supported for an input fusion. bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) { return is_pointwise_op_supported_by_mlir(i); From 48712c9d6e7e86ebdc92684a76e22f7a6ca3e8bd Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 22 May 2024 17:07:26 -0700 Subject: [PATCH 13/26] Handle scalars --- src/targets/gpu/fuse_mlir.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f24ca3075d4..f91b1cf9052 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include #include @@ -595,6 +597,16 @@ struct find_pointwise_mlir mlir_input_pointwise(match::used_once()).bind("pointwise"))); } + static instruction_ref insert_pointwise(module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) + { + assert(mod_args.empty()); + return insert_common_op(m, ins, op, inputs); + } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; @@ -606,7 +618,7 @@ struct find_pointwise_mlir std::unordered_map map_ins; module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); m->set_bypass(); - auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins, &insert_pointwise).front(); map_ins[pw] = rins; auto ret = m->fuse(*mm, ins->inputs(), &map_ins); @@ -653,6 +665,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); + mpm.run_pass(dead_code_elimination{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) match::find_matches(mpm, find_pointwise_mlir{}); #else From c3cf902c588c9fc7d7d242e4aa9cf94212feea1c Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 22 May 2024 17:07:32 -0700 Subject: [PATCH 14/26] Format --- src/targets/gpu/fuse_mlir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f91b1cf9052..f5af1d3d51f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -598,10 +598,10 @@ struct find_pointwise_mlir } static instruction_ref insert_pointwise(module& m, - instruction_ref ins, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) { assert(mod_args.empty()); return insert_common_op(m, ins, op, inputs); From 51d3ea915bbfbc1ccff595801964d06c4de1a5ef Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 18 Jun 2024 13:13:09 -0700 Subject: [PATCH 15/26] Add description --- src/module.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/module.cpp b/src/module.cpp index afd0185b3ba..9d2229dd222 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -979,6 +979,8 @@ std::array module::split(const std::vector& inputs, std::unordered_map& map_ins) From 0e600850cf00a7e95bbffa4d23260b27522fbb38 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 19 Jun 2024 17:03:00 -0400 Subject: [PATCH 16/26] Update src/targets/gpu/fuse_mlir.cpp Co-authored-by: Umang Yadav <29876643+umangyadav@users.noreply.github.com> --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 57ec8d81d81..3e62934101f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -400,7 +400,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } -// A seprate function so we can remove operators that are supported by mlir +// A separate function so we can remove operators that are supported by mlir // but not supported for an input fusion. bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) { From fd0b7f7ce4696ea9bd076fac014c35f4007a7e3f Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 19 Jun 2024 19:45:32 -0700 Subject: [PATCH 17/26] Add doc --- docs/dev/env_vars.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 70135ad6836..0a7de7c4437 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -272,6 +272,11 @@ Performs exhaustive tuning for MLIR. Set to an integer greater than 1. Limits the number of solutions available to MLIR for tuning. +.. envvar:: MIGRAPHX_ENABLE_MLIR_INPUT_FUSION + +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable input fusions in MLIR. + CK vars ----------- From 9c4d6590036aa83013817d6f0a7383b78f8fbddb Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 20 Jun 2024 16:48:04 -0700 Subject: [PATCH 18/26] Add input fusion to jenkins --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index de6fa059a0b..7cb184f51d1 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" From 5eaaed3c90874ee55e174063824c8c8bd3dc03c0 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 20 Jun 2024 17:25:38 -0700 Subject: [PATCH 19/26] Add unit test for fuse module --- test/module_test.cpp | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index 52981930bbb..a6cb1976fd1 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -28,10 +28,11 @@ #include #include #include -#include "test.hpp" #include #include +#include +#include migraphx::program create_program() { @@ -659,4 +660,35 @@ TEST_CASE(module_split3) EXPECT(bool{mods[2].inputs[1] == splits1.front()}); } +TEST_CASE(fuse_module) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add = add_pointwise(p, "main:pointwise0", {x, y}, single_pointwise("add")); + auto mul = add_pointwise(p, "main:pointwise1", {add, z}, single_pointwise("mul")); + + std::unordered_map map_ins; + auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); + map_ins[add] = rins; + auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); + m1.add_return(ret); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x0", s); + auto y = m2.add_parameter("x1", s); + auto z = m2.add_parameter("x2", s); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + auto mul = m2.add_instruction(migraphx::make_op("mul"), add, z); + m2.add_return({mul}); + } + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 1edac2d5a93e27ebd5bd0d4db8c9d7fc0b30867a Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 20 Jun 2024 17:25:43 -0700 Subject: [PATCH 20/26] Format --- test/module_test.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index a6cb1976fd1..3b910f8dfdf 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -666,24 +666,24 @@ TEST_CASE(fuse_module) migraphx::module m1; { migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); auto add = add_pointwise(p, "main:pointwise0", {x, y}, single_pointwise("add")); auto mul = add_pointwise(p, "main:pointwise1", {add, z}, single_pointwise("mul")); std::unordered_map map_ins; - auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); + auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); map_ins[add] = rins; - auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); + auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); m1.add_return(ret); } migraphx::module m2; { - auto x = m2.add_parameter("x0", s); - auto y = m2.add_parameter("x1", s); - auto z = m2.add_parameter("x2", s); + auto x = m2.add_parameter("x0", s); + auto y = m2.add_parameter("x1", s); + auto z = m2.add_parameter("x2", s); auto add = m2.add_instruction(migraphx::make_op("add"), x, y); auto mul = m2.add_instruction(migraphx::make_op("mul"), add, z); m2.add_return({mul}); From 1ebdaf1dc33bd9c47d891528da935ef8bab30739 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 14:28:54 -0700 Subject: [PATCH 21/26] Add unit test --- src/targets/gpu/fuse_mlir.cpp | 2 +- test/gpu/fuse_mlir.cpp | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 3e62934101f..d663fa3f1e2 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -607,7 +607,7 @@ struct find_pointwise_mlir instruction_ref ins, const operation& op, const std::vector& inputs, - const std::vector& mod_args) + const std::vector&) { assert(mod_args.empty()); return insert_common_op(m, ins, op, inputs); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 6b646720d66..68b248c44ae 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -32,6 +32,8 @@ #include #include +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); + void run_pass(migraphx::program& p) { migraphx::run_passes( @@ -100,6 +102,44 @@ TEST_CASE(dot_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_dot) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, b); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = + add_mlir(p2, + "main:pointwise0:mlir_dot1", + {x, y, b}, + {"x0", "x1", "x2"}, + [=](auto* pm, const auto& inputs) { + auto add = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto dot = + pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); + return std::make_tuple(dot, dot); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(int_quant_dot_abs) { migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}}; From d035f3bb7b155814d7a36c09bd9201987798857d Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 14:29:01 -0700 Subject: [PATCH 22/26] Format --- test/gpu/fuse_mlir.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 68b248c44ae..e124b47da84 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -128,9 +128,9 @@ TEST_CASE(add_dot) {x, y, b}, {"x0", "x1", "x2"}, [=](auto* pm, const auto& inputs) { - auto add = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - auto dot = - pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); + auto add = + pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); return std::make_tuple(dot, dot); }); mm->add_return({fused}); From 271ea78d992500fcf00c4f891d0aa977567941d7 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 16:41:23 -0700 Subject: [PATCH 23/26] Add verify test --- test/verify/test_add_dot.cpp | 49 ++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 test/verify/test_add_dot.cpp diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp new file mode 100644 index 00000000000..eb62d0191fd --- /dev/null +++ b/test/verify/test_add_dot.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {256, 256}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("y", s); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); + mm->add_return({dot}); + return p; + } +}; + +template struct test_add_dot; +template struct test_add_dot; From 43e76f503eaee61e72f64e59b6b57c9599f3d6d9 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 16:41:38 -0700 Subject: [PATCH 24/26] Format --- test/verify/test_add_dot.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp index eb62d0191fd..6028e211ced 100644 --- a/test/verify/test_add_dot.cpp +++ b/test/verify/test_add_dot.cpp @@ -35,9 +35,9 @@ struct test_add_dot : verify_program> migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{DType, {256, 256}}; - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("y", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("y", s); auto add = mm->add_instruction(migraphx::make_op("add"), x, y); auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); mm->add_return({dot}); From b357f943b0210705b6825863271b8511466b285e Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 16:49:57 -0700 Subject: [PATCH 25/26] Fix tidy issue --- src/targets/gpu/fuse_mlir.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index d663fa3f1e2..e901dc24a2b 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -607,8 +607,10 @@ struct find_pointwise_mlir instruction_ref ins, const operation& op, const std::vector& inputs, - const std::vector&) + const std::vector& mod_args) { + // Only used in assert + (void)mod_args; assert(mod_args.empty()); return insert_common_op(m, ins, op, inputs); } From 5e848ba9758583b4811a1e205b66d67b0be02501 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 17:45:49 -0700 Subject: [PATCH 26/26] Fix parameter name --- test/verify/test_add_dot.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp index 6028e211ced..ad23cc5acf6 100644 --- a/test/verify/test_add_dot.cpp +++ b/test/verify/test_add_dot.cpp @@ -37,7 +37,7 @@ struct test_add_dot : verify_program> migraphx::shape s{DType, {256, 256}}; auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); auto add = mm->add_instruction(migraphx::make_op("add"), x, y); auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); mm->add_return({dot});