From ded44c33595a2172bc79c65dffe8b3a02f9eb115 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 14:26:12 -0700 Subject: [PATCH 01/59] Add initial split reduce pass --- src/CMakeLists.txt | 1 + src/fuse_pointwise.cpp | 2 +- src/fuse_reduce.cpp | 4 +- src/include/migraphx/module.hpp | 12 +- src/include/migraphx/split_reduce.hpp | 22 +++ src/module.cpp | 31 ++-- src/split_reduce.cpp | 207 ++++++++++++++++++++++++++ src/split_single_dyn_dim.cpp | 2 +- src/targets/gpu/fuse_mlir.cpp | 2 +- src/targets/gpu/target.cpp | 5 +- 10 files changed, 263 insertions(+), 25 deletions(-) create mode 100644 src/include/migraphx/split_reduce.hpp create mode 100644 src/split_reduce.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 89000b1781e..c4b46308586 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -94,6 +94,7 @@ add_library(migraphx replace_allocate.cpp rewrite_reduce.cpp simplify_qdq.cpp + split_reduce.cpp sqlite.cpp rewrite_gelu.cpp rewrite_low_precision.cpp diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 90ad475f7e2..e6f29a251a7 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -163,7 +163,7 @@ static std::vector append_pointwise_module(instruction_ref ins, input_map[input] = map_ins[param]; } } - pm->replace_return(pm->insert_instructions(last, xm, map_ins)); + pm->replace_return(pm->insert_instructions(last, xm, &map_ins)); return inputs; } diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 6f2e6a1b862..2ad9153b6ae 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -118,7 +118,7 @@ static auto insert_ins_in_submodule(module_ref sm, std::unordered_map& map_ins) { insert_params(sm, ins, map_ins); - return sm->add_instructions({ins}, map_ins); + return sm->add_instructions({ins}, &map_ins); } static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) @@ -139,7 +139,7 @@ insert_module_in_submodule(module_ref sm, { map_ins[param] = map_ins.at(input); } - return sm->add_instructions(m, map_ins); + return sm->add_instructions(m, &map_ins); } static std::vector diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 7a650c79914..1a2404ebdae 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -122,32 +122,32 @@ struct MIGRAPHX_EXPORT module std::vector add_instructions(const std::vector& instructions, - std::unordered_map map_ins = {}); + std::unordered_map* map_ins = nullptr); std::vector add_instructions(const_module_ref m, - std::unordered_map map_ins = {}); + std::unordered_map* map_ins = nullptr); std::vector add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map map_ins = {}); + std::unordered_map* map_ins = nullptr); std::vector insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map map_ins = {}); + std::unordered_map* map_ins = nullptr); std::vector insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map map_ins = {}); + std::unordered_map* map_ins = nullptr); std::vector insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map map_ins = {}); + std::unordered_map* map_ins = nullptr); template instruction_ref add_literal(Ts&&... xs) diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp new file mode 100644 index 00000000000..2fe26454b19 --- /dev/null +++ b/src/include/migraphx/split_reduce.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module_pass_manager; + +struct MIGRAPHX_EXPORT split_reduce +{ + std::size_t split_size = 2048; + std::string name() const { return "split_reduce"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP + diff --git a/src/module.cpp b/src/module.cpp index 5091e92320f..37945df5d98 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -204,7 +204,7 @@ static std::vector insert_generic_instructions(module& m, instruction_ref ins, Range&& instructions, - std::unordered_map map_ins) + std::unordered_map& map_ins) { assert(m.has_instruction(ins) or is_end(ins, m.end())); std::vector mod_outputs; @@ -401,50 +401,53 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d std::vector module::add_instructions(const std::vector& instructions, - std::unordered_map map_ins) + std::unordered_map* map_ins) { - return this->insert_instructions(this->end(), instructions, std::move(map_ins)); + return this->insert_instructions(this->end(), instructions, map_ins); } std::vector module::add_instructions(const_module_ref m, - std::unordered_map map_ins) + std::unordered_map* map_ins) { - return this->insert_instructions(this->end(), m, std::move(map_ins)); + return this->insert_instructions(this->end(), m, map_ins); } std::vector module::add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map map_ins) + std::unordered_map* map_ins) { - return this->insert_instructions(this->end(), start, last, std::move(map_ins)); + return this->insert_instructions(this->end(), start, last, map_ins); } std::vector module::insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map map_ins) + std::unordered_map* map_ins) { - return insert_generic_instructions(*this, ins, instructions, std::move(map_ins)); + std::unordered_map default_map_ins; + return insert_generic_instructions(*this, ins, instructions, map_ins ? *map_ins : default_map_ins); } std::vector module::insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map map_ins) + std::unordered_map* map_ins) { - return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins)); + std::unordered_map default_map_ins; + return insert_generic_instructions(*this, ins, iterator_for(*m), map_ins ? *map_ins : default_map_ins); } std::vector module::insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map map_ins) + std::unordered_map* map_ins) { auto r = range(start, last); - return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins)); + std::unordered_map default_map_ins; + return insert_generic_instructions(*this, ins, iterator_for(r), map_ins ? *map_ins : default_map_ins); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } @@ -462,6 +465,7 @@ instruction_ref module::add_parameter(std::string name, shape s) instruction_ref module::add_return(std::vector args) { + assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); shape instr_shape = compute_shape(builtin::returns{}, args); impl->push_back({builtin::returns{}, instr_shape, std::move(args)}); auto result = std::prev(impl->instructions.end()); @@ -486,6 +490,7 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name, instruction_ref module::replace_return(std::vector args) { + assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); auto last = std::prev(this->end()); // If there is no return then add a return if(last->name() != "@return") diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp new file mode 100644 index 00000000000..eaf8675e758 --- /dev/null +++ b/src/split_reduce.cpp @@ -0,0 +1,207 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct split_fused_reduce +{ + std::vector axes{}; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axes, "axes")); + } + + shape compute_shape(const std::vector& inputs, std::vector mods) const + { + if(mods.size() != 1) + MIGRAPHX_THROW("should have one submodule."); + const auto* sm = mods.front(); + if(sm->get_output_shapes().size() != 1) + MIGRAPHX_THROW("Only one output supported"); + auto names = sm->get_parameter_names(); + check_shapes{inputs, *this}.has(names.size()).same_ndims(); + std::sort(names.begin(), names.end()); + auto shapes = sm->get_parameter_shapes(); + // Check dimension matches for each input + if(not equal(names, inputs, [&](const auto& name, const auto& input) { + return shapes.at(name).lens() == input.lens(); + })) + MIGRAPHX_THROW("Dimenstion does not match the submodule."); + const auto& s = inputs.at(0); + auto lens = s.lens(); + if(lens != sm->get_output_shapes().front().lens()) + { + for(const auto& axis : axes) + { + lens[axis] = 1; + } + } + + return shape::from_permutation( + sm->get_output_shapes().front().type(), lens, find_permutation(inputs)); + } + + std::string name() const { return "split_fused_reduce"; } +}; +MIGRAPHX_REGISTER_OP(split_fused_reduce); + + +static bool is_reduce(const instruction& ins) +{ + return contains(ins.name(), "reduce"); +} + +static std::string param_name(std::size_t i, const std::string& prefix = "x") +{ + return prefix + std::to_string(i); +} + +struct module_with_inputs +{ + module mod; + std::vector inputs; +}; + +static std::pair split_module(module_ref m, const std::vector& splits, const std::vector& args) +{ + std::unordered_map param_map; + auto params = m->get_parameter_names(); + std::sort(params.begin(), params.end()); + std::transform(params.begin(), params.end(), args.begin(), std::inserter(param_map, param_map.begin()), [&](const std::string& name, instruction_ref arg) { + return std::make_pair(m->get_parameter(name), arg); + }); + + std::unordered_set selected_instructions; + fix([&](auto self, const std::vector& inputs) { + for(auto input:inputs) + { + if(contains(selected_instructions, input)) + continue; + selected_instructions.insert(input); + self(input->inputs()); + } + })(splits); + + std::vector instructions1; + // TODO: copy_if + for(auto ins:iterator_for(*m)) + { + if(not contains(selected_instructions, ins)) + continue; + instructions1.push_back(ins); + } + + std::vector inputs1; + for(auto ins:instructions1) + { + if(not contains(param_map, ins)) + continue; + inputs1.push_back(param_map[ins]); + } + module m1; + std::unordered_map map_ins1; + m1.add_instructions(instructions1, &map_ins1); + std::vector outputs; + std::transform(splits.begin(), splits.end(), std::back_inserter(outputs), [&](instruction_ref ins) { + return map_ins1.at(ins); + }); + m1.add_return(outputs); + + std::vector instructions2; + for(auto ins:iterator_for(*m)) + { + if(contains(selected_instructions, ins)) + continue; + // Input params can be used in both modules + std::vector input_params; + // TODO: Use join_inserter + std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(input_params), [&](instruction_ref input) { + if(input->name() != "@param") + return false; + return not contains(instructions2, input); + }); + instructions2.insert(instructions2.end(), input_params.begin(), input_params.end()); + instructions2.push_back(ins); + } + + std::vector inputs2; + for(auto ins:instructions2) + { + if(not contains(param_map, ins)) + continue; + inputs2.push_back(param_map[ins]); + } + module m2; + std::unordered_map map_ins2; + std::size_t n = 0; + for(auto ins:splits) + map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); + for(auto ins:iterator_for(*m)) + { + if(ins->name() != "@param") + continue; + if(not contains(instructions2, ins)) + continue; + map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); + } + auto r = m2.add_instructions(instructions2, &map_ins2); + m2.add_return(r); + return {{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}; +} + +static std::vector find_split(module_ref rm) +{ + std::vector result; + auto reduce_ins = std::find_if(rm->begin(), rm->end(), &is_reduce); + if(reduce_ins == rm->end()) + return result; + // Bail if there is more than one reduce for now + if(std::any_of(std::next(reduce_ins), rm->end(), &is_reduce)) + return result; + result.push_back(reduce_ins); + // TODO: Find instructions that are used again in the module + return result; +} + +void split_reduce::apply(module_pass_manager& mpm) const +{ + for(auto ins:iterator_for(mpm.get_module())) + { + if(ins->name() != "fused_reduce") + continue; + auto* rm = ins->module_inputs().front(); + auto splits = find_split(rm); + if(splits.empty()) + continue; + auto v = ins->get_operator().to_value(); + auto axes = v["axes"].to_vector(); + + auto mp = split_module(rm, splits, ins->inputs()); + auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp.first.mod)); + auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp.second.mod)); + m1->set_bypass(); + m2->set_bypass(); + auto split_reduce = mpm.get_module().insert_instruction(ins, make_op("split_fused_reduce", {{"axes", axes}}), mp.first.inputs, {m1}); + std::vector inputs = {split_reduce}; + inputs.insert(inputs.end(), mp.second.inputs.begin(), mp.second.inputs.end()); + mpm.get_module().replace_instruction(ins, make_op("fused_reduce"), inputs, {m2}); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index 0bd21521ed5..a5707c38674 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -140,7 +140,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const auto static_shape = dyn_param_shape.to_static(dim_size); map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape); } - auto outputs = submod->add_instructions(mm, map_ins); + auto outputs = submod->add_instructions(mm, &map_ins); submod->add_return({outputs}); submodules.push_back(submod); } diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 88615ffc694..fe9ec540658 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -357,7 +357,7 @@ fold_pointwise_mod(instruction_ref pm_ins, pm->get_parameter(name), parent_mod->add_parameter(name, input->get_shape().as_standard())); }); - return parent_mod->insert_instructions(parent_mod->end(), pm, param_map); + return parent_mod->insert_instructions(parent_mod->end(), pm, ¶m_map); } // Whitelist supported fusion options, including imposing type constraints diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index cc0a136892d..d16077c9d12 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -54,6 +54,7 @@ #include #include #include +#include #include #include #include @@ -76,7 +77,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) @@ -162,6 +163,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), dead_code_elimination{}, + enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}), + dead_code_elimination{}, fuse_concat{}, dead_code_elimination{}, #ifndef _WIN32 From cc84a582618a6350ece1821bc7a5a41a8f9f9e69 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 14:26:19 -0700 Subject: [PATCH 02/59] Format --- src/include/migraphx/split_reduce.hpp | 1 - src/module.cpp | 9 ++-- src/split_reduce.cpp | 68 +++++++++++++++------------ src/targets/gpu/target.cpp | 10 ++-- 4 files changed, 49 insertions(+), 39 deletions(-) diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index 2fe26454b19..a7a21ca1bc8 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -19,4 +19,3 @@ struct MIGRAPHX_EXPORT split_reduce } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP - diff --git a/src/module.cpp b/src/module.cpp index 37945df5d98..12f5c55ca66 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -427,7 +427,8 @@ module::insert_instructions(instruction_ref ins, std::unordered_map* map_ins) { std::unordered_map default_map_ins; - return insert_generic_instructions(*this, ins, instructions, map_ins ? *map_ins : default_map_ins); + return insert_generic_instructions( + *this, ins, instructions, map_ins ? *map_ins : default_map_ins); } std::vector @@ -436,7 +437,8 @@ module::insert_instructions(instruction_ref ins, std::unordered_map* map_ins) { std::unordered_map default_map_ins; - return insert_generic_instructions(*this, ins, iterator_for(*m), map_ins ? *map_ins : default_map_ins); + return insert_generic_instructions( + *this, ins, iterator_for(*m), map_ins ? *map_ins : default_map_ins); } std::vector @@ -447,7 +449,8 @@ module::insert_instructions(instruction_ref ins, { auto r = range(start, last); std::unordered_map default_map_ins; - return insert_generic_instructions(*this, ins, iterator_for(r), map_ins ? *map_ins : default_map_ins); + return insert_generic_instructions( + *this, ins, iterator_for(r), map_ins ? *map_ins : default_map_ins); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index eaf8675e758..beab604a101 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -11,7 +11,6 @@ #include #include - namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -59,11 +58,7 @@ struct split_fused_reduce }; MIGRAPHX_REGISTER_OP(split_fused_reduce); - -static bool is_reduce(const instruction& ins) -{ - return contains(ins.name(), "reduce"); -} +static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } static std::string param_name(std::size_t i, const std::string& prefix = "x") { @@ -76,18 +71,25 @@ struct module_with_inputs std::vector inputs; }; -static std::pair split_module(module_ref m, const std::vector& splits, const std::vector& args) +static std::pair +split_module(module_ref m, + const std::vector& splits, + const std::vector& args) { std::unordered_map param_map; auto params = m->get_parameter_names(); std::sort(params.begin(), params.end()); - std::transform(params.begin(), params.end(), args.begin(), std::inserter(param_map, param_map.begin()), [&](const std::string& name, instruction_ref arg) { - return std::make_pair(m->get_parameter(name), arg); - }); + std::transform(params.begin(), + params.end(), + args.begin(), + std::inserter(param_map, param_map.begin()), + [&](const std::string& name, instruction_ref arg) { + return std::make_pair(m->get_parameter(name), arg); + }); std::unordered_set selected_instructions; fix([&](auto self, const std::vector& inputs) { - for(auto input:inputs) + for(auto input : inputs) { if(contains(selected_instructions, input)) continue; @@ -98,7 +100,7 @@ static std::pair split_module(module_ref std::vector instructions1; // TODO: copy_if - for(auto ins:iterator_for(*m)) + for(auto ins : iterator_for(*m)) { if(not contains(selected_instructions, ins)) continue; @@ -106,7 +108,7 @@ static std::pair split_module(module_ref } std::vector inputs1; - for(auto ins:instructions1) + for(auto ins : instructions1) { if(not contains(param_map, ins)) continue; @@ -116,30 +118,34 @@ static std::pair split_module(module_ref std::unordered_map map_ins1; m1.add_instructions(instructions1, &map_ins1); std::vector outputs; - std::transform(splits.begin(), splits.end(), std::back_inserter(outputs), [&](instruction_ref ins) { - return map_ins1.at(ins); - }); + std::transform(splits.begin(), + splits.end(), + std::back_inserter(outputs), + [&](instruction_ref ins) { return map_ins1.at(ins); }); m1.add_return(outputs); std::vector instructions2; - for(auto ins:iterator_for(*m)) + for(auto ins : iterator_for(*m)) { if(contains(selected_instructions, ins)) continue; // Input params can be used in both modules std::vector input_params; // TODO: Use join_inserter - std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(input_params), [&](instruction_ref input) { - if(input->name() != "@param") - return false; - return not contains(instructions2, input); - }); + std::copy_if(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(input_params), + [&](instruction_ref input) { + if(input->name() != "@param") + return false; + return not contains(instructions2, input); + }); instructions2.insert(instructions2.end(), input_params.begin(), input_params.end()); instructions2.push_back(ins); } std::vector inputs2; - for(auto ins:instructions2) + for(auto ins : instructions2) { if(not contains(param_map, ins)) continue; @@ -148,9 +154,9 @@ static std::pair split_module(module_ref module m2; std::unordered_map map_ins2; std::size_t n = 0; - for(auto ins:splits) + for(auto ins : splits) map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); - for(auto ins:iterator_for(*m)) + for(auto ins : iterator_for(*m)) { if(ins->name() != "@param") continue; @@ -179,23 +185,24 @@ static std::vector find_split(module_ref rm) void split_reduce::apply(module_pass_manager& mpm) const { - for(auto ins:iterator_for(mpm.get_module())) + for(auto ins : iterator_for(mpm.get_module())) { if(ins->name() != "fused_reduce") continue; - auto* rm = ins->module_inputs().front(); + auto* rm = ins->module_inputs().front(); auto splits = find_split(rm); if(splits.empty()) continue; - auto v = ins->get_operator().to_value(); + auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); - auto mp = split_module(rm, splits, ins->inputs()); + auto mp = split_module(rm, splits, ins->inputs()); auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp.first.mod)); auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp.second.mod)); m1->set_bypass(); m2->set_bypass(); - auto split_reduce = mpm.get_module().insert_instruction(ins, make_op("split_fused_reduce", {{"axes", axes}}), mp.first.inputs, {m1}); + auto split_reduce = mpm.get_module().insert_instruction( + ins, make_op("split_fused_reduce", {{"axes", axes}}), mp.first.inputs, {m1}); std::vector inputs = {split_reduce}; inputs.insert(inputs.end(), mp.second.inputs.begin(), mp.second.inputs.end()); mpm.get_module().replace_instruction(ins, make_op("fused_reduce"), inputs, {m2}); @@ -204,4 +211,3 @@ void split_reduce::apply(module_pass_manager& mpm) const } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx - diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index d16077c9d12..2d63f3430bb 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -77,13 +77,15 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) + MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) + MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif -std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const + std::vector target::get_passes(migraphx::context& gctx, + const compile_options& options) const { auto& ctx = any_cast(gctx); ctx.set_exhaustive_tune_flag(options.exhaustive_tune); From 35ffc5a5406d1e526e8f9fd972cdb08a804affd8 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 14:26:37 -0700 Subject: [PATCH 03/59] Format --- src/targets/gpu/target.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 2d63f3430bb..b3ed59df57b 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -79,13 +79,12 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) - MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 - MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif - std::vector target::get_passes(migraphx::context& gctx, - const compile_options& options) const +std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { auto& ctx = any_cast(gctx); ctx.set_exhaustive_tune_flag(options.exhaustive_tune); From cad365ac75294f45b312ff124fab05f588655f11 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 14:47:23 -0700 Subject: [PATCH 04/59] Fixes --- src/split_reduce.cpp | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index beab604a101..cc9f0373b66 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -195,17 +195,35 @@ void split_reduce::apply(module_pass_manager& mpm) const continue; auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); + // TODO: Check reduction size auto mp = split_module(rm, splits, ins->inputs()); auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp.first.mod)); auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp.second.mod)); m1->set_bypass(); m2->set_bypass(); + + // Insert split reduce auto split_reduce = mpm.get_module().insert_instruction( ins, make_op("split_fused_reduce", {{"axes", axes}}), mp.first.inputs, {m1}); + std::vector inputs = {split_reduce}; inputs.insert(inputs.end(), mp.second.inputs.begin(), mp.second.inputs.end()); - mpm.get_module().replace_instruction(ins, make_op("fused_reduce"), inputs, {m2}); + auto param_names = m2->get_parameter_names(); + std::sort(param_names.begin(), param_names.end()); + + // TODO: Use get_ins_param_map function + std::unordered_map param_map; + std::transform(param_names.begin(), + param_names.end(), + inputs.begin(), + std::inserter(param_map, param_map.begin()), + [&](const std::string& name, instruction_ref input) { + return std::make_pair(m2->get_parameter(name), input); + }); + auto replaced = mpm.get_module().insert_instructions(ins, m2, ¶m_map); + assert(replaced.size() == 1); + mpm.get_module().replace_instruction(ins, replaced.front()); } } From 363afbe1494bbfb82cdd31d46e34d77eaba9a6b2 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 14:47:29 -0700 Subject: [PATCH 05/59] Format --- src/split_reduce.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index cc9f0373b66..fc90e58b3b3 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -202,7 +202,7 @@ void split_reduce::apply(module_pass_manager& mpm) const auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp.second.mod)); m1->set_bypass(); m2->set_bypass(); - + // Insert split reduce auto split_reduce = mpm.get_module().insert_instruction( ins, make_op("split_fused_reduce", {{"axes", axes}}), mp.first.inputs, {m1}); @@ -215,12 +215,12 @@ void split_reduce::apply(module_pass_manager& mpm) const // TODO: Use get_ins_param_map function std::unordered_map param_map; std::transform(param_names.begin(), - param_names.end(), - inputs.begin(), - std::inserter(param_map, param_map.begin()), - [&](const std::string& name, instruction_ref input) { - return std::make_pair(m2->get_parameter(name), input); - }); + param_names.end(), + inputs.begin(), + std::inserter(param_map, param_map.begin()), + [&](const std::string& name, instruction_ref input) { + return std::make_pair(m2->get_parameter(name), input); + }); auto replaced = mpm.get_module().insert_instructions(ins, m2, ¶m_map); assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); From 8f007f7d34eba5ded7f52dc3e06dcba82a70e8d9 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 17:50:41 -0700 Subject: [PATCH 06/59] Implement split reduce_kernel --- src/split_reduce.cpp | 18 +++++- src/targets/gpu/jit/reduce.cpp | 62 ++++++++++++++++++- .../include/migraphx/kernels/reduce.hpp | 10 +-- .../kernels/scatter_reduction_modes.hpp | 6 +- 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index fc90e58b3b3..f193ca35d32 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -17,11 +17,13 @@ inline namespace MIGRAPHX_INLINE_NS { struct split_fused_reduce { std::vector axes{}; + std::string assign = "assign_none"; template static auto reflect(Self& self, F f) { - return pack(f(self.axes, "axes")); + return pack(f(self.axes, "axes"), + f(self.assign, "assign")); } shape compute_shape(const std::vector& inputs, std::vector mods) const @@ -183,6 +185,18 @@ static std::vector find_split(module_ref rm) return result; } +static std::string assign_op(const std::vector& splits) +{ + static std::unordered_map m = { + {"reduce_sum", "assign_add"}, + {"reduce_mean", "assign_add"}, + {"reduce_prod", "assign_mul"}, + {"reduce_max", "assign_max"}, + {"reduce_min", "assign_min"}, + }; + return m.at(splits.front()->name()); +} + void split_reduce::apply(module_pass_manager& mpm) const { for(auto ins : iterator_for(mpm.get_module())) @@ -205,7 +219,7 @@ void split_reduce::apply(module_pass_manager& mpm) const // Insert split reduce auto split_reduce = mpm.get_module().insert_instruction( - ins, make_op("split_fused_reduce", {{"axes", axes}}), mp.first.inputs, {m1}); + ins, make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}), mp.first.inputs, {m1}); std::vector inputs = {split_reduce}; inputs.insert(inputs.end(), mp.second.inputs.begin(), mp.second.inputs.end()); diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index fe00cb3fcc4..18c0802f7a5 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -27,6 +27,8 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -131,6 +133,58 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) return wavefront_size; } +static std::vector split_reduce(const std::vector& inputs, std::size_t min_size = 1024) +{ + std::vector result; + auto input_shape = inputs.front(); + auto reduce_shape = inputs[inputs.size() - 2]; + auto output_shape = inputs[inputs.size() - 1]; + + auto is = range(reduce_shape.lens().size()); + using array_type = std::array; + auto initial = array_type{std::numeric_limits::max(), std::numeric_limits::max()}; + auto faxis = transform_accumulate(is.begin(), is.end(), initial, MIGRAPHX_LIFT(std::min), [&](auto i) -> array_type { + if(input_shape.lens()[i] == output_shape.lens()[i]) + return initial; + return {input_shape.strides()[i], std::size_t(i)}; + })[1]; + + assert(faxis < reduce_shape.lens().size()); + + std::size_t n = 1; + auto r = input_shape.lens()[faxis]; + auto factors = make_array(2, 3, 5, 7, 11); + while(r > min_size) + { + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { + return r % d == 0; + }); + if (it == factors.end()) + break; + r /= *it; + n *= *it; + } + assert(n != 1); + std::transform(inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { + auto lens = s.lens(); + auto strides = s.strides(); + + lens.push_back(n); + if(lens[faxis] == 1) + { + strides.push_back(0); + } + else + { + lens[faxis] /= n; + strides.push_back(strides[faxis] * lens[faxis]); + } + + return {s.type(), lens, strides}; + }); + return reduce_dims(normalize_permutation(result)); +} + struct simple_reduce_compiler : compiler { std::vector names() const @@ -231,7 +285,7 @@ extern "C" { MIGRAPHX_GLOBAL void ${kernel}(${params}) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { - fused_reduce(y, partial(${lambda})(xs...)); + fused_reduce(y, ${assign}{}, partial(${lambda})(xs...)); }); } @@ -243,15 +297,18 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) struct fused_reduce_compiler : compiler { - std::vector names() const { return {"fused_reduce"}; } + std::vector names() const { return {"fused_reduce", "split_fused_reduce"}; } operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { + auto assign = v.get("assign", "assign_none"); auto axes = v.at("axes").to_vector(); auto virtual_inputs = inputs; virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); virtual_inputs.push_back(get_output_shape(inputs.front(), axes)); virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs)); + if(assign != "assign_none") + virtual_inputs = split_reduce(virtual_inputs); auto reduce_output_shape = virtual_inputs.back(); virtual_inputs.pop_back(); auto reduction_shape = virtual_inputs.back(); @@ -303,6 +360,7 @@ struct fused_reduce_compiler : compiler {{"kernel", options.kernel_name}, {"params", enum_params(inputs.size(), "void * private_p")}, {"args", enum_params(inputs.size(), "private_p")}, + {"assign", assign}, {"algo", algo}, {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, {"lambda", v.at("lambda").to()}, diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index e37f7423147..9adf8832beb 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace migraphx { @@ -730,21 +731,22 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu }); } -template -__device__ void fused_reduce(Output output, F f) +template +__device__ void fused_reduce(Output output, Assign assign, F f) { Algo::template run([&](auto out_idx, auto r) { auto result = f(r, out_idx); if constexpr(reduce::is_inner_storage{}) { - r.inner([&](auto& y, auto x) { y = x; })(output, result); + r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); } else { - r.outer([&] { output[out_idx] = implicit_conversion(result); }); + r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); } }); } + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index 4081c313ad4..3a6ea756baa 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -42,7 +42,7 @@ struct assign_add template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicAdd(&x, y); + atomicAdd(&x, T(y)); } }; @@ -66,7 +66,7 @@ struct assign_max template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicMax(&x, y); + atomicMax(&x, T(y)); } }; @@ -75,7 +75,7 @@ struct assign_min template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicMin(&x, y); + atomicMin(&x, T(y)); } }; From aed9edb91c12a21734868fac88649786f3dd7b1f Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 17 Mar 2024 17:50:52 -0700 Subject: [PATCH 07/59] Format --- src/split_reduce.cpp | 8 +- src/targets/gpu/jit/reduce.cpp | 80 ++++++++++--------- .../include/migraphx/kernels/reduce.hpp | 1 - 3 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index f193ca35d32..c6656603991 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -22,8 +22,7 @@ struct split_fused_reduce template static auto reflect(Self& self, F f) { - return pack(f(self.axes, "axes"), - f(self.assign, "assign")); + return pack(f(self.axes, "axes"), f(self.assign, "assign")); } shape compute_shape(const std::vector& inputs, std::vector mods) const @@ -219,7 +218,10 @@ void split_reduce::apply(module_pass_manager& mpm) const // Insert split reduce auto split_reduce = mpm.get_module().insert_instruction( - ins, make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}), mp.first.inputs, {m1}); + ins, + make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}), + mp.first.inputs, + {m1}); std::vector inputs = {split_reduce}; inputs.insert(inputs.end(), mp.second.inputs.begin(), mp.second.inputs.end()); diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 18c0802f7a5..c3ba3e39b8b 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -133,55 +133,57 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) return wavefront_size; } -static std::vector split_reduce(const std::vector& inputs, std::size_t min_size = 1024) +static std::vector split_reduce(const std::vector& inputs, + std::size_t min_size = 1024) { std::vector result; - auto input_shape = inputs.front(); + auto input_shape = inputs.front(); auto reduce_shape = inputs[inputs.size() - 2]; auto output_shape = inputs[inputs.size() - 1]; - auto is = range(reduce_shape.lens().size()); + auto is = range(reduce_shape.lens().size()); using array_type = std::array; - auto initial = array_type{std::numeric_limits::max(), std::numeric_limits::max()}; - auto faxis = transform_accumulate(is.begin(), is.end(), initial, MIGRAPHX_LIFT(std::min), [&](auto i) -> array_type { - if(input_shape.lens()[i] == output_shape.lens()[i]) - return initial; - return {input_shape.strides()[i], std::size_t(i)}; - })[1]; + auto initial = array_type{std::numeric_limits::max(), + std::numeric_limits::max()}; + auto faxis = transform_accumulate( + is.begin(), is.end(), initial, MIGRAPHX_LIFT(std::min), [&](auto i) -> array_type { + if(input_shape.lens()[i] == output_shape.lens()[i]) + return initial; + return {input_shape.strides()[i], std::size_t(i)}; + })[1]; assert(faxis < reduce_shape.lens().size()); std::size_t n = 1; - auto r = input_shape.lens()[faxis]; - auto factors = make_array(2, 3, 5, 7, 11); + auto r = input_shape.lens()[faxis]; + auto factors = make_array(2, 3, 5, 7, 11); while(r > min_size) { - auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { - return r % d == 0; - }); - if (it == factors.end()) + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + if(it == factors.end()) break; r /= *it; n *= *it; } assert(n != 1); - std::transform(inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { - auto lens = s.lens(); - auto strides = s.strides(); + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { + auto lens = s.lens(); + auto strides = s.strides(); - lens.push_back(n); - if(lens[faxis] == 1) - { - strides.push_back(0); - } - else - { - lens[faxis] /= n; - strides.push_back(strides[faxis] * lens[faxis]); - } + lens.push_back(n); + if(lens[faxis] == 1) + { + strides.push_back(0); + } + else + { + lens[faxis] /= n; + strides.push_back(strides[faxis] * lens[faxis]); + } - return {s.type(), lens, strides}; - }); + return {s.type(), lens, strides}; + }); return reduce_dims(normalize_permutation(result)); } @@ -301,7 +303,7 @@ struct fused_reduce_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { - auto assign = v.get("assign", "assign_none"); + auto assign = v.get("assign", "assign_none"); auto axes = v.at("axes").to_vector(); auto virtual_inputs = inputs; virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); @@ -358,14 +360,14 @@ struct fused_reduce_compiler : compiler auto src = interpolate_string( fused_reduce_kernel, {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"assign", assign}, - {"algo", algo}, - {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, - {"lambda", v.at("lambda").to()}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"assign", assign}, + {"algo", algo}, + {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); options.emplace_param("-Wno-float-equal"); return compile_hip_code_object(src, options); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index 9adf8832beb..b2fb0f4b00f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -747,6 +747,5 @@ __device__ void fused_reduce(Output output, Assign assign, F f) }); } - } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP From 63be4df239cc0e16207d4e54ad77946f08d3038a Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 07:41:49 -0700 Subject: [PATCH 08/59] Use unsafeAtomicAdd --- .../include/migraphx/kernels/scatter_reduction_modes.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index 3a6ea756baa..07d066a1589 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -42,7 +42,7 @@ struct assign_add template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicAdd(&x, T(y)); + unsafeAtomicAdd(&x, T(y)); } }; From 3f052e97b6de57f6f5fbc23df29284fe5aba8da9 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 14:30:57 -0700 Subject: [PATCH 09/59] Use array --- src/split_reduce.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index c6656603991..a1d70badae5 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -72,7 +72,7 @@ struct module_with_inputs std::vector inputs; }; -static std::pair +static std::array split_module(module_ref m, const std::vector& splits, const std::vector& args) @@ -167,7 +167,7 @@ split_module(module_ref m, } auto r = m2.add_instructions(instructions2, &map_ins2); m2.add_return(r); - return {{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}; + return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; } static std::vector find_split(module_ref rm) @@ -206,13 +206,17 @@ void split_reduce::apply(module_pass_manager& mpm) const auto splits = find_split(rm); if(splits.empty()) continue; + if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { + return split->get_shape().type() == shape::float_type; + })) + continue; auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); // TODO: Check reduction size auto mp = split_module(rm, splits, ins->inputs()); - auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp.first.mod)); - auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp.second.mod)); + auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp[0].mod)); + auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp[1].mod)); m1->set_bypass(); m2->set_bypass(); @@ -220,11 +224,11 @@ void split_reduce::apply(module_pass_manager& mpm) const auto split_reduce = mpm.get_module().insert_instruction( ins, make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}), - mp.first.inputs, + mp[0].inputs, {m1}); std::vector inputs = {split_reduce}; - inputs.insert(inputs.end(), mp.second.inputs.begin(), mp.second.inputs.end()); + inputs.insert(inputs.end(), mp[1].inputs.begin(), mp[1].inputs.end()); auto param_names = m2->get_parameter_names(); std::sort(param_names.begin(), param_names.end()); From 2fc9d69f444a80d9c8eedb8b3ea4f907eb113ef9 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 14:31:04 -0700 Subject: [PATCH 10/59] Format --- src/split_reduce.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index a1d70badae5..49db13de3c9 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -72,10 +72,9 @@ struct module_with_inputs std::vector inputs; }; -static std::array -split_module(module_ref m, - const std::vector& splits, - const std::vector& args) +static std::array split_module(module_ref m, + const std::vector& splits, + const std::vector& args) { std::unordered_map param_map; auto params = m->get_parameter_names(); @@ -207,8 +206,8 @@ void split_reduce::apply(module_pass_manager& mpm) const if(splits.empty()) continue; if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { - return split->get_shape().type() == shape::float_type; - })) + return split->get_shape().type() == shape::float_type; + })) continue; auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); From 6d4a878170747bd95626aa6376a330fd5966ce47 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 14:48:04 -0700 Subject: [PATCH 11/59] Format --- src/include/migraphx/module.hpp | 14 +++++++------- src/module.cpp | 15 ++++++++++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index e94ee6c3388..c0cf0868575 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -128,37 +128,37 @@ struct MIGRAPHX_EXPORT module std::vector add_instructions(const std::vector& instructions, std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + inserter insert = nullptr); std::vector add_instructions(const_module_ref m, std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + inserter insert = nullptr); std::vector add_instructions(instruction_ref start, instruction_ref last, std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, const std::vector& instructions, std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, const_module_ref m, std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); - + inserter insert = nullptr); + std::vector insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + inserter insert = nullptr); template instruction_ref add_literal(Ts&&... xs) diff --git a/src/module.cpp b/src/module.cpp index 886452918a4..52d57782823 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -431,7 +431,8 @@ module::add_instructions(const std::vector& instructions, std::vector module::add_instructions(const_module_ref m, - std::unordered_map* map_ins, module::inserter insert) + std::unordered_map* map_ins, + module::inserter insert) { return this->insert_instructions(this->end(), m, map_ins, std::move(insert)); } @@ -439,7 +440,8 @@ module::add_instructions(const_module_ref m, std::vector module::add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map* map_ins, module::inserter insert) + std::unordered_map* map_ins, + module::inserter insert) { return this->insert_instructions(this->end(), start, last, map_ins, std::move(insert)); } @@ -447,7 +449,8 @@ module::add_instructions(instruction_ref start, std::vector module::insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map* map_ins, module::inserter insert) + std::unordered_map* map_ins, + module::inserter insert) { std::unordered_map default_map_ins; return insert_generic_instructions( @@ -457,7 +460,8 @@ module::insert_instructions(instruction_ref ins, std::vector module::insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map* map_ins, module::inserter insert) + std::unordered_map* map_ins, + module::inserter insert) { std::unordered_map default_map_ins; return insert_generic_instructions( @@ -468,7 +472,8 @@ std::vector module::insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map* map_ins, module::inserter insert) + std::unordered_map* map_ins, + module::inserter insert) { auto r = range(start, last); std::unordered_map default_map_ins; From 17e91285cb18289d75da6c80f34ea53af206b2f8 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 15:29:15 -0700 Subject: [PATCH 12/59] Fix test --- src/module.cpp | 6 +++--- test/module_test.cpp | 22 ++++++++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 52d57782823..e2d40fecad1 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -453,7 +453,7 @@ module::insert_instructions(instruction_ref ins, module::inserter insert) { std::unordered_map default_map_ins; - return insert_generic_instructions( + return insert_generic_instructions_impl( *this, ins, instructions, map_ins ? *map_ins : default_map_ins, std::move(insert)); } @@ -464,7 +464,7 @@ module::insert_instructions(instruction_ref ins, module::inserter insert) { std::unordered_map default_map_ins; - return insert_generic_instructions( + return insert_generic_instructions_impl( *this, ins, iterator_for(*m), map_ins ? *map_ins : default_map_ins, std::move(insert)); } @@ -477,7 +477,7 @@ module::insert_instructions(instruction_ref ins, { auto r = range(start, last); std::unordered_map default_map_ins; - return insert_generic_instructions( + return insert_generic_instructions_impl( *this, ins, iterator_for(r), map_ins ? *map_ins : default_map_ins, std::move(insert)); } diff --git a/test/module_test.cpp b/test/module_test.cpp index a21e565a7b1..f6eeff69cb1 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -309,6 +309,20 @@ TEST_CASE(parameter_name_order) EXPECT(param_names == names1); } +struct map_ins +{ + map_ins(std::unordered_map x) + : m(std::move(x)) + {} + + operator std::unordered_map*() + { + return &m; + } + + std::unordered_map m; +} + TEST_CASE(insert_instructions_module) { migraphx::shape s{migraphx::shape::int32_type, {1}}; @@ -321,7 +335,7 @@ TEST_CASE(insert_instructions_module) auto x2 = m2.add_parameter("x2", s); m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.insert_instructions(sqrt, &m2, {{x2, x1}}); + m1.insert_instructions(sqrt, &m2, map_ins{{x2, x1}}); EXPECT(std::prev(sqrt)->name() == "sqrt"); EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) == @@ -343,7 +357,7 @@ TEST_CASE(add_instructions_module) auto x2 = m2.add_parameter("x2", s); m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.add_instructions(&m2, {{x2, x1}}); + m1.add_instructions(&m2, map_ins{{x2, x1}}); EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) == 2); @@ -364,7 +378,7 @@ TEST_CASE(add_instructions_range) auto x2 = m2.add_parameter("x2", s); auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.add_instructions(sqrt2, m2.end(), {{x2, x1}}); + m1.add_instructions(sqrt2, m2.end(), map_ins{{x2, x1}}); EXPECT(std::any_of( m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); })); @@ -387,7 +401,7 @@ TEST_CASE(add_instructions_vector) auto x2 = m2.add_parameter("x2", s); auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.add_instructions({sqrt2}, {{x2, x1}}); + m1.add_instructions({sqrt2}, map_ins{{x2, x1}}); EXPECT(std::any_of( m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); })); From b907737cfbb3a091c385c168f0dca6557f95bd6d Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 15:29:19 -0700 Subject: [PATCH 13/59] Format --- test/module_test.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index f6eeff69cb1..c8ae5ffe588 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -312,8 +312,9 @@ TEST_CASE(parameter_name_order) struct map_ins { map_ins(std::unordered_map x) - : m(std::move(x)) - {} + : m(std::move(x)) + { + } operator std::unordered_map*() { From 10265027e9c7e1b5fc461184b7de8b9bd46e6302 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 15:42:35 -0700 Subject: [PATCH 14/59] Fix tests --- src/module.cpp | 12 ++++++------ test/gpu/mlir.cpp | 2 +- test/module_test.cpp | 11 ++++++----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index e2d40fecad1..b2042c1c6e1 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -261,20 +261,20 @@ static std::vector insert_generic_instructions(module& m, instruction_ref ins, Range&& instructions, - std::unordered_map map_ins, + std::unordered_map& map_ins, module::inserter insert) { if(insert == nullptr) return insert_generic_instructions_impl(m, ins, static_cast(instructions), - std::move(map_ins), + map_ins, [](module& mm, auto&&... xs) { return mm.insert_instruction( std::forward(xs)...); }); return insert_generic_instructions_impl( - m, ins, static_cast(instructions), std::move(map_ins), insert); + m, ins, static_cast(instructions), map_ins, insert); } instruction_ref module::add_instruction(const operation& op, std::vector args) @@ -453,7 +453,7 @@ module::insert_instructions(instruction_ref ins, module::inserter insert) { std::unordered_map default_map_ins; - return insert_generic_instructions_impl( + return insert_generic_instructions( *this, ins, instructions, map_ins ? *map_ins : default_map_ins, std::move(insert)); } @@ -464,7 +464,7 @@ module::insert_instructions(instruction_ref ins, module::inserter insert) { std::unordered_map default_map_ins; - return insert_generic_instructions_impl( + return insert_generic_instructions( *this, ins, iterator_for(*m), map_ins ? *map_ins : default_map_ins, std::move(insert)); } @@ -477,7 +477,7 @@ module::insert_instructions(instruction_ref ins, { auto r = range(start, last); std::unordered_map default_map_ins; - return insert_generic_instructions_impl( + return insert_generic_instructions( *this, ins, iterator_for(r), map_ins ? *map_ins : default_map_ins, std::move(insert)); } diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index ae3726729d4..7ea6c137cad 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -79,7 +79,7 @@ migraphx::module create_mlir_submodule(const migraphx::module& mmlir) auto param = mmlir.get_parameter(name); map_ins[param] = m.add_parameter(name, param->get_shape().as_standard()); } - auto y = m.add_instructions(&mmlir, map_ins); + auto y = m.add_instructions(&mmlir, &map_ins); m.add_return(y); return m; } diff --git a/test/module_test.cpp b/test/module_test.cpp index c8ae5ffe588..14b63863a14 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -311,18 +311,19 @@ TEST_CASE(parameter_name_order) struct map_ins { - map_ins(std::unordered_map x) - : m(std::move(x)) + using type = std::unordered_map; + map_ins(std::initializer_list x) + : m(x) { } - operator std::unordered_map*() + operator type*() { return &m; } - std::unordered_map m; -} + type m; +}; TEST_CASE(insert_instructions_module) { From a45febeaeb455aa20abccbea991d0f1e612cc401 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 15:42:39 -0700 Subject: [PATCH 15/59] Format --- src/module.cpp | 12 ++++-------- test/module_test.cpp | 10 ++-------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index b2042c1c6e1..1d807e983ec 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -265,14 +265,10 @@ insert_generic_instructions(module& m, module::inserter insert) { if(insert == nullptr) - return insert_generic_instructions_impl(m, - ins, - static_cast(instructions), - map_ins, - [](module& mm, auto&&... xs) { - return mm.insert_instruction( - std::forward(xs)...); - }); + return insert_generic_instructions_impl( + m, ins, static_cast(instructions), map_ins, [](module& mm, auto&&... xs) { + return mm.insert_instruction(std::forward(xs)...); + }); return insert_generic_instructions_impl( m, ins, static_cast(instructions), map_ins, insert); } diff --git a/test/module_test.cpp b/test/module_test.cpp index 14b63863a14..d73b1afc7c3 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -312,15 +312,9 @@ TEST_CASE(parameter_name_order) struct map_ins { using type = std::unordered_map; - map_ins(std::initializer_list x) - : m(x) - { - } + map_ins(std::initializer_list x) : m(x) {} - operator type*() - { - return &m; - } + operator type*() { return &m; } type m; }; From 369c95228e3fcb6106252b69f2c8f10455d809cd Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 17:02:04 -0700 Subject: [PATCH 16/59] Some refactoring --- src/fuse_reduce.cpp | 19 +-------- src/include/migraphx/module.hpp | 2 + src/module.cpp | 42 ++++++++++++++++++-- src/split_reduce.cpp | 68 ++++++++++++++++----------------- 4 files changed, 75 insertions(+), 56 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index b33c65c0c0a..7f60d5ebe70 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -83,23 +83,6 @@ struct fused_reduce }; MIGRAPHX_REGISTER_OP(fused_reduce); -static std::unordered_map -get_ins_param_map(const std::vector& inputs, const_module_ref sm) -{ - std::unordered_map result; - auto names = sm->get_parameter_names(); - std::sort(names.begin(), names.end()); - assert(names.size() == inputs.size()); - std::transform(names.begin(), - names.end(), - inputs.begin(), - std::inserter(result, result.end()), - [&](const auto& name, auto input) { - return std::make_pair(input, sm->get_parameter(name)); - }); - return result; -} - static void insert_params(module_ref sm, const std::vector& inputs, std::unordered_map& map_ins) @@ -136,7 +119,7 @@ insert_module_in_submodule(module_ref sm, module::inserter insert = nullptr) { insert_params(sm, inputs, map_ins); - auto param_map = get_ins_param_map(inputs, m); + auto param_map = m->get_ins_param_map(inputs); for(auto&& [input, param] : param_map) { map_ins[param] = map_ins.at(input); diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index c0cf0868575..85622ab02cd 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -205,6 +205,8 @@ struct MIGRAPHX_EXPORT module void finalize(std::vector& contexts); + std::unordered_map get_ins_param_map(const std::vector& inputs, bool reverse = false) const; + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, diff --git a/src/module.cpp b/src/module.cpp index 1d807e983ec..48bdab891c5 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -265,10 +265,14 @@ insert_generic_instructions(module& m, module::inserter insert) { if(insert == nullptr) - return insert_generic_instructions_impl( - m, ins, static_cast(instructions), map_ins, [](module& mm, auto&&... xs) { - return mm.insert_instruction(std::forward(xs)...); - }); + return insert_generic_instructions_impl(m, + ins, + static_cast(instructions), + map_ins, + [](module& mm, auto&&... xs) { + return mm.insert_instruction( + std::forward(xs)...); + }); return insert_generic_instructions_impl( m, ins, static_cast(instructions), map_ins, insert); } @@ -731,6 +735,36 @@ void module::finalize(std::vector& contexts) << std::endl; } +std::unordered_map module::get_ins_param_map(const std::vector& inputs, bool reverse) const +{ + std::unordered_map result; + auto names = this->get_parameter_names(); + std::sort(names.begin(), names.end()); + assert(names.size() == inputs.size()); + if(reverse) + { + std::transform(names.begin(), + names.end(), + inputs.begin(), + std::inserter(result, result.end()), + [&](const auto& name, auto input) { + return std::make_pair(this->get_parameter(name), input); + }); + } + else + { + std::transform(names.begin(), + names.end(), + inputs.begin(), + std::inserter(result, result.end()), + [&](const auto& name, auto input) { + return std::make_pair(input, this->get_parameter(name)); + }); + + } + return result; +} + void module::debug_print() const { std::cout << *this << std::endl; } void module::debug_print(instruction_ref ins, diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 49db13de3c9..5623938fcbb 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -70,22 +71,42 @@ struct module_with_inputs { module mod; std::vector inputs; + void replace(instruction_ref ins, instruction_ref rep) + { + auto it = std::find(inputs.begin(), inputs.end(), ins); + if(it == inputs.end()) + return; + *it = rep; + } + void replace(const std::unordered_map& map_ins) + { + for(auto& ins:inputs) + { + if(not contains(map_ins, ins)) + continue; + ins = map_ins.at(ins); + } + } }; +static std::vector select_params(const std::vector& instructions, const std::unordered_map& param_map) +{ + std::vector result; + transform_if(instructions.begin(), instructions.end(), std::back_inserter(result), [&](instruction_ref ins) { return contains(param_map, ins); }, [&](instruction_ref ins) { + return param_map.at(ins); + }); + std::sort(result.begin(), result.end(), by(std::less<>{}, [](instruction_ref ins) { + const auto& param = any_cast(ins->get_operator()); + return param.parameter; + })); + return result; +} + static std::array split_module(module_ref m, const std::vector& splits, const std::vector& args) { - std::unordered_map param_map; - auto params = m->get_parameter_names(); - std::sort(params.begin(), params.end()); - std::transform(params.begin(), - params.end(), - args.begin(), - std::inserter(param_map, param_map.begin()), - [&](const std::string& name, instruction_ref arg) { - return std::make_pair(m->get_parameter(name), arg); - }); + std::unordered_map param_map = m->get_ins_param_map(args, true); std::unordered_set selected_instructions; fix([&](auto self, const std::vector& inputs) { @@ -107,13 +128,7 @@ static std::array split_module(module_ref m, instructions1.push_back(ins); } - std::vector inputs1; - for(auto ins : instructions1) - { - if(not contains(param_map, ins)) - continue; - inputs1.push_back(param_map[ins]); - } + std::vector inputs1 = select_params(instructions1, param_map); module m1; std::unordered_map map_ins1; m1.add_instructions(instructions1, &map_ins1); @@ -131,7 +146,6 @@ static std::array split_module(module_ref m, continue; // Input params can be used in both modules std::vector input_params; - // TODO: Use join_inserter std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(input_params), @@ -144,13 +158,7 @@ static std::array split_module(module_ref m, instructions2.push_back(ins); } - std::vector inputs2; - for(auto ins : instructions2) - { - if(not contains(param_map, ins)) - continue; - inputs2.push_back(param_map[ins]); - } + std::vector inputs2 = select_params(instructions2, param_map); module m2; std::unordered_map map_ins2; std::size_t n = 0; @@ -231,15 +239,7 @@ void split_reduce::apply(module_pass_manager& mpm) const auto param_names = m2->get_parameter_names(); std::sort(param_names.begin(), param_names.end()); - // TODO: Use get_ins_param_map function - std::unordered_map param_map; - std::transform(param_names.begin(), - param_names.end(), - inputs.begin(), - std::inserter(param_map, param_map.begin()), - [&](const std::string& name, instruction_ref input) { - return std::make_pair(m2->get_parameter(name), input); - }); + std::unordered_map param_map = m2->get_ins_param_map(inputs, true); auto replaced = mpm.get_module().insert_instructions(ins, m2, ¶m_map); assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); From ac0a922e968ac932ae2e829e410a41f233e33e29 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 18 Mar 2024 17:02:16 -0700 Subject: [PATCH 17/59] Format --- src/include/migraphx/module.hpp | 3 ++- src/module.cpp | 16 ++++++---------- src/split_reduce.cpp | 27 +++++++++++++++++---------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 85622ab02cd..3e77378ac49 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -205,7 +205,8 @@ struct MIGRAPHX_EXPORT module void finalize(std::vector& contexts); - std::unordered_map get_ins_param_map(const std::vector& inputs, bool reverse = false) const; + std::unordered_map + get_ins_param_map(const std::vector& inputs, bool reverse = false) const; void debug_print() const; void debug_print(instruction_ref ins) const; diff --git a/src/module.cpp b/src/module.cpp index 48bdab891c5..f1b3f616251 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -265,14 +265,10 @@ insert_generic_instructions(module& m, module::inserter insert) { if(insert == nullptr) - return insert_generic_instructions_impl(m, - ins, - static_cast(instructions), - map_ins, - [](module& mm, auto&&... xs) { - return mm.insert_instruction( - std::forward(xs)...); - }); + return insert_generic_instructions_impl( + m, ins, static_cast(instructions), map_ins, [](module& mm, auto&&... xs) { + return mm.insert_instruction(std::forward(xs)...); + }); return insert_generic_instructions_impl( m, ins, static_cast(instructions), map_ins, insert); } @@ -735,7 +731,8 @@ void module::finalize(std::vector& contexts) << std::endl; } -std::unordered_map module::get_ins_param_map(const std::vector& inputs, bool reverse) const +std::unordered_map +module::get_ins_param_map(const std::vector& inputs, bool reverse) const { std::unordered_map result; auto names = this->get_parameter_names(); @@ -760,7 +757,6 @@ std::unordered_map module::get_ins_param_map(c [&](const auto& name, auto input) { return std::make_pair(input, this->get_parameter(name)); }); - } return result; } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 5623938fcbb..30ade37794f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -80,7 +80,7 @@ struct module_with_inputs } void replace(const std::unordered_map& map_ins) { - for(auto& ins:inputs) + for(auto& ins : inputs) { if(not contains(map_ins, ins)) continue; @@ -89,16 +89,21 @@ struct module_with_inputs } }; -static std::vector select_params(const std::vector& instructions, const std::unordered_map& param_map) +static std::vector +select_params(const std::vector& instructions, + const std::unordered_map& param_map) { std::vector result; - transform_if(instructions.begin(), instructions.end(), std::back_inserter(result), [&](instruction_ref ins) { return contains(param_map, ins); }, [&](instruction_ref ins) { - return param_map.at(ins); - }); + transform_if( + instructions.begin(), + instructions.end(), + std::back_inserter(result), + [&](instruction_ref ins) { return contains(param_map, ins); }, + [&](instruction_ref ins) { return param_map.at(ins); }); std::sort(result.begin(), result.end(), by(std::less<>{}, [](instruction_ref ins) { - const auto& param = any_cast(ins->get_operator()); - return param.parameter; - })); + const auto& param = any_cast(ins->get_operator()); + return param.parameter; + })); return result; } @@ -106,7 +111,8 @@ static std::array split_module(module_ref m, const std::vector& splits, const std::vector& args) { - std::unordered_map param_map = m->get_ins_param_map(args, true); + std::unordered_map param_map = + m->get_ins_param_map(args, true); std::unordered_set selected_instructions; fix([&](auto self, const std::vector& inputs) { @@ -239,7 +245,8 @@ void split_reduce::apply(module_pass_manager& mpm) const auto param_names = m2->get_parameter_names(); std::sort(param_names.begin(), param_names.end()); - std::unordered_map param_map = m2->get_ins_param_map(inputs, true); + std::unordered_map param_map = + m2->get_ins_param_map(inputs, true); auto replaced = mpm.get_module().insert_instructions(ins, m2, ¶m_map); assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); From 02953fed909f54c59320624b8556f1b3b91a2a16 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 11:37:10 -0700 Subject: [PATCH 18/59] Some more refactoring --- src/CMakeLists.txt | 1 + src/include/migraphx/module.hpp | 2 ++ src/module.cpp | 33 ++++++++++++++++++++++----------- src/split_reduce.cpp | 18 +++++------------- 4 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c4b46308586..f269aa8cb38 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -77,6 +77,7 @@ add_library(migraphx operation.cpp optimize_module.cpp pad_calc.cpp + param_utils.cpp pass.cpp pass_manager.cpp permutation.cpp diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 3e77378ac49..035372c5af4 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -185,6 +185,8 @@ struct MIGRAPHX_EXPORT module shape get_parameter_shape(std::string name) const; instruction_ref get_parameter(std::string name) const; + + std::vector get_parameters() const; void rename_parameter(instruction_ref ins, const std::string& name); diff --git a/src/module.cpp b/src/module.cpp index f1b3f616251..5cc2bc178e8 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -588,6 +589,16 @@ instruction_ref module::get_parameter(std::string name) const return this->end(); } +std::vector module::get_parameters() const +{ + std::vector result; + auto refs = iterator_for(*this); + std::copy_if(refs.begin(), refs.end(), std::back_inserter(result), [&](instruction_ref ins) { + return ins->name() == "@param"; + }); + return result; +} + void module::rename_parameter(instruction_ref ins, const std::string& name) { assert(ins->name() == "@param"); @@ -735,27 +746,27 @@ std::unordered_map module::get_ins_param_map(const std::vector& inputs, bool reverse) const { std::unordered_map result; - auto names = this->get_parameter_names(); - std::sort(names.begin(), names.end()); - assert(names.size() == inputs.size()); + auto params = this->get_parameters(); + assert(params.size() == inputs.size()); + sort_params(params); if(reverse) { - std::transform(names.begin(), - names.end(), + std::transform(params.begin(), + params.end(), inputs.begin(), std::inserter(result, result.end()), - [&](const auto& name, auto input) { - return std::make_pair(this->get_parameter(name), input); + [&](instruction_ref param, auto input) { + return std::make_pair(param, input); }); } else { - std::transform(names.begin(), - names.end(), + std::transform(params.begin(), + params.end(), inputs.begin(), std::inserter(result, result.end()), - [&](const auto& name, auto input) { - return std::make_pair(input, this->get_parameter(name)); + [&](instruction_ref param, auto input) { + return std::make_pair(input, param); }); } return result; diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 30ade37794f..3230d56de09 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -62,11 +63,6 @@ MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } -static std::string param_name(std::size_t i, const std::string& prefix = "x") -{ - return prefix + std::to_string(i); -} - struct module_with_inputs { module mod; @@ -100,10 +96,7 @@ select_params(const std::vector& instructions, std::back_inserter(result), [&](instruction_ref ins) { return contains(param_map, ins); }, [&](instruction_ref ins) { return param_map.at(ins); }); - std::sort(result.begin(), result.end(), by(std::less<>{}, [](instruction_ref ins) { - const auto& param = any_cast(ins->get_operator()); - return param.parameter; - })); + sort_params(result); return result; } @@ -165,6 +158,7 @@ static std::array split_module(module_ref m, } std::vector inputs2 = select_params(instructions2, param_map); + inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); module m2; std::unordered_map map_ins2; std::size_t n = 0; @@ -240,13 +234,11 @@ void split_reduce::apply(module_pass_manager& mpm) const mp[0].inputs, {m1}); + mp[1].replace(splits.front(), split_reduce); std::vector inputs = {split_reduce}; inputs.insert(inputs.end(), mp[1].inputs.begin(), mp[1].inputs.end()); - auto param_names = m2->get_parameter_names(); - std::sort(param_names.begin(), param_names.end()); - std::unordered_map param_map = - m2->get_ins_param_map(inputs, true); + m2->get_ins_param_map(mp[1].inputs, true); auto replaced = mpm.get_module().insert_instructions(ins, m2, ¶m_map); assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); From ce2a336df83fba1ae80083b79f92d0a9ab9c9beb Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 11:37:17 -0700 Subject: [PATCH 19/59] Format --- src/include/migraphx/module.hpp | 2 +- src/module.cpp | 26 ++++++++++++-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 035372c5af4..7eff8a93969 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -185,7 +185,7 @@ struct MIGRAPHX_EXPORT module shape get_parameter_shape(std::string name) const; instruction_ref get_parameter(std::string name) const; - + std::vector get_parameters() const; void rename_parameter(instruction_ref ins, const std::string& name); diff --git a/src/module.cpp b/src/module.cpp index 5cc2bc178e8..f0e47754c98 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -751,23 +751,21 @@ module::get_ins_param_map(const std::vector& inputs, bool rever sort_params(params); if(reverse) { - std::transform(params.begin(), - params.end(), - inputs.begin(), - std::inserter(result, result.end()), - [&](instruction_ref param, auto input) { - return std::make_pair(param, input); - }); + std::transform( + params.begin(), + params.end(), + inputs.begin(), + std::inserter(result, result.end()), + [&](instruction_ref param, auto input) { return std::make_pair(param, input); }); } else { - std::transform(params.begin(), - params.end(), - inputs.begin(), - std::inserter(result, result.end()), - [&](instruction_ref param, auto input) { - return std::make_pair(input, param); - }); + std::transform( + params.begin(), + params.end(), + inputs.begin(), + std::inserter(result, result.end()), + [&](instruction_ref param, auto input) { return std::make_pair(input, param); }); } return result; } From be217d25424b7d440981a25043f02592c6d4945b Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 12:14:36 -0700 Subject: [PATCH 20/59] Move split to module class --- src/include/migraphx/module.hpp | 16 ++++- src/module.cpp | 108 +++++++++++++++++++++++++++++ src/split_reduce.cpp | 116 +------------------------------- 3 files changed, 124 insertions(+), 116 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 7eff8a93969..6d26d8e0a4f 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -50,6 +50,8 @@ struct module_impl; using parameter_map = std::unordered_map; using ins_dep_map = std::unordered_map>; +struct module_with_inputs; + /** * @brief Stores the instruction stream */ @@ -185,7 +187,7 @@ struct MIGRAPHX_EXPORT module shape get_parameter_shape(std::string name) const; instruction_ref get_parameter(std::string name) const; - + std::vector get_parameters() const; void rename_parameter(instruction_ref ins, const std::string& name); @@ -210,6 +212,10 @@ struct MIGRAPHX_EXPORT module std::unordered_map get_ins_param_map(const std::vector& inputs, bool reverse = false) const; + using with_inputs = module_with_inputs; + + std::array split(const std::vector& args, const std::vector& splits) const; + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, @@ -271,6 +277,14 @@ struct MIGRAPHX_EXPORT module std::unique_ptr impl; }; +struct module_with_inputs +{ + module mod; + std::vector inputs; + void replace(instruction_ref ins, instruction_ref rep); + void replace(const std::unordered_map& map_ins); +}; + inline module& get_module(module& m) { return m; } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/module.cpp b/src/module.cpp index f0e47754c98..2a327a818f9 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ #include +#include #include #include #include @@ -770,6 +771,113 @@ module::get_ins_param_map(const std::vector& inputs, bool rever return result; } +static std::vector +select_params(const std::vector& instructions, + const std::unordered_map& param_map) +{ + std::vector result; + transform_if( + instructions.begin(), + instructions.end(), + std::back_inserter(result), + [&](instruction_ref ins) { return contains(param_map, ins); }, + [&](instruction_ref ins) { return param_map.at(ins); }); + sort_params(result); + return result; +} + +std::array module::split(const std::vector& args, const std::vector& splits) const +{ + std::unordered_map param_map = + this->get_ins_param_map(args, true); + + std::unordered_set selected_instructions; + fix([&](auto self, const std::vector& inputs) { + for(auto input : inputs) + { + if(contains(selected_instructions, input)) + continue; + selected_instructions.insert(input); + self(input->inputs()); + } + })(splits); + + std::vector instructions1; + // TODO: copy_if + for(auto ins : iterator_for(*this)) + { + if(not contains(selected_instructions, ins)) + continue; + instructions1.push_back(ins); + } + + std::vector inputs1 = select_params(instructions1, param_map); + module m1; + std::unordered_map map_ins1; + m1.add_instructions(instructions1, &map_ins1); + std::vector outputs; + std::transform(splits.begin(), + splits.end(), + std::back_inserter(outputs), + [&](instruction_ref ins) { return map_ins1.at(ins); }); + m1.add_return(outputs); + + std::vector instructions2; + for(auto ins : iterator_for(*this)) + { + if(contains(selected_instructions, ins)) + continue; + // Input params can be used in both modules + std::vector input_params; + std::copy_if(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(input_params), + [&](instruction_ref input) { + if(input->name() != "@param") + return false; + return not contains(instructions2, input); + }); + instructions2.insert(instructions2.end(), input_params.begin(), input_params.end()); + instructions2.push_back(ins); + } + + std::vector inputs2 = select_params(instructions2, param_map); + inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); + module m2; + std::unordered_map map_ins2; + std::size_t n = 0; + for(auto ins : splits) + map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); + for(auto ins : iterator_for(*this)) + { + if(ins->name() != "@param") + continue; + if(not contains(instructions2, ins)) + continue; + map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); + } + auto r = m2.add_instructions(instructions2, &map_ins2); + m2.add_return(r); + return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; +} + +void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) +{ + auto it = std::find(inputs.begin(), inputs.end(), ins); + if(it == inputs.end()) + return; + *it = rep; +} +void module_with_inputs::replace(const std::unordered_map& map_ins) +{ + for(auto& ins : inputs) + { + if(not contains(map_ins, ins)) + continue; + ins = map_ins.at(ins); + } +} + void module::debug_print() const { std::cout << *this << std::endl; } void module::debug_print(instruction_ref ins, diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 3230d56de09..e8908a52c79 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -63,120 +63,6 @@ MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } -struct module_with_inputs -{ - module mod; - std::vector inputs; - void replace(instruction_ref ins, instruction_ref rep) - { - auto it = std::find(inputs.begin(), inputs.end(), ins); - if(it == inputs.end()) - return; - *it = rep; - } - void replace(const std::unordered_map& map_ins) - { - for(auto& ins : inputs) - { - if(not contains(map_ins, ins)) - continue; - ins = map_ins.at(ins); - } - } -}; - -static std::vector -select_params(const std::vector& instructions, - const std::unordered_map& param_map) -{ - std::vector result; - transform_if( - instructions.begin(), - instructions.end(), - std::back_inserter(result), - [&](instruction_ref ins) { return contains(param_map, ins); }, - [&](instruction_ref ins) { return param_map.at(ins); }); - sort_params(result); - return result; -} - -static std::array split_module(module_ref m, - const std::vector& splits, - const std::vector& args) -{ - std::unordered_map param_map = - m->get_ins_param_map(args, true); - - std::unordered_set selected_instructions; - fix([&](auto self, const std::vector& inputs) { - for(auto input : inputs) - { - if(contains(selected_instructions, input)) - continue; - selected_instructions.insert(input); - self(input->inputs()); - } - })(splits); - - std::vector instructions1; - // TODO: copy_if - for(auto ins : iterator_for(*m)) - { - if(not contains(selected_instructions, ins)) - continue; - instructions1.push_back(ins); - } - - std::vector inputs1 = select_params(instructions1, param_map); - module m1; - std::unordered_map map_ins1; - m1.add_instructions(instructions1, &map_ins1); - std::vector outputs; - std::transform(splits.begin(), - splits.end(), - std::back_inserter(outputs), - [&](instruction_ref ins) { return map_ins1.at(ins); }); - m1.add_return(outputs); - - std::vector instructions2; - for(auto ins : iterator_for(*m)) - { - if(contains(selected_instructions, ins)) - continue; - // Input params can be used in both modules - std::vector input_params; - std::copy_if(ins->inputs().begin(), - ins->inputs().end(), - std::back_inserter(input_params), - [&](instruction_ref input) { - if(input->name() != "@param") - return false; - return not contains(instructions2, input); - }); - instructions2.insert(instructions2.end(), input_params.begin(), input_params.end()); - instructions2.push_back(ins); - } - - std::vector inputs2 = select_params(instructions2, param_map); - inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); - module m2; - std::unordered_map map_ins2; - std::size_t n = 0; - for(auto ins : splits) - map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); - for(auto ins : iterator_for(*m)) - { - if(ins->name() != "@param") - continue; - if(not contains(instructions2, ins)) - continue; - map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); - } - auto r = m2.add_instructions(instructions2, &map_ins2); - m2.add_return(r); - return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; -} - static std::vector find_split(module_ref rm) { std::vector result; @@ -221,7 +107,7 @@ void split_reduce::apply(module_pass_manager& mpm) const auto axes = v["axes"].to_vector(); // TODO: Check reduction size - auto mp = split_module(rm, splits, ins->inputs()); + auto mp = rm->split(ins->inputs(), splits); auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp[0].mod)); auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp[1].mod)); m1->set_bypass(); From 4dbd08a94d26204e609747b6232eef8d2a544eb5 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 12:14:44 -0700 Subject: [PATCH 21/59] Format --- src/include/migraphx/module.hpp | 5 +++-- src/module.cpp | 6 ++++-- src/split_reduce.cpp | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 6d26d8e0a4f..c49fbbab9e1 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -187,7 +187,7 @@ struct MIGRAPHX_EXPORT module shape get_parameter_shape(std::string name) const; instruction_ref get_parameter(std::string name) const; - + std::vector get_parameters() const; void rename_parameter(instruction_ref ins, const std::string& name); @@ -214,7 +214,8 @@ struct MIGRAPHX_EXPORT module using with_inputs = module_with_inputs; - std::array split(const std::vector& args, const std::vector& splits) const; + std::array split(const std::vector& args, + const std::vector& splits) const; void debug_print() const; void debug_print(instruction_ref ins) const; diff --git a/src/module.cpp b/src/module.cpp index 2a327a818f9..1408f456809 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -786,7 +786,8 @@ select_params(const std::vector& instructions, return result; } -std::array module::split(const std::vector& args, const std::vector& splits) const +std::array module::split(const std::vector& args, + const std::vector& splits) const { std::unordered_map param_map = this->get_ins_param_map(args, true); @@ -868,7 +869,8 @@ void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) return; *it = rep; } -void module_with_inputs::replace(const std::unordered_map& map_ins) +void module_with_inputs::replace( + const std::unordered_map& map_ins) { for(auto& ins : inputs) { diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index e8908a52c79..8648b008f12 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -107,7 +107,7 @@ void split_reduce::apply(module_pass_manager& mpm) const auto axes = v["axes"].to_vector(); // TODO: Check reduction size - auto mp = rm->split(ins->inputs(), splits); + auto mp = rm->split(ins->inputs(), splits); auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp[0].mod)); auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp[1].mod)); m1->set_bypass(); From 88087f286fab2dd70beebd13fb283bc0245fb37f Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 12:57:56 -0700 Subject: [PATCH 22/59] Add missing files --- src/include/migraphx/param_utils.hpp | 18 ++++++++++++++++++ src/param_utils.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 src/include/migraphx/param_utils.hpp create mode 100644 src/param_utils.cpp diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp new file mode 100644 index 00000000000..153f4f2ac86 --- /dev/null +++ b/src/include/migraphx/param_utils.hpp @@ -0,0 +1,18 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::string param_name(std::size_t i, const std::string& prefix = "x"); + +void sort_params(std::vector& params); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/param_utils.cpp b/src/param_utils.cpp new file mode 100644 index 00000000000..cc181c9d90b --- /dev/null +++ b/src/param_utils.cpp @@ -0,0 +1,24 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::string param_name(std::size_t i, const std::string& prefix) +{ + return prefix + std::to_string(i); +} + +void sort_params(std::vector& params) +{ + std::sort(params.begin(), params.end(), by(std::less<>{}, [](instruction_ref ins) { + const auto& param = any_cast(ins->get_operator()); + return param.parameter; + })); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + + From b1cc070d071b45c9424cd026e8fcd21566d23909 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 12:57:59 -0700 Subject: [PATCH 23/59] Format --- src/param_utils.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/param_utils.cpp b/src/param_utils.cpp index cc181c9d90b..8cf7be6ef80 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -20,5 +20,3 @@ void sort_params(std::vector& params) } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx - - From d01103d9f228b05018b8d5822ab49c7f4ef7df4d Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 13:06:50 -0700 Subject: [PATCH 24/59] Add test case --- test/split_reduce.cpp | 106 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 test/split_reduce.cpp diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp new file mode 100644 index 00000000000..653c0c07b94 --- /dev/null +++ b/test/split_reduce.cpp @@ -0,0 +1,106 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +void run_pass(migraphx::program& p) +{ + migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::split_reduce{}, migraphx::dead_code_elimination{}}); +} + +bool all_instructions_are_local(const migraphx::module& m) +{ + return std::all_of(m.begin(), m.end(), [&](const auto& ins) { + return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { + return m.has_instruction(input); + }); + }); +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + const std::string& assign, + F f) +{ + auto* rm = p.create_module(name); + auto* mm = p.get_main_module(); + rm->set_bypass(); + std::vector params; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { + return rm->add_parameter( + "x" + std::to_string(params.size()), + migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + }); + auto r = f(rm, params, axes); + rm->add_return({r}); + EXPECT(all_instructions_are_local(*rm)); + return mm->add_instruction(migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), inputs, {rm}); +} + +inline auto single_reduce(const std::string& name) +{ + return [=](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs); + }; +} + +TEST_CASE(single) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = mm->add_instruction(migraphx::make_op("add"), x, rsumb); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce(p2, "main:reduce_sum0:main:pointwise0_0", {x}, {2}, "assign_add", single_reduce("reduce_sum")); + auto rsumb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = add_pointwise(p2, mm, "main:pointwise0", {x, rsumb}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 792062cb5f3b8e7adffeb3bec01d118237155fe9 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 13:06:58 -0700 Subject: [PATCH 25/59] Format --- test/split_reduce.cpp | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 653c0c07b94..db198fc5b04 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -36,7 +36,11 @@ void run_pass(migraphx::program& p) { - migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::split_reduce{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(p, + {migraphx::fuse_pointwise{}, + migraphx::fuse_reduce{}, + migraphx::split_reduce{}, + migraphx::dead_code_elimination{}}); } bool all_instructions_are_local(const migraphx::module& m) @@ -68,7 +72,10 @@ migraphx::instruction_ref add_reduce(migraphx::program& p, auto r = f(rm, params, axes); rm->add_return({r}); EXPECT(all_instructions_are_local(*rm)); - return mm->add_instruction(migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), inputs, {rm}); + return mm->add_instruction( + migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), + inputs, + {rm}); } inline auto single_reduce(const std::string& name) @@ -85,8 +92,9 @@ TEST_CASE(single) { auto* mm = p1.get_main_module(); auto x = mm->add_parameter("x", s); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); - auto rsumb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); auto add = mm->add_instruction(migraphx::make_op("add"), x, rsumb); mm->add_return({add}); } @@ -95,8 +103,14 @@ TEST_CASE(single) { auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); - auto rsum = add_reduce(p2, "main:reduce_sum0:main:pointwise0_0", {x}, {2}, "assign_add", single_reduce("reduce_sum")); - auto rsumb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto rsum = add_reduce(p2, + "main:reduce_sum0:main:pointwise0_0", + {x}, + {2}, + "assign_add", + single_reduce("reduce_sum")); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); auto add = add_pointwise(p2, mm, "main:pointwise0", {x, rsumb}, single_pointwise("add")); mm->add_return({add}); } From 917695f6e452b300510391c81409bbbfcbf6ce89 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 13:23:23 -0700 Subject: [PATCH 26/59] Move liveness --- src/include/migraphx/liveness.hpp | 53 +++++++++++++++++++++++++++++++ src/memory_coloring.cpp | 37 +-------------------- 2 files changed, 54 insertions(+), 36 deletions(-) create mode 100644 src/include/migraphx/liveness.hpp diff --git a/src/include/migraphx/liveness.hpp b/src/include/migraphx/liveness.hpp new file mode 100644 index 00000000000..1b8761d9e7d --- /dev/null +++ b/src/include/migraphx/liveness.hpp @@ -0,0 +1,53 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// This will do liveness analysis on the module, and it will call the +// function `f` with the instruction and the set of the other instructions +// that are live +template +void liveness(const module& m, F f) +{ + auto implicit_deps = m.calc_implicit_deps(); + std::unordered_set live_set; + auto rp = reverse(m); + for(auto rins : iterator_for(rp)) // NOLINT + { + // The base iterator is one ahead, so we need to use the previous iterator + auto ins = std::prev(rins.base()); + // Add live variables + auto add_live_variables = [&](const auto& inputs) { + for(auto input : inputs) + { + auto i = instruction::get_output_alias(input); + // Skip if variable comes from parent + if(not m.has_instruction(i)) + continue; + live_set.insert(i); + } + }; + add_live_variables(ins->inputs()); + add_live_variables(implicit_deps[ins]); + // Remove last usage + auto it = live_set.find(ins); + if(it != live_set.end()) + { + live_set.erase(it); + f(ins, live_set); + } + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP diff --git a/src/memory_coloring.cpp b/src/memory_coloring.cpp index 733e39d28d2..649d9f5acc5 100644 --- a/src/memory_coloring.cpp +++ b/src/memory_coloring.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -43,42 +44,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING); using instruction_set = std::unordered_set; using instruction_set_map = std::unordered_map; -// This will do liveness analysis on the module, and it will call the -// function `f` with the instruction and the set of the other instructions -// that are live -template -void liveness(const module& m, F f) -{ - auto implicit_deps = m.calc_implicit_deps(); - instruction_set live_set; - auto rp = reverse(m); - for(auto rins : iterator_for(rp)) // NOLINT - { - // The base iterator is one ahead, so we need to use the previous iterator - auto ins = std::prev(rins.base()); - // Add live variables - auto add_live_variables = [&](const auto& inputs) { - for(auto input : inputs) - { - auto i = instruction::get_output_alias(input); - // Skip if variable comes from parent - if(not m.has_instruction(i)) - continue; - live_set.insert(i); - } - }; - add_live_variables(ins->inputs()); - add_live_variables(implicit_deps[ins]); - // Remove last usage - auto it = live_set.find(ins); - if(it != live_set.end()) - { - live_set.erase(it); - f(ins, live_set); - } - } -} - // This will build the conflict table or interference graph. This is // essentially a map from one instruction to a set of instruction that are // used together. Each instruction will be the allocation instruction. From 1f172c86a95b25be16d1435fafff37e8a0a0c533 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 16:47:32 -0700 Subject: [PATCH 27/59] Seperate previous pointwise module if its used again --- src/include/migraphx/module.hpp | 6 ++++ src/module.cpp | 60 +++++++++++++++++++++++++++---- src/split_reduce.cpp | 63 +++++++++++++++++++++++++-------- test/split_reduce.cpp | 36 ++++++++++++++++++- 4 files changed, 143 insertions(+), 22 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index c49fbbab9e1..0e6bf9603af 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -217,6 +217,10 @@ struct MIGRAPHX_EXPORT module std::array split(const std::vector& args, const std::vector& splits) const; + std::array split(const std::vector& args, + const std::vector& splits1, + const std::vector& splits2) const; + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, @@ -284,6 +288,8 @@ struct module_with_inputs std::vector inputs; void replace(instruction_ref ins, instruction_ref rep); void replace(const std::unordered_map& map_ins); + void replace( + const std::vector& keys, const std::vector& values); }; inline module& get_module(module& m) { return m; } diff --git a/src/module.cpp b/src/module.cpp index 1408f456809..50cc048eaa9 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -786,11 +786,12 @@ select_params(const std::vector& instructions, return result; } -std::array module::split(const std::vector& args, - const std::vector& splits) const +static std::array generic_split(const module& m, const std::vector& args, + const std::vector& splits, + std::unordered_map* map_ins = nullptr) { std::unordered_map param_map = - this->get_ins_param_map(args, true); + m.get_ins_param_map(args, true); std::unordered_set selected_instructions; fix([&](auto self, const std::vector& inputs) { @@ -805,7 +806,7 @@ std::array module::split(const std::vector instructions1; // TODO: copy_if - for(auto ins : iterator_for(*this)) + for(auto ins : iterator_for(m)) { if(not contains(selected_instructions, ins)) continue; @@ -824,7 +825,7 @@ std::array module::split(const std::vector instructions2; - for(auto ins : iterator_for(*this)) + for(auto ins : iterator_for(m)) { if(contains(selected_instructions, ins)) continue; @@ -845,11 +846,11 @@ std::array module::split(const std::vector inputs2 = select_params(instructions2, param_map); inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); module m2; - std::unordered_map map_ins2; std::size_t n = 0; + std::unordered_map map_ins2; for(auto ins : splits) map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); - for(auto ins : iterator_for(*this)) + for(auto ins : iterator_for(m)) { if(ins->name() != "@param") continue; @@ -859,9 +860,43 @@ std::array module::split(const std::vector module::split(const std::vector& args, + const std::vector& splits) const +{ + return generic_split(*this, args, splits); +} + +std::array module::split(const std::vector& args, + const std::vector& splits1, + const std::vector& splits2) const +{ + std::unordered_map map_ins; + auto mods1 = generic_split(*this, args, splits1, &map_ins); + + assert(all_of(mods1[0].inputs, [&](auto ins) { return contains(args, ins); })); + assert(all_of(mods1[1].inputs, [&](auto ins) { return contains(args, ins) or contains(splits1, ins); })); + + std::vector new_splits2; + std::transform(splits2.begin(), splits2.end(), std::back_inserter(new_splits2), [&](auto ins) { + return map_ins.at(ins); + }); + + auto mods2 = mods1[1].mod.split(mods1[1].inputs, new_splits2); + // Replace new splits with old splits + mods2[1].replace(new_splits2, splits2); + + assert(all_of(mods2[0].inputs, [&](auto ins) { return contains(args, ins) or contains(splits1, ins); })); + assert(all_of(mods2[1].inputs, [&](auto ins) { return contains(args, ins) or contains(splits1, ins) or contains(splits2, ins); })); + + return {{std::move(mods1[0]), std::move(mods2[0]), std::move(mods2[1])}}; +} + void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) { auto it = std::find(inputs.begin(), inputs.end(), ins); @@ -879,6 +914,17 @@ void module_with_inputs::replace( ins = map_ins.at(ins); } } +void module_with_inputs::replace( + const std::vector& keys, const std::vector& values) +{ + for(auto& ins : inputs) + { + auto it = std::find(keys.begin(), keys.end(), ins); + if(it == keys.end()) + continue; + ins = values[it - keys.begin()]; + } +} void module::debug_print() const { std::cout << *this << std::endl; } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 8648b008f12..5a46aa7fba3 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -77,6 +78,23 @@ static std::vector find_split(module_ref rm) return result; } +static std::vector get_alive(module_ref rm, const std::vector& splits) +{ + std::vector result; + bool stop = false; + liveness(*rm, [&](auto ins, const auto& live_set) { + if(stop) + return; + if(not contains(splits, ins)) + return; + std::copy_if(live_set.begin(), live_set.end(), std::back_inserter(result), [](instruction_ref live) { + return live->name() != "@param"; + }); + stop = true; + }); + return result; +} + static std::string assign_op(const std::vector& splits) { static std::unordered_map m = { @@ -89,6 +107,13 @@ static std::string assign_op(const std::vector& splits) return m.at(splits.front()->name()); } +static std::vector insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& mwi) +{ + auto param_map = + mwi.mod.get_ins_param_map(mwi.inputs, true); + return m.insert_instructions(ins, &mwi.mod, ¶m_map); +} + void split_reduce::apply(module_pass_manager& mpm) const { for(auto ins : iterator_for(mpm.get_module())) @@ -99,6 +124,7 @@ void split_reduce::apply(module_pass_manager& mpm) const auto splits = find_split(rm); if(splits.empty()) continue; + // Only use split reduce with float for now if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { return split->get_shape().type() == shape::float_type; })) @@ -107,25 +133,34 @@ void split_reduce::apply(module_pass_manager& mpm) const auto axes = v["axes"].to_vector(); // TODO: Check reduction size - auto mp = rm->split(ins->inputs(), splits); - auto* m1 = mpm.create_module(rm->name() + "_0", std::move(mp[0].mod)); - auto* m2 = mpm.create_module(rm->name() + "_1", std::move(mp[1].mod)); - m1->set_bypass(); - m2->set_bypass(); + auto alive = get_alive(rm, splits); + + std::array mods; + if(not alive.empty()) + { + auto mods3 = rm->split(ins->inputs(), alive, splits); + auto r = insert_module_inline(mpm.get_module(), ins, mods3[0]); + mods3[1].replace(alive, r); + mods3[2].replace(alive, r); + mods = {std::move(mods3[1]), std::move(mods3[2])}; + } + else + { + mods = rm->split(ins->inputs(), splits); + } + + auto* splitm = mpm.create_module(rm->name() + "_split", std::move(mods[0].mod)); + splitm->set_bypass(); // Insert split reduce auto split_reduce = mpm.get_module().insert_instruction( ins, make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}), - mp[0].inputs, - {m1}); - - mp[1].replace(splits.front(), split_reduce); - std::vector inputs = {split_reduce}; - inputs.insert(inputs.end(), mp[1].inputs.begin(), mp[1].inputs.end()); - std::unordered_map param_map = - m2->get_ins_param_map(mp[1].inputs, true); - auto replaced = mpm.get_module().insert_instructions(ins, m2, ¶m_map); + mods[0].inputs, + {splitm}); + + mods[1].replace(splits.front(), split_reduce); + auto replaced = insert_module_inline(mpm.get_module(), ins, mods[1]); assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); } diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index db198fc5b04..7bbd3588630 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -104,7 +104,7 @@ TEST_CASE(single) auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); auto rsum = add_reduce(p2, - "main:reduce_sum0:main:pointwise0_0", + "main:reduce_sum0:main:pointwise0_split", {x}, {2}, "assign_add", @@ -117,4 +117,38 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(split_pointwise) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sqrt); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = mm->add_instruction(migraphx::make_op("add"), sqrt, rsumb); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto sqrt = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); + auto rsum = add_reduce(p2, + "main:pointwise0:main:reduce_sum0:main:pointwise1_split", + {sqrt}, + {2}, + "assign_add", + single_reduce("reduce_sum")); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = add_pointwise(p2, mm, "main:pointwise1", {sqrt, rsumb}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From d9c2c9aa67543b016cbb074ece780cd591ff054f Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 16:47:39 -0700 Subject: [PATCH 28/59] Format --- src/include/migraphx/module.hpp | 4 ++-- src/module.cpp | 29 +++++++++++++++++------------ src/split_reduce.cpp | 20 +++++++++++--------- test/split_reduce.cpp | 4 ++-- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 0e6bf9603af..7d79f318ba7 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -288,8 +288,8 @@ struct module_with_inputs std::vector inputs; void replace(instruction_ref ins, instruction_ref rep); void replace(const std::unordered_map& map_ins); - void replace( - const std::vector& keys, const std::vector& values); + void replace(const std::vector& keys, + const std::vector& values); }; inline module& get_module(module& m) { return m; } diff --git a/src/module.cpp b/src/module.cpp index 50cc048eaa9..9dda3620f9c 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -786,9 +786,11 @@ select_params(const std::vector& instructions, return result; } -static std::array generic_split(const module& m, const std::vector& args, - const std::vector& splits, - std::unordered_map* map_ins = nullptr) +static std::array +generic_split(const module& m, + const std::vector& args, + const std::vector& splits, + std::unordered_map* map_ins = nullptr) { std::unordered_map param_map = m.get_ins_param_map(args, true); @@ -865,7 +867,6 @@ static std::array generic_split(const module& m, const s return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; } - std::array module::split(const std::vector& args, const std::vector& splits) const { @@ -873,14 +874,15 @@ std::array module::split(const std::vector module::split(const std::vector& args, - const std::vector& splits1, - const std::vector& splits2) const + const std::vector& splits1, + const std::vector& splits2) const { std::unordered_map map_ins; auto mods1 = generic_split(*this, args, splits1, &map_ins); - + assert(all_of(mods1[0].inputs, [&](auto ins) { return contains(args, ins); })); - assert(all_of(mods1[1].inputs, [&](auto ins) { return contains(args, ins) or contains(splits1, ins); })); + assert(all_of(mods1[1].inputs, + [&](auto ins) { return contains(args, ins) or contains(splits1, ins); })); std::vector new_splits2; std::transform(splits2.begin(), splits2.end(), std::back_inserter(new_splits2), [&](auto ins) { @@ -891,8 +893,11 @@ std::array module::split(const std::vector& keys, const std::vector& values) +void module_with_inputs::replace(const std::vector& keys, + const std::vector& values) { for(auto& ins : inputs) { diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 5a46aa7fba3..552e3eec9e9 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -78,7 +78,8 @@ static std::vector find_split(module_ref rm) return result; } -static std::vector get_alive(module_ref rm, const std::vector& splits) +static std::vector get_alive(module_ref rm, + const std::vector& splits) { std::vector result; bool stop = false; @@ -87,9 +88,10 @@ static std::vector get_alive(module_ref rm, const std::vectorname() != "@param"; - }); + std::copy_if(live_set.begin(), + live_set.end(), + std::back_inserter(result), + [](instruction_ref live) { return live->name() != "@param"; }); stop = true; }); return result; @@ -107,10 +109,10 @@ static std::string assign_op(const std::vector& splits) return m.at(splits.front()->name()); } -static std::vector insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& mwi) +static std::vector +insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& mwi) { - auto param_map = - mwi.mod.get_ins_param_map(mwi.inputs, true); + auto param_map = mwi.mod.get_ins_param_map(mwi.inputs, true); return m.insert_instructions(ins, &mwi.mod, ¶m_map); } @@ -139,14 +141,14 @@ void split_reduce::apply(module_pass_manager& mpm) const if(not alive.empty()) { auto mods3 = rm->split(ins->inputs(), alive, splits); - auto r = insert_module_inline(mpm.get_module(), ins, mods3[0]); + auto r = insert_module_inline(mpm.get_module(), ins, mods3[0]); mods3[1].replace(alive, r); mods3[2].replace(alive, r); mods = {std::move(mods3[1]), std::move(mods3[2])}; } else { - mods = rm->split(ins->inputs(), splits); + mods = rm->split(ins->inputs(), splits); } auto* splitm = mpm.create_module(rm->name() + "_split", std::move(mods[0].mod)); diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 7bbd3588630..8cea383aad3 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -124,7 +124,7 @@ TEST_CASE(split_pointwise) { auto* mm = p1.get_main_module(); auto x = mm->add_parameter("x", s); - auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), x); auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sqrt); auto rsumb = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); @@ -136,7 +136,7 @@ TEST_CASE(split_pointwise) { auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); - auto sqrt = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); + auto sqrt = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); auto rsum = add_reduce(p2, "main:pointwise0:main:reduce_sum0:main:pointwise1_split", {sqrt}, From 9a86a41fbc5c400d5f35d89480a97cb2ffd5ad98 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 17:03:26 -0700 Subject: [PATCH 29/59] Check threshold --- src/include/migraphx/split_reduce.hpp | 2 +- src/split_reduce.cpp | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index a7a21ca1bc8..6a7b1c3d0cf 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -11,7 +11,7 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT split_reduce { - std::size_t split_size = 2048; + std::size_t split_size = 8192; std::string name() const { return "split_reduce"; } void apply(module_pass_manager& mpm) const; }; diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 552e3eec9e9..243ce783ec8 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -116,6 +116,13 @@ insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& return m.insert_instructions(ins, &mwi.mod, ¶m_map); } +static std::size_t get_reduce_size(module_ref rm) +{ + auto ins = std::find_if(rm->begin(), rm->end(), &is_reduce); + assert(ins != rm->end()); + return ins->inputs().front()->get_shape().elements() / ins->get_shape().elements(); +} + void split_reduce::apply(module_pass_manager& mpm) const { for(auto ins : iterator_for(mpm.get_module())) @@ -123,6 +130,8 @@ void split_reduce::apply(module_pass_manager& mpm) const if(ins->name() != "fused_reduce") continue; auto* rm = ins->module_inputs().front(); + if(get_reduce_size(rm) < split_size) + continue; auto splits = find_split(rm); if(splits.empty()) continue; @@ -133,7 +142,6 @@ void split_reduce::apply(module_pass_manager& mpm) const continue; auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); - // TODO: Check reduction size auto alive = get_alive(rm, splits); From d844e99fc46221e5d24bf56bd6058be9618da987 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 17:42:20 -0700 Subject: [PATCH 30/59] Add fill for split reduce --- src/split_reduce.cpp | 2 ++ src/targets/gpu/hip.cpp | 15 ++++++++++ src/targets/gpu/include/migraphx/gpu/hip.hpp | 29 ++++++++++++++++++++ src/targets/gpu/lowering.cpp | 11 ++++++++ 4 files changed, 57 insertions(+) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 243ce783ec8..c9cb2e03595 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -28,6 +28,8 @@ struct split_fused_reduce return pack(f(self.axes, "axes"), f(self.assign, "assign")); } + value attributes() const { return {{"zero_init", true}}; } + shape compute_shape(const std::vector& inputs, std::vector mods) const { if(mods.size() != 1) diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index b5306e681ed..697f1b2837d 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_REGISTER_OP(hip_allocate) +MIGRAPHX_REGISTER_OP(hip_fill) MIGRAPHX_REGISTER_OP(hip_sync_stream) MIGRAPHX_REGISTER_OP(hip_copy_to_gpu) MIGRAPHX_REGISTER_OP(hip_copy_from_gpu) @@ -246,6 +247,14 @@ void gpu_sync() void gpu_sync(const context& ctx) { ctx.finish(); } +void hip_async_memset(context& ctx, const argument& dst, int value) +{ + std::size_t dst_size = dst.get_shape().bytes(); + auto status = hipMemsetAsync(dst.data(), value, dst_size, ctx.get_stream().get()); + if(status != hipSuccess) + MIGRAPHX_THROW("Gpu fill failed: " + hip_error(status)); +} + void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind) { std::size_t src_size = src.get_shape().bytes(); @@ -293,6 +302,12 @@ argument get_preallocation(context& ctx, const std::string& id) return ctx.get_current_device().preallocations.at(id); } +void gpu_fill(context& ctx, const argument& dst, int value) +{ + // TODO: Handle non-packed tensor when value is not 0 + hip_async_memset(ctx, dst, value); +} + void store_preallocated_param(context& ctx, const std::string& id, const argument& a) { ctx.get_current_device().preallocations[id] = a; diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index d51df82005a..3cf917b4fa2 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -59,6 +59,8 @@ MIGRAPHX_GPU_EXPORT void copy_from_gpu(context& ctx, const argument& src, const MIGRAPHX_GPU_EXPORT argument get_preallocation(context& ctx, const std::string& id); +MIGRAPHX_GPU_EXPORT void gpu_fill(context& ctx, const argument& dst, int value = 0); + struct hip_allocate { shape s; @@ -81,6 +83,33 @@ struct hip_allocate } }; +struct hip_fill +{ + int value = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.value, "value")); + } + + std::string name() const { return "hip::fill"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(1); + return inputs.front(); + } + argument compute(context& ctx, const shape&, const std::vector& args) const + { + gpu_fill(ctx, args.front(), value); + return args.front(); + } + std::ptrdiff_t output_alias(const std::vector&) const + { + return 0; + } +}; + struct hip_sync_stream { diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index fcde59841fd..5c931c2776e 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -174,10 +174,21 @@ struct miopen_apply { check_shape(s, insert_custom_op(it, attrs)); } + if(attrs.get("zero_init", false)) + insert_fill0(it); } copy_params(); } + void insert_fill0(instruction_ref ins) const + { + instruction_ref alloc = instruction::get_output_alias(ins, true); + if(alloc == ins) + return; + auto fill = mod->insert_instruction(ins, make_op("hip::fill"), alloc); + instruction::replace_argument(ins, alloc, fill); + } + instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const { const auto& custom_op = ins->get_operator(); From 75aecd0a4a786cb6a7ee3ad6a56abaa1fb212c53 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 17:42:28 -0700 Subject: [PATCH 31/59] Format --- src/targets/gpu/hip.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/hip.hpp | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index 697f1b2837d..fcad9fc0b23 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -250,7 +250,7 @@ void gpu_sync(const context& ctx) { ctx.finish(); } void hip_async_memset(context& ctx, const argument& dst, int value) { std::size_t dst_size = dst.get_shape().bytes(); - auto status = hipMemsetAsync(dst.data(), value, dst_size, ctx.get_stream().get()); + auto status = hipMemsetAsync(dst.data(), value, dst_size, ctx.get_stream().get()); if(status != hipSuccess) MIGRAPHX_THROW("Gpu fill failed: " + hip_error(status)); } diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index 3cf917b4fa2..3299c451fa4 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -104,10 +104,7 @@ struct hip_fill gpu_fill(ctx, args.front(), value); return args.front(); } - std::ptrdiff_t output_alias(const std::vector&) const - { - return 0; - } + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; struct hip_sync_stream From 3da1f321920cbb24f3adba0c4ac534dbf792e98b Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 17:47:29 -0700 Subject: [PATCH 32/59] Add license --- src/include/migraphx/liveness.hpp | 24 +++++++++++++++++++ src/include/migraphx/param_utils.hpp | 24 +++++++++++++++++++ src/include/migraphx/split_reduce.hpp | 24 +++++++++++++++++++ src/memory_coloring.cpp | 2 +- src/param_utils.cpp | 24 +++++++++++++++++++ src/split_reduce.cpp | 24 +++++++++++++++++++ src/targets/gpu/hip.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/hip.hpp | 2 +- .../kernels/scatter_reduction_modes.hpp | 2 +- test/module_test.cpp | 2 +- 10 files changed, 125 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/liveness.hpp b/src/include/migraphx/liveness.hpp index 1b8761d9e7d..d785bbc7216 100644 --- a/src/include/migraphx/liveness.hpp +++ b/src/include/migraphx/liveness.hpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #ifndef MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP #define MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 153f4f2ac86..7c172647b93 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #ifndef MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP #define MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index 6a7b1c3d0cf..64478d183a4 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #ifndef MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP #define MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP diff --git a/src/memory_coloring.cpp b/src/memory_coloring.cpp index 649d9f5acc5..3753f549000 100644 --- a/src/memory_coloring.cpp +++ b/src/memory_coloring.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 8cf7be6ef80..1dd95ae48a4 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #include #include #include diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index c9cb2e03595..512abd7d4fa 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #include #include #include diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index fcad9fc0b23..58919e4bbdb 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index 3299c451fa4..acd7525d620 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index 07d066a1589..93f4bed2fb4 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/module_test.cpp b/test/module_test.cpp index d73b1afc7c3..360bb7a75ba 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 650ac713b94253e3bb5a8e2152e98f490d832be1 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 19 Mar 2024 17:47:33 -0700 Subject: [PATCH 33/59] Format --- src/include/migraphx/liveness.hpp | 46 +++++++++++++-------------- src/include/migraphx/param_utils.hpp | 46 +++++++++++++-------------- src/include/migraphx/split_reduce.hpp | 46 +++++++++++++-------------- src/param_utils.cpp | 46 +++++++++++++-------------- src/split_reduce.cpp | 46 +++++++++++++-------------- 5 files changed, 115 insertions(+), 115 deletions(-) diff --git a/src/include/migraphx/liveness.hpp b/src/include/migraphx/liveness.hpp index d785bbc7216..6d9715a8a10 100644 --- a/src/include/migraphx/liveness.hpp +++ b/src/include/migraphx/liveness.hpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #ifndef MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP #define MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 7c172647b93..1552c28300b 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #ifndef MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP #define MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index 64478d183a4..620096b06b6 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #ifndef MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP #define MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 1dd95ae48a4..5d9560cfee0 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #include #include #include diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 512abd7d4fa..e80d10f68d4 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #include #include #include From be097a49e7c08befa337190394f3cd023cace17e Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 20 Mar 2024 07:39:53 -0700 Subject: [PATCH 34/59] Fix tidy warnings --- src/include/migraphx/pass_manager.hpp | 2 +- src/module.cpp | 8 ++++---- src/pass_manager.cpp | 4 ++-- src/split_reduce.cpp | 7 +++---- src/targets/gpu/jit/reduce.cpp | 6 +++--- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/include/migraphx/pass_manager.hpp b/src/include/migraphx/pass_manager.hpp index d930cc2416f..fdbdc123a12 100644 --- a/src/include/migraphx/pass_manager.hpp +++ b/src/include/migraphx/pass_manager.hpp @@ -39,7 +39,7 @@ struct module_pass_manager module_pass_manager(const module_pass_manager&) = delete; virtual module& get_module() = 0; virtual module* create_module(const std::string& name) = 0; - virtual module* create_module(const std::string& name, const module& m) = 0; + virtual module* create_module(const std::string& name, module m) = 0; virtual module* get_common_parent() = 0; virtual module* get_root_module() = 0; virtual void run_pass(const pass& p) = 0; diff --git a/src/module.cpp b/src/module.cpp index 9dda3620f9c..2f4c99575d9 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -452,7 +452,7 @@ module::insert_instructions(instruction_ref ins, { std::unordered_map default_map_ins; return insert_generic_instructions( - *this, ins, instructions, map_ins ? *map_ins : default_map_ins, std::move(insert)); + *this, ins, instructions, map_ins == nullptr ? default_map_ins : *map_ins, std::move(insert)); } std::vector @@ -463,7 +463,7 @@ module::insert_instructions(instruction_ref ins, { std::unordered_map default_map_ins; return insert_generic_instructions( - *this, ins, iterator_for(*m), map_ins ? *map_ins : default_map_ins, std::move(insert)); + *this, ins, iterator_for(*m), map_ins == nullptr ? default_map_ins : *map_ins, std::move(insert)); } std::vector @@ -476,7 +476,7 @@ module::insert_instructions(instruction_ref ins, auto r = range(start, last); std::unordered_map default_map_ins; return insert_generic_instructions( - *this, ins, iterator_for(r), map_ins ? *map_ins : default_map_ins, std::move(insert)); + *this, ins, iterator_for(r), map_ins == nullptr ? default_map_ins : *map_ins, std::move(insert)); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } @@ -862,7 +862,7 @@ generic_split(const module& m, } auto r = m2.add_instructions(instructions2, &map_ins2); m2.add_return(r); - if(map_ins) + if(map_ins != nullptr) *map_ins = map_ins2; return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; } diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index be62bc4c01e..3748e094773 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -99,10 +99,10 @@ struct module_pm : module_pass_manager return prog->create_module(name); } - virtual module* create_module(const std::string& name, const module& m) override + virtual module* create_module(const std::string& name, module m) override { assert(prog); - return prog->create_module(name, m); + return prog->create_module(name, std::move(m)); } virtual module* get_common_parent() override { return common_parent; } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index e80d10f68d4..14173e9faea 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -90,7 +89,7 @@ MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } -static std::vector find_split(module_ref rm) +static std::vector find_split(const_module_ref rm) { std::vector result; auto reduce_ins = std::find_if(rm->begin(), rm->end(), &is_reduce); @@ -104,7 +103,7 @@ static std::vector find_split(module_ref rm) return result; } -static std::vector get_alive(module_ref rm, +static std::vector get_alive(const_module_ref rm, const std::vector& splits) { std::vector result; @@ -142,7 +141,7 @@ insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& return m.insert_instructions(ins, &mwi.mod, ¶m_map); } -static std::size_t get_reduce_size(module_ref rm) +static std::size_t get_reduce_size(const_module_ref rm) { auto ins = std::find_if(rm->begin(), rm->end(), &is_reduce); assert(ins != rm->end()); diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index c3ba3e39b8b..44d6d9bc81c 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -138,8 +138,8 @@ static std::vector split_reduce(const std::vector& inputs, { std::vector result; auto input_shape = inputs.front(); - auto reduce_shape = inputs[inputs.size() - 2]; - auto output_shape = inputs[inputs.size() - 1]; + const auto& reduce_shape = inputs[inputs.size() - 2]; + const auto& output_shape = inputs[inputs.size() - 1]; auto is = range(reduce_shape.lens().size()); using array_type = std::array; @@ -159,7 +159,7 @@ static std::vector split_reduce(const std::vector& inputs, auto factors = make_array(2, 3, 5, 7, 11); while(r > min_size) { - auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + const auto* it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; r /= *it; From 79ed5175afd491ad57c0859934b3b45f82e5d81d Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 20 Mar 2024 07:40:00 -0700 Subject: [PATCH 35/59] Format --- src/module.cpp | 21 +++++++++++++++------ src/targets/gpu/jit/reduce.cpp | 3 ++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 2f4c99575d9..3d9ab706079 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -451,8 +451,11 @@ module::insert_instructions(instruction_ref ins, module::inserter insert) { std::unordered_map default_map_ins; - return insert_generic_instructions( - *this, ins, instructions, map_ins == nullptr ? default_map_ins : *map_ins, std::move(insert)); + return insert_generic_instructions(*this, + ins, + instructions, + map_ins == nullptr ? default_map_ins : *map_ins, + std::move(insert)); } std::vector @@ -462,8 +465,11 @@ module::insert_instructions(instruction_ref ins, module::inserter insert) { std::unordered_map default_map_ins; - return insert_generic_instructions( - *this, ins, iterator_for(*m), map_ins == nullptr ? default_map_ins : *map_ins, std::move(insert)); + return insert_generic_instructions(*this, + ins, + iterator_for(*m), + map_ins == nullptr ? default_map_ins : *map_ins, + std::move(insert)); } std::vector @@ -475,8 +481,11 @@ module::insert_instructions(instruction_ref ins, { auto r = range(start, last); std::unordered_map default_map_ins; - return insert_generic_instructions( - *this, ins, iterator_for(r), map_ins == nullptr ? default_map_ins : *map_ins, std::move(insert)); + return insert_generic_instructions(*this, + ins, + iterator_for(r), + map_ins == nullptr ? default_map_ins : *map_ins, + std::move(insert)); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 44d6d9bc81c..bcdff4ec285 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -159,7 +159,8 @@ static std::vector split_reduce(const std::vector& inputs, auto factors = make_array(2, 3, 5, 7, 11); while(r > min_size) { - const auto* it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + const auto* it = + std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; r /= *it; From 31021f34b8628144918b4699b43944aa05bb0520 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 20 Mar 2024 07:40:36 -0700 Subject: [PATCH 36/59] Remvoe assert --- src/module.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/module.cpp b/src/module.cpp index 3d9ab706079..3e211ce8c9a 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -503,7 +503,6 @@ instruction_ref module::add_parameter(std::string name, shape s) instruction_ref module::add_return(std::vector args) { - assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); shape instr_shape = compute_shape(builtin::returns{}, args); impl->push_back({builtin::returns{}, instr_shape, std::move(args)}); auto result = std::prev(impl->instructions.end()); From fb1c9ba26c6b10b233bf64402732ff44b74e9dca Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 20 Mar 2024 15:31:41 -0700 Subject: [PATCH 37/59] Only use for reduce_sum --- src/split_reduce.cpp | 5 ++++- src/targets/gpu/lowering.cpp | 10 ++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 14173e9faea..992144df15e 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -51,7 +51,7 @@ struct split_fused_reduce return pack(f(self.axes, "axes"), f(self.assign, "assign")); } - value attributes() const { return {{"zero_init", true}}; } + value attributes() const { return {{"prefill", 0}}; } shape compute_shape(const std::vector& inputs, std::vector mods) const { @@ -98,6 +98,9 @@ static std::vector find_split(const_module_ref rm) // Bail if there is more than one reduce for now if(std::any_of(std::next(reduce_ins), rm->end(), &is_reduce)) return result; + // Only handle reduce_sum for now + if(reduce_ins->name() != "reduce_sum") + return result; result.push_back(reduce_ins); // TODO: Find instructions that are used again in the module return result; diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 5c931c2776e..f666280571c 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -174,18 +174,20 @@ struct miopen_apply { check_shape(s, insert_custom_op(it, attrs)); } - if(attrs.get("zero_init", false)) - insert_fill0(it); + if(attrs.contains("prefiil")) + { + insert_fill(it, attrs.at("prefill")); + } } copy_params(); } - void insert_fill0(instruction_ref ins) const + void insert_fill(instruction_ref ins, value v) const { instruction_ref alloc = instruction::get_output_alias(ins, true); if(alloc == ins) return; - auto fill = mod->insert_instruction(ins, make_op("hip::fill"), alloc); + auto fill = mod->insert_instruction(ins, make_op("hip::fill", {{"value", v}}), alloc); instruction::replace_argument(ins, alloc, fill); } From b0220b4215af9c1829b27426f46afa6f13ce531a Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 20 Mar 2024 15:36:27 -0700 Subject: [PATCH 38/59] Fix typo --- src/targets/gpu/lowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index f666280571c..51d924f6a5f 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -174,7 +174,7 @@ struct miopen_apply { check_shape(s, insert_custom_op(it, attrs)); } - if(attrs.contains("prefiil")) + if(attrs.contains("prefill")) { insert_fill(it, attrs.at("prefill")); } From 188602ecded9107cea3fea38a83286ada4766681 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 28 Mar 2024 17:55:39 -0700 Subject: [PATCH 39/59] Add doc --- src/include/migraphx/module.hpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 7d79f318ba7..f9d41121159 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -209,14 +209,24 @@ struct MIGRAPHX_EXPORT module void finalize(std::vector& contexts); + /// Create a mapping from the input instruction to the corresponding + /// parameter instruction. Use the `reverse` flag to reverse the lookup + /// to be from parameter instruction to input instread. std::unordered_map get_ins_param_map(const std::vector& inputs, bool reverse = false) const; using with_inputs = module_with_inputs; + /// This will split the module into two parts at the instruction splits. + /// Each split instruction becomes an input parameter in the second + /// module. As such the inputs instructions to the second module will use + /// the split instructions as input placeholders that can be replaced + /// later. std::array split(const std::vector& args, const std::vector& splits) const; + /// This will split the module in 3 parts using different split + /// instruction for each additional module. std::array split(const std::vector& args, const std::vector& splits1, const std::vector& splits2) const; @@ -286,8 +296,13 @@ struct module_with_inputs { module mod; std::vector inputs; + /// Replace the instruction in the inputs with rep void replace(instruction_ref ins, instruction_ref rep); + /// Replace the input instructions using the map_ins to lookup the replacement void replace(const std::unordered_map& map_ins); + + /// Replace the input instructions of the keys with the instructions + /// passed as values. Both vectors should be in the same order. void replace(const std::vector& keys, const std::vector& values); }; From cba899eb3c78bc8ff984cc82ff4d964789b34c27 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 28 Mar 2024 18:07:03 -0700 Subject: [PATCH 40/59] Add asserts --- docs/dev/env_vars.rst | 4 ++++ src/module.cpp | 3 +++ src/targets/gpu/hip.cpp | 1 + 3 files changed, 8 insertions(+) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 7f0c554a30f..16a324b8d9a 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -107,6 +107,10 @@ Disables the ``schedule`` pass. Set to "1", "enable", "enabled", "yes", or "true" to use. Disables the ``fuse_reduce`` pass. +.. envvar:: MIGRAPHX_ENABLE_SPLIT_REDUCE +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable split_reduce. + .. envvar:: MIGRAPHX_ENABLE_NHWC Set to "1", "enable", "enabled", "yes", or "true" to use. diff --git a/src/module.cpp b/src/module.cpp index 3e211ce8c9a..137ff7e5bc6 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -915,6 +915,7 @@ void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) auto it = std::find(inputs.begin(), inputs.end(), ins); if(it == inputs.end()) return; + assert((*it)->get_shape().lens() == rep->get_shape().lens()); *it = rep; } void module_with_inputs::replace( @@ -924,6 +925,7 @@ void module_with_inputs::replace( { if(not contains(map_ins, ins)) continue; + assert(ins->get_shape().lens() == map_ins.at(ins)->get_shape().lens()); ins = map_ins.at(ins); } } @@ -935,6 +937,7 @@ void module_with_inputs::replace(const std::vector& keys, auto it = std::find(keys.begin(), keys.end(), ins); if(it == keys.end()) continue; + assert(ins->get_shape().lens() == values[it - keys.begin()]->get_shape().lens()); ins = values[it - keys.begin()]; } } diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index 58919e4bbdb..49505bcf8be 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -305,6 +305,7 @@ argument get_preallocation(context& ctx, const std::string& id) void gpu_fill(context& ctx, const argument& dst, int value) { // TODO: Handle non-packed tensor when value is not 0 + assert(dst.get_shape().packed() and value == 0); hip_async_memset(ctx, dst, value); } From bc9827dd1f7e52bedd31a50c95c7a60aabf54c50 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 28 Mar 2024 18:12:36 -0700 Subject: [PATCH 41/59] Format --- src/module.cpp | 1 + src/split_reduce.cpp | 2 +- src/targets/gpu/jit/reduce.cpp | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 137ff7e5bc6..d351370318e 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 992144df15e..29ee83c9a11 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -157,7 +157,7 @@ void split_reduce::apply(module_pass_manager& mpm) const { if(ins->name() != "fused_reduce") continue; - auto* rm = ins->module_inputs().front(); + auto* rm = ins->module_inputs().front(); if(get_reduce_size(rm) < split_size) continue; auto splits = find_split(rm); diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index bcdff4ec285..ab9e7ac92a2 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -137,7 +137,7 @@ static std::vector split_reduce(const std::vector& inputs, std::size_t min_size = 1024) { std::vector result; - auto input_shape = inputs.front(); + auto input_shape = inputs.front(); const auto& reduce_shape = inputs[inputs.size() - 2]; const auto& output_shape = inputs[inputs.size() - 1]; From 155f2fff48fbf5952a900d2f3dc5392e4f2a1f2e Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 12:21:36 -0700 Subject: [PATCH 42/59] Add unit tests for split function --- test/module_test.cpp | 106 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/test/module_test.cpp b/test/module_test.cpp index 360bb7a75ba..12fb1e98b7a 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -458,4 +458,110 @@ TEST_CASE(multiple_module_dependency) EXPECT((sub->validate() == sub->end())); } +TEST_CASE(module_split2) +{ + migraphx::shape s{migraphx::shape::float_type, {1}}; + migraphx::module input_m; + std::vector inputs; + { + auto x1 = input_m.add_parameter("x1", s); + auto x2 = input_m.add_parameter("x2", s); + auto x3 = input_m.add_parameter("x3", s); + inputs = {x1, x2, x3}; + } + migraphx::module m; + std::vector splits; + { + auto x1 = m.add_parameter("x1", s); + auto x2 = m.add_parameter("x2", s); + auto x3 = m.add_parameter("x3", s); + auto add = m.add_instruction(migraphx::make_op("add"), x1, x2); + auto mul = m.add_instruction(migraphx::make_op("mul"), add, x3); + m.add_return({mul}); + splits.push_back(add); + } + auto mods = m.split(inputs, splits); + + migraphx::module m1; + { + auto x1 = m1.add_parameter("x1", s); + auto x2 = m1.add_parameter("x2", s); + auto add = m1.add_instruction(migraphx::make_op("add"), x1, x2); + m1.add_return({add}); + } + migraphx::module m2; + { + auto x0 = m2.add_parameter("x0", s); + auto x1 = m2.add_parameter("x1", s); + auto mul = m2.add_instruction(migraphx::make_op("mul"), x0, x1); + m2.add_return({mul}); + } + EXPECT(mods[0].mod.sort() == m1.sort()); + EXPECT(mods[1].mod.sort() == m2.sort()); + + EXPECT(bool{mods[0].inputs[0] == inputs[0]}); + EXPECT(bool{mods[0].inputs[1] == inputs[1]}); + + EXPECT(bool{mods[1].inputs[0] == splits.front()}); + EXPECT(bool{mods[1].inputs[1] == inputs[2]}); +} + +TEST_CASE(module_split3) +{ + migraphx::shape s{migraphx::shape::float_type, {1}}; + migraphx::module input_m; + std::vector inputs; + { + auto x1 = input_m.add_parameter("x1", s); + auto x2 = input_m.add_parameter("x2", s); + inputs = {x1, x2}; + } + migraphx::module m; + std::vector splits1; + std::vector splits2; + { + auto x1 = m.add_parameter("x1", s); + auto x2 = m.add_parameter("x2", s); + auto mul = m.add_instruction(migraphx::make_op("mul"), x1, x2); + auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), mul); + auto add = m.add_instruction(migraphx::make_op("add"), sqrt, mul); + m.add_return({add}); + splits1.push_back(mul); + splits2.push_back(sqrt); + } + auto mods = m.split(inputs, splits1, splits2); + + migraphx::module m1; + { + auto x1 = m1.add_parameter("x1", s); + auto x2 = m1.add_parameter("x2", s); + auto mul = m1.add_instruction(migraphx::make_op("mul"), x1, x2); + m1.add_return({mul}); + } + migraphx::module m2; + { + auto x0 = m2.add_parameter("x0", s); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), x0); + m2.add_return({sqrt}); + } + migraphx::module m3; + { + auto x0 = m3.add_parameter("x0", s); + auto x1 = m3.add_parameter("x1", s); + auto add = m3.add_instruction(migraphx::make_op("add"), x0, x1); + m3.add_return({add}); + } + EXPECT(mods[0].mod.sort() == m1.sort()); + EXPECT(mods[1].mod.sort() == m2.sort()); + EXPECT(mods[2].mod.sort() == m3.sort()); + + EXPECT(bool{mods[0].inputs[0] == inputs[0]}); + EXPECT(bool{mods[0].inputs[1] == inputs[1]}); + + EXPECT(bool{mods[1].inputs[0] == splits1.front()}); + + EXPECT(bool{mods[2].inputs[0] == splits2.front()}); + EXPECT(bool{mods[2].inputs[1] == splits1.front()}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From f626268cb9a292b7c8892a91aa8599745308991a Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 12:21:42 -0700 Subject: [PATCH 43/59] Format --- test/module_test.cpp | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index 12fb1e98b7a..928d1e7aec4 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -467,14 +467,14 @@ TEST_CASE(module_split2) auto x1 = input_m.add_parameter("x1", s); auto x2 = input_m.add_parameter("x2", s); auto x3 = input_m.add_parameter("x3", s); - inputs = {x1, x2, x3}; + inputs = {x1, x2, x3}; } migraphx::module m; std::vector splits; { - auto x1 = m.add_parameter("x1", s); - auto x2 = m.add_parameter("x2", s); - auto x3 = m.add_parameter("x3", s); + auto x1 = m.add_parameter("x1", s); + auto x2 = m.add_parameter("x2", s); + auto x3 = m.add_parameter("x3", s); auto add = m.add_instruction(migraphx::make_op("add"), x1, x2); auto mul = m.add_instruction(migraphx::make_op("mul"), add, x3); m.add_return({mul}); @@ -484,15 +484,15 @@ TEST_CASE(module_split2) migraphx::module m1; { - auto x1 = m1.add_parameter("x1", s); - auto x2 = m1.add_parameter("x2", s); + auto x1 = m1.add_parameter("x1", s); + auto x2 = m1.add_parameter("x2", s); auto add = m1.add_instruction(migraphx::make_op("add"), x1, x2); m1.add_return({add}); } migraphx::module m2; { - auto x0 = m2.add_parameter("x0", s); - auto x1 = m2.add_parameter("x1", s); + auto x0 = m2.add_parameter("x0", s); + auto x1 = m2.add_parameter("x1", s); auto mul = m2.add_instruction(migraphx::make_op("mul"), x0, x1); m2.add_return({mul}); } @@ -514,17 +514,17 @@ TEST_CASE(module_split3) { auto x1 = input_m.add_parameter("x1", s); auto x2 = input_m.add_parameter("x2", s); - inputs = {x1, x2}; + inputs = {x1, x2}; } migraphx::module m; std::vector splits1; std::vector splits2; { - auto x1 = m.add_parameter("x1", s); - auto x2 = m.add_parameter("x2", s); - auto mul = m.add_instruction(migraphx::make_op("mul"), x1, x2); + auto x1 = m.add_parameter("x1", s); + auto x2 = m.add_parameter("x2", s); + auto mul = m.add_instruction(migraphx::make_op("mul"), x1, x2); auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), mul); - auto add = m.add_instruction(migraphx::make_op("add"), sqrt, mul); + auto add = m.add_instruction(migraphx::make_op("add"), sqrt, mul); m.add_return({add}); splits1.push_back(mul); splits2.push_back(sqrt); @@ -533,21 +533,21 @@ TEST_CASE(module_split3) migraphx::module m1; { - auto x1 = m1.add_parameter("x1", s); - auto x2 = m1.add_parameter("x2", s); + auto x1 = m1.add_parameter("x1", s); + auto x2 = m1.add_parameter("x2", s); auto mul = m1.add_instruction(migraphx::make_op("mul"), x1, x2); m1.add_return({mul}); } migraphx::module m2; { - auto x0 = m2.add_parameter("x0", s); + auto x0 = m2.add_parameter("x0", s); auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), x0); m2.add_return({sqrt}); } migraphx::module m3; { - auto x0 = m3.add_parameter("x0", s); - auto x1 = m3.add_parameter("x1", s); + auto x0 = m3.add_parameter("x0", s); + auto x1 = m3.add_parameter("x1", s); auto add = m3.add_instruction(migraphx::make_op("add"), x0, x1); m3.add_return({add}); } From 7c3069ac7a52df0f341de1d843a5cb91809d0bc4 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 14:38:06 -0700 Subject: [PATCH 44/59] Add test for small reduce --- test/split_reduce.cpp | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 8cea383aad3..eed0586aebc 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -39,7 +39,15 @@ void run_pass(migraphx::program& p) migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, - migraphx::split_reduce{}, + migraphx::split_reduce{.split_size=8192}, + migraphx::dead_code_elimination{}}); +} + +void run_fuse_pass(migraphx::program& p) +{ + migraphx::run_passes(p, + {migraphx::fuse_pointwise{}, + migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); } @@ -117,6 +125,26 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(small) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 1024}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = mm->add_instruction(migraphx::make_op("add"), x, rsumb); + mm->add_return({add}); + } + migraphx::program p2 = p1; + run_fuse_pass(p2); + run_pass(p1); + + EXPECT(p1 == p2); +} + TEST_CASE(split_pointwise) { migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; From bc4155869ef12769dcf86be0f069188d5c00e802 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 14:38:12 -0700 Subject: [PATCH 45/59] Format --- test/split_reduce.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index eed0586aebc..76b074e5062 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -39,16 +39,15 @@ void run_pass(migraphx::program& p) migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, - migraphx::split_reduce{.split_size=8192}, + migraphx::split_reduce{.split_size = 8192}, migraphx::dead_code_elimination{}}); } void run_fuse_pass(migraphx::program& p) { - migraphx::run_passes(p, - {migraphx::fuse_pointwise{}, - migraphx::fuse_reduce{}, - migraphx::dead_code_elimination{}}); + migraphx::run_passes( + p, + {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); } bool all_instructions_are_local(const migraphx::module& m) From f1e6ab4b3401aa2ddd76a2bfa6844451126e737e Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 14:46:03 -0700 Subject: [PATCH 46/59] Add docstring --- src/include/migraphx/split_reduce.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index 620096b06b6..687f363b0d2 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -33,6 +33,12 @@ inline namespace MIGRAPHX_INLINE_NS { struct module_pass_manager; +/// For large reductions that are larger than the split_size, this pass will +/// split the fused_reduce operators so that the reduction will happen across +/// multiple compute units gaining better occupancy for targets with many +/// compute units. Since the reduction is split across compute units, any +/// elementwise operators will be split into separate operators as well due to +/// needing global synchronization. struct MIGRAPHX_EXPORT split_reduce { std::size_t split_size = 8192; From 9d71f9d7e8d81b65067bfdc35c508c25e8f24fa3 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 14:52:32 -0700 Subject: [PATCH 47/59] Remove TODO --- src/split_reduce.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 29ee83c9a11..f2c36e83899 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -102,7 +102,6 @@ static std::vector find_split(const_module_ref rm) if(reduce_ins->name() != "reduce_sum") return result; result.push_back(reduce_ins); - // TODO: Find instructions that are used again in the module return result; } From be1d82daf1f45ef086400c03524a814c06205dfb Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 14:52:57 -0700 Subject: [PATCH 48/59] Add assert --- src/param_utils.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 5d9560cfee0..61302a0afba 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -31,6 +31,7 @@ inline namespace MIGRAPHX_INLINE_NS { std::string param_name(std::size_t i, const std::string& prefix) { + assert(i < 10); return prefix + std::to_string(i); } From d19c73472a74837ac23f98ffb283403fc46758d6 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 17:19:32 -0700 Subject: [PATCH 49/59] Add more tests and TODOs --- src/module.cpp | 11 ++++------ src/split_reduce.cpp | 3 +++ test/module_test.cpp | 9 ++++++-- test/split_reduce.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 9 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index d351370318e..d49717b749d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -785,13 +785,10 @@ select_params(const std::vector& instructions, const std::unordered_map& param_map) { std::vector result; - transform_if( - instructions.begin(), - instructions.end(), - std::back_inserter(result), - [&](instruction_ref ins) { return contains(param_map, ins); }, - [&](instruction_ref ins) { return param_map.at(ins); }); - sort_params(result); + std::vector params; + std::copy_if(instructions.begin(), instructions.end(), std::back_inserter(params), [&](instruction_ref ins) { return contains(param_map, ins); }); + sort_params(params); + std::transform(params.begin(), params.end(), std::back_inserter(result), [&](instruction_ref ins) { return param_map.at(ins); }); return result; } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index f2c36e83899..f2bcf343131 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -96,9 +96,11 @@ static std::vector find_split(const_module_ref rm) if(reduce_ins == rm->end()) return result; // Bail if there is more than one reduce for now + // TODO: Support multiple reductions if(std::any_of(std::next(reduce_ins), rm->end(), &is_reduce)) return result; // Only handle reduce_sum for now + // TODO: Support other reduction types if(reduce_ins->name() != "reduce_sum") return result; result.push_back(reduce_ins); @@ -163,6 +165,7 @@ void split_reduce::apply(module_pass_manager& mpm) const if(splits.empty()) continue; // Only use split reduce with float for now + // TODO: Support half and other data types if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { return split->get_shape().type() == shape::float_type; })) diff --git a/test/module_test.cpp b/test/module_test.cpp index 928d1e7aec4..1752210b0b1 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -467,7 +467,10 @@ TEST_CASE(module_split2) auto x1 = input_m.add_parameter("x1", s); auto x2 = input_m.add_parameter("x2", s); auto x3 = input_m.add_parameter("x3", s); - inputs = {x1, x2, x3}; + auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); + auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); + auto sx3 = input_m.add_instruction(migraphx::make_op("sqrt"), x3); + inputs = {sx1, sx2, sx3}; } migraphx::module m; std::vector splits; @@ -514,7 +517,9 @@ TEST_CASE(module_split3) { auto x1 = input_m.add_parameter("x1", s); auto x2 = input_m.add_parameter("x2", s); - inputs = {x1, x2}; + auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); + auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); + inputs = {sx1, sx2}; } migraphx::module m; std::vector splits1; diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 76b074e5062..77de81fe2e4 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -93,6 +93,32 @@ inline auto single_reduce(const std::string& name) } TEST_CASE(single) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + mm->add_return({rsum}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce(p2, + "main:reduce_sum0_split", + {x}, + {2}, + "assign_add", + single_reduce("reduce_sum")); + mm->add_return({rsum}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(fused) { migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; migraphx::program p1; @@ -178,4 +204,28 @@ TEST_CASE(split_pointwise) EXPECT(p1 == p2); } +TEST_CASE(sequence_reduces) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsum1b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1); + auto sub = mm->add_instruction(migraphx::make_op("sub"), x, rsum1b); + auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sub); + auto rsum2b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum2); + auto add = mm->add_instruction(migraphx::make_op("add"), rsum2b, x); + mm->add_return({add}); + } + migraphx::program p2 = p1; + run_fuse_pass(p2); + run_pass(p1); + + EXPECT(p1 == p2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 3a9267bb1d74e4ff8bc74bff5daa7f774f631af7 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Apr 2024 17:19:41 -0700 Subject: [PATCH 50/59] Format --- src/module.cpp | 10 ++++++++-- test/module_test.cpp | 4 ++-- test/split_reduce.cpp | 24 ++++++++++-------------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index d49717b749d..8c28c15d22d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -786,9 +786,15 @@ select_params(const std::vector& instructions, { std::vector result; std::vector params; - std::copy_if(instructions.begin(), instructions.end(), std::back_inserter(params), [&](instruction_ref ins) { return contains(param_map, ins); }); + std::copy_if(instructions.begin(), + instructions.end(), + std::back_inserter(params), + [&](instruction_ref ins) { return contains(param_map, ins); }); sort_params(params); - std::transform(params.begin(), params.end(), std::back_inserter(result), [&](instruction_ref ins) { return param_map.at(ins); }); + std::transform(params.begin(), + params.end(), + std::back_inserter(result), + [&](instruction_ref ins) { return param_map.at(ins); }); return result; } diff --git a/test/module_test.cpp b/test/module_test.cpp index 1752210b0b1..c766affdc6f 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -470,7 +470,7 @@ TEST_CASE(module_split2) auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); auto sx3 = input_m.add_instruction(migraphx::make_op("sqrt"), x3); - inputs = {sx1, sx2, sx3}; + inputs = {sx1, sx2, sx3}; } migraphx::module m; std::vector splits; @@ -519,7 +519,7 @@ TEST_CASE(module_split3) auto x2 = input_m.add_parameter("x2", s); auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); - inputs = {sx1, sx2}; + inputs = {sx1, sx2}; } migraphx::module m; std::vector splits1; diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 77de81fe2e4..f757391216e 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -97,22 +97,18 @@ TEST_CASE(single) migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); mm->add_return({rsum}); } run_pass(p1); migraphx::program p2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto rsum = add_reduce(p2, - "main:reduce_sum0_split", - {x}, - {2}, - "assign_add", - single_reduce("reduce_sum")); + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0_split", {x}, {2}, "assign_add", single_reduce("reduce_sum")); mm->add_return({rsum}); } EXPECT(p1 == p2); @@ -209,12 +205,12 @@ TEST_CASE(sequence_reduces) migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); auto rsum1b = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1); - auto sub = mm->add_instruction(migraphx::make_op("sub"), x, rsum1b); + auto sub = mm->add_instruction(migraphx::make_op("sub"), x, rsum1b); auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sub); auto rsum2b = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum2); From 7e3f587085ebe9e6852f4ef54ef8b92046e5ebdd Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 5 Apr 2024 09:51:38 -0700 Subject: [PATCH 51/59] Format --- test/module_test.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index c766affdc6f..f44c089d30d 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -464,9 +464,9 @@ TEST_CASE(module_split2) migraphx::module input_m; std::vector inputs; { - auto x1 = input_m.add_parameter("x1", s); - auto x2 = input_m.add_parameter("x2", s); - auto x3 = input_m.add_parameter("x3", s); + auto x1 = input_m.add_parameter("x1", s); + auto x2 = input_m.add_parameter("x2", s); + auto x3 = input_m.add_parameter("x3", s); auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); auto sx3 = input_m.add_instruction(migraphx::make_op("sqrt"), x3); @@ -515,8 +515,8 @@ TEST_CASE(module_split3) migraphx::module input_m; std::vector inputs; { - auto x1 = input_m.add_parameter("x1", s); - auto x2 = input_m.add_parameter("x2", s); + auto x1 = input_m.add_parameter("x1", s); + auto x2 = input_m.add_parameter("x2", s); auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); inputs = {sx1, sx2}; From 3a216c6421eb33b91df715de5aea3b542d89ecab Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 5 Apr 2024 10:23:54 -0700 Subject: [PATCH 52/59] Fix windows --- src/targets/gpu/jit/reduce.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index ab9e7ac92a2..35a744f9a6a 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -159,7 +159,8 @@ static std::vector split_reduce(const std::vector& inputs, auto factors = make_array(2, 3, 5, 7, 11); while(r > min_size) { - const auto* it = + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; From 4db4a59ad9e04494c1e24a259f572e0c753c1f1e Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 5 Apr 2024 10:23:59 -0700 Subject: [PATCH 53/59] Format --- src/targets/gpu/jit/reduce.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 35a744f9a6a..61802e45335 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -160,8 +160,7 @@ static std::vector split_reduce(const std::vector& inputs, while(r > min_size) { // NOLINTNEXTLINE(readability-qualified-auto) - auto it = - std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; r /= *it; From 0c518399a8d8ed1fec3281ec0dc70bc0daea6e88 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 8 Apr 2024 14:17:22 -0700 Subject: [PATCH 54/59] Add docstring --- src/targets/gpu/jit/reduce.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 61802e45335..bb311421bd6 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -133,6 +133,12 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) return wavefront_size; } +/// This will adjust the input shapes so a partial reduction is done per workgroup. +/// This is done by splitting the reduction axis so each split group becomes +/// part of the batch. So if we want to do a split redution of a tensor +/// {K}, then this will create a tensor of {K/N, N} where N is the number of +/// split groups. To compute the number of split groups it finds the largets +/// divisor that can divide K to make it less than min_size. static std::vector split_reduce(const std::vector& inputs, std::size_t min_size = 1024) { From 614830df4e65e9f45d47d122f060c4925379701d Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 8 Apr 2024 14:17:35 -0700 Subject: [PATCH 55/59] Format --- src/targets/gpu/jit/reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index bb311421bd6..9d3680dcd7b 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -138,7 +138,7 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) /// part of the batch. So if we want to do a split redution of a tensor /// {K}, then this will create a tensor of {K/N, N} where N is the number of /// split groups. To compute the number of split groups it finds the largets -/// divisor that can divide K to make it less than min_size. +/// divisor that can divide K to make it less than min_size. static std::vector split_reduce(const std::vector& inputs, std::size_t min_size = 1024) { From b934717db51b76f5b27dc40eb18342598b5a48ed Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Tue, 9 Apr 2024 09:12:22 -0500 Subject: [PATCH 56/59] Update src/targets/gpu/jit/reduce.cpp Co-authored-by: Umang Yadav <29876643+umangyadav@users.noreply.github.com> --- src/targets/gpu/jit/reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 9d3680dcd7b..0dc8e34b855 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -137,7 +137,7 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) /// This is done by splitting the reduction axis so each split group becomes /// part of the batch. So if we want to do a split redution of a tensor /// {K}, then this will create a tensor of {K/N, N} where N is the number of -/// split groups. To compute the number of split groups it finds the largets +/// split groups. To compute the number of split groups it finds the largest /// divisor that can divide K to make it less than min_size. static std::vector split_reduce(const std::vector& inputs, std::size_t min_size = 1024) From 936d5f6bcd4f3045598a8d766b3b382e74b21b12 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 9 Apr 2024 12:58:01 -0700 Subject: [PATCH 57/59] Only use unsafe when available --- src/targets/gpu/jit/scatter.hpp | 4 +--- .../kernels/scatter_reduction_modes.hpp | 22 ++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/jit/scatter.hpp b/src/targets/gpu/jit/scatter.hpp index ff615331684..2902aa9c599 100644 --- a/src/targets/gpu/jit/scatter.hpp +++ b/src/targets/gpu/jit/scatter.hpp @@ -48,9 +48,7 @@ struct scatter_compiler : compiler options.output = inputs.back(); options.kernel_name = derived().get_kernel_name(op); options.virtual_inputs = inputs; - // The compiler protests the inequality comparison in assign_mul when pertaining to floating - // point, despite it making sense in the context. Thus the warning removal. - options.emplace_param("-Wno-float-equal"); + options.emplace_param("-DMIGRAPHX_ALLOW_ATOMIC_CAS=1"); const auto src = derived().make_interpolated_string(op); return prepend_copy_data_to_output(compile_hip_code_object(src, options)); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index 93f4bed2fb4..a62a44e9934 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -25,6 +25,14 @@ #define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP #include +#include +#include + +#ifndef MIGRAPHX_ALLOW_ATOMIC_CAS +#define MIGRAPHX_ALLOW_ATOMIC_CAS 0 +#endif + +#define MIGRAPHX_ATOMIC_CAS_WARNING() MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow") namespace migraphx { @@ -42,7 +50,15 @@ struct assign_add template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - unsafeAtomicAdd(&x, T(y)); + if constexpr(is_same{} or is_same{}) + { + unsafeAtomicAdd(&x, T(y)); + } + else + { + MIGRAPHX_ATOMIC_CAS_WARNING(); + atomicAdd(&x, T(y)); + } } }; @@ -51,13 +67,17 @@ struct assign_mul template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { + MIGRAPHX_ATOMIC_CAS_WARNING(); T old = x; T assumed; do { assumed = old; old = atomicCAS(&x, assumed, assumed * y); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wfloat-equal" } while(assumed != old); +#pragma clang diagnostic pop } }; From 7bcb56082905f2daf03822200669ce83ea0f6ad7 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 9 Apr 2024 12:58:08 -0700 Subject: [PATCH 58/59] Format --- .../include/migraphx/kernels/scatter_reduction_modes.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index a62a44e9934..313978d8bbf 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -32,7 +32,8 @@ #define MIGRAPHX_ALLOW_ATOMIC_CAS 0 #endif -#define MIGRAPHX_ATOMIC_CAS_WARNING() MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow") +#define MIGRAPHX_ATOMIC_CAS_WARNING() \ + MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow") namespace migraphx { From 66e8053c7956cc508d9bd4706855e7b48f94bd5c Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 11 Apr 2024 10:06:58 -0700 Subject: [PATCH 59/59] Suppress tidy warnings --- .../include/migraphx/kernels/scatter_reduction_modes.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index 313978d8bbf..b0236f92f2c 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -29,9 +29,11 @@ #include #ifndef MIGRAPHX_ALLOW_ATOMIC_CAS +// NOLINTNEXTLINE #define MIGRAPHX_ALLOW_ATOMIC_CAS 0 #endif +// NOLINTNEXTLINE #define MIGRAPHX_ATOMIC_CAS_WARNING() \ MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow")