diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 7f0c554a30f..16a324b8d9a 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -107,6 +107,10 @@ Disables the ``schedule`` pass. Set to "1", "enable", "enabled", "yes", or "true" to use. Disables the ``fuse_reduce`` pass. +.. envvar:: MIGRAPHX_ENABLE_SPLIT_REDUCE +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable split_reduce. + .. envvar:: MIGRAPHX_ENABLE_NHWC Set to "1", "enable", "enabled", "yes", or "true" to use. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 89000b1781e..f269aa8cb38 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -77,6 +77,7 @@ add_library(migraphx operation.cpp optimize_module.cpp pad_calc.cpp + param_utils.cpp pass.cpp pass_manager.cpp permutation.cpp @@ -94,6 +95,7 @@ add_library(migraphx replace_allocate.cpp rewrite_reduce.cpp simplify_qdq.cpp + split_reduce.cpp sqlite.cpp rewrite_gelu.cpp rewrite_low_precision.cpp diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 15e78d338d5..10b560ff52e 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -162,7 +162,7 @@ static std::vector append_pointwise_module(instruction_ref ins, input_map[input] = map_ins[param]; } } - pm->replace_return(pm->insert_instructions(last, xm, map_ins)); + pm->replace_return(pm->insert_instructions(last, xm, &map_ins)); return inputs; } diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 4e1602a1964..7f60d5ebe70 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -83,23 +83,6 @@ struct fused_reduce }; MIGRAPHX_REGISTER_OP(fused_reduce); -static std::unordered_map -get_ins_param_map(const std::vector& inputs, const_module_ref sm) -{ - std::unordered_map result; - auto names = sm->get_parameter_names(); - std::sort(names.begin(), names.end()); - assert(names.size() == inputs.size()); - std::transform(names.begin(), - names.end(), - inputs.begin(), - std::inserter(result, result.end()), - [&](const auto& name, auto input) { - return std::make_pair(input, sm->get_parameter(name)); - }); - return result; -} - static void insert_params(module_ref sm, const std::vector& inputs, std::unordered_map& map_ins) @@ -119,7 +102,7 @@ static auto insert_ins_in_submodule(module_ref sm, std::unordered_map& map_ins) { insert_params(sm, ins->inputs(), map_ins); - return sm->add_instructions({ins}, map_ins); + return sm->add_instructions({ins}, &map_ins); } static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) @@ -136,12 +119,12 @@ insert_module_in_submodule(module_ref sm, module::inserter insert = nullptr) { insert_params(sm, inputs, map_ins); - auto param_map = get_ins_param_map(inputs, m); + auto param_map = m->get_ins_param_map(inputs); for(auto&& [input, param] : param_map) { map_ins[param] = map_ins.at(input); } - return sm->add_instructions(m, map_ins, std::move(insert)); + return sm->add_instructions(m, &map_ins, std::move(insert)); } static auto diff --git a/src/include/migraphx/liveness.hpp b/src/include/migraphx/liveness.hpp new file mode 100644 index 00000000000..6d9715a8a10 --- /dev/null +++ b/src/include/migraphx/liveness.hpp @@ -0,0 +1,77 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// This will do liveness analysis on the module, and it will call the +// function `f` with the instruction and the set of the other instructions +// that are live +template +void liveness(const module& m, F f) +{ + auto implicit_deps = m.calc_implicit_deps(); + std::unordered_set live_set; + auto rp = reverse(m); + for(auto rins : iterator_for(rp)) // NOLINT + { + // The base iterator is one ahead, so we need to use the previous iterator + auto ins = std::prev(rins.base()); + // Add live variables + auto add_live_variables = [&](const auto& inputs) { + for(auto input : inputs) + { + auto i = instruction::get_output_alias(input); + // Skip if variable comes from parent + if(not m.has_instruction(i)) + continue; + live_set.insert(i); + } + }; + add_live_variables(ins->inputs()); + add_live_variables(implicit_deps[ins]); + // Remove last usage + auto it = live_set.find(ins); + if(it != live_set.end()) + { + live_set.erase(it); + f(ins, live_set); + } + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 838ee9faff5..f9d41121159 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -50,6 +50,8 @@ struct module_impl; using parameter_map = std::unordered_map; using ins_dep_map = std::unordered_map>; +struct module_with_inputs; + /** * @brief Stores the instruction stream */ @@ -127,38 +129,38 @@ struct MIGRAPHX_EXPORT module std::vector add_instructions(const std::vector& instructions, - std::unordered_map map_ins = {}, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector add_instructions(const_module_ref m, - std::unordered_map map_ins = {}, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map map_ins = {}, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map map_ins = {}, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map map_ins = {}, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map map_ins = {}, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); template instruction_ref add_literal(Ts&&... xs) @@ -186,6 +188,8 @@ struct MIGRAPHX_EXPORT module instruction_ref get_parameter(std::string name) const; + std::vector get_parameters() const; + void rename_parameter(instruction_ref ins, const std::string& name); std::unordered_map get_parameter_shapes() const; @@ -205,6 +209,28 @@ struct MIGRAPHX_EXPORT module void finalize(std::vector& contexts); + /// Create a mapping from the input instruction to the corresponding + /// parameter instruction. Use the `reverse` flag to reverse the lookup + /// to be from parameter instruction to input instread. + std::unordered_map + get_ins_param_map(const std::vector& inputs, bool reverse = false) const; + + using with_inputs = module_with_inputs; + + /// This will split the module into two parts at the instruction splits. + /// Each split instruction becomes an input parameter in the second + /// module. As such the inputs instructions to the second module will use + /// the split instructions as input placeholders that can be replaced + /// later. + std::array split(const std::vector& args, + const std::vector& splits) const; + + /// This will split the module in 3 parts using different split + /// instruction for each additional module. + std::array split(const std::vector& args, + const std::vector& splits1, + const std::vector& splits2) const; + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, @@ -266,6 +292,21 @@ struct MIGRAPHX_EXPORT module std::unique_ptr impl; }; +struct module_with_inputs +{ + module mod; + std::vector inputs; + /// Replace the instruction in the inputs with rep + void replace(instruction_ref ins, instruction_ref rep); + /// Replace the input instructions using the map_ins to lookup the replacement + void replace(const std::unordered_map& map_ins); + + /// Replace the input instructions of the keys with the instructions + /// passed as values. Both vectors should be in the same order. + void replace(const std::vector& keys, + const std::vector& values); +}; + inline module& get_module(module& m) { return m; } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp new file mode 100644 index 00000000000..1552c28300b --- /dev/null +++ b/src/include/migraphx/param_utils.hpp @@ -0,0 +1,42 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::string param_name(std::size_t i, const std::string& prefix = "x"); + +void sort_params(std::vector& params); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/include/migraphx/pass_manager.hpp b/src/include/migraphx/pass_manager.hpp index d930cc2416f..fdbdc123a12 100644 --- a/src/include/migraphx/pass_manager.hpp +++ b/src/include/migraphx/pass_manager.hpp @@ -39,7 +39,7 @@ struct module_pass_manager module_pass_manager(const module_pass_manager&) = delete; virtual module& get_module() = 0; virtual module* create_module(const std::string& name) = 0; - virtual module* create_module(const std::string& name, const module& m) = 0; + virtual module* create_module(const std::string& name, module m) = 0; virtual module* get_common_parent() = 0; virtual module* get_root_module() = 0; virtual void run_pass(const pass& p) = 0; diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp new file mode 100644 index 00000000000..687f363b0d2 --- /dev/null +++ b/src/include/migraphx/split_reduce.hpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module_pass_manager; + +/// For large reductions that are larger than the split_size, this pass will +/// split the fused_reduce operators so that the reduction will happen across +/// multiple compute units gaining better occupancy for targets with many +/// compute units. Since the reduction is split across compute units, any +/// elementwise operators will be split into separate operators as well due to +/// needing global synchronization. +struct MIGRAPHX_EXPORT split_reduce +{ + std::size_t split_size = 8192; + std::string name() const { return "split_reduce"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP diff --git a/src/memory_coloring.cpp b/src/memory_coloring.cpp index 733e39d28d2..3753f549000 100644 --- a/src/memory_coloring.cpp +++ b/src/memory_coloring.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -43,42 +44,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING); using instruction_set = std::unordered_set; using instruction_set_map = std::unordered_map; -// This will do liveness analysis on the module, and it will call the -// function `f` with the instruction and the set of the other instructions -// that are live -template -void liveness(const module& m, F f) -{ - auto implicit_deps = m.calc_implicit_deps(); - instruction_set live_set; - auto rp = reverse(m); - for(auto rins : iterator_for(rp)) // NOLINT - { - // The base iterator is one ahead, so we need to use the previous iterator - auto ins = std::prev(rins.base()); - // Add live variables - auto add_live_variables = [&](const auto& inputs) { - for(auto input : inputs) - { - auto i = instruction::get_output_alias(input); - // Skip if variable comes from parent - if(not m.has_instruction(i)) - continue; - live_set.insert(i); - } - }; - add_live_variables(ins->inputs()); - add_live_variables(implicit_deps[ins]); - // Remove last usage - auto it = live_set.find(ins); - if(it != live_set.end()) - { - live_set.erase(it); - f(ins, live_set); - } - } -} - // This will build the conflict table or interference graph. This is // essentially a map from one instruction to a set of instruction that are // used together. Each instruction will be the allocation instruction. diff --git a/src/module.cpp b/src/module.cpp index f0fa8ac8a7b..8c28c15d22d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ #include +#include #include #include #include @@ -33,11 +34,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -204,7 +207,7 @@ static std::vector insert_generic_instructions_impl(module& m, instruction_ref ins, Range&& instructions, - std::unordered_map map_ins, + std::unordered_map& map_ins, Inserter insert) { assert(m.has_instruction(ins) or is_end(ins, m.end())); @@ -261,20 +264,16 @@ static std::vector insert_generic_instructions(module& m, instruction_ref ins, Range&& instructions, - std::unordered_map map_ins, + std::unordered_map& map_ins, module::inserter insert) { if(insert == nullptr) - return insert_generic_instructions_impl(m, - ins, - static_cast(instructions), - std::move(map_ins), - [](module& mm, auto&&... xs) { - return mm.insert_instruction( - std::forward(xs)...); - }); + return insert_generic_instructions_impl( + m, ins, static_cast(instructions), map_ins, [](module& mm, auto&&... xs) { + return mm.insert_instruction(std::forward(xs)...); + }); return insert_generic_instructions_impl( - m, ins, static_cast(instructions), std::move(map_ins), insert); + m, ins, static_cast(instructions), map_ins, insert); } instruction_ref module::add_instruction(const operation& op, std::vector args) @@ -423,61 +422,71 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d std::vector module::add_instructions(const std::vector& instructions, - std::unordered_map map_ins, + std::unordered_map* map_ins, module::inserter insert) { - return this->insert_instructions( - this->end(), instructions, std::move(map_ins), std::move(insert)); + return this->insert_instructions(this->end(), instructions, map_ins, std::move(insert)); } std::vector module::add_instructions(const_module_ref m, - std::unordered_map map_ins, + std::unordered_map* map_ins, module::inserter insert) { - return this->insert_instructions(this->end(), m, std::move(map_ins), std::move(insert)); + return this->insert_instructions(this->end(), m, map_ins, std::move(insert)); } std::vector module::add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map map_ins, + std::unordered_map* map_ins, module::inserter insert) { - return this->insert_instructions( - this->end(), start, last, std::move(map_ins), std::move(insert)); + return this->insert_instructions(this->end(), start, last, map_ins, std::move(insert)); } std::vector module::insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map map_ins, + std::unordered_map* map_ins, module::inserter insert) { - return insert_generic_instructions( - *this, ins, instructions, std::move(map_ins), std::move(insert)); + std::unordered_map default_map_ins; + return insert_generic_instructions(*this, + ins, + instructions, + map_ins == nullptr ? default_map_ins : *map_ins, + std::move(insert)); } std::vector module::insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map map_ins, + std::unordered_map* map_ins, module::inserter insert) { - return insert_generic_instructions( - *this, ins, iterator_for(*m), std::move(map_ins), std::move(insert)); + std::unordered_map default_map_ins; + return insert_generic_instructions(*this, + ins, + iterator_for(*m), + map_ins == nullptr ? default_map_ins : *map_ins, + std::move(insert)); } std::vector module::insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map map_ins, + std::unordered_map* map_ins, module::inserter insert) { auto r = range(start, last); - return insert_generic_instructions( - *this, ins, iterator_for(r), std::move(map_ins), std::move(insert)); + std::unordered_map default_map_ins; + return insert_generic_instructions(*this, + ins, + iterator_for(r), + map_ins == nullptr ? default_map_ins : *map_ins, + std::move(insert)); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } @@ -519,6 +528,7 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name, instruction_ref module::replace_return(std::vector args) { + assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); auto last = std::prev(this->end()); // If there is no return then add a return if(last->name() != "@return") @@ -589,6 +599,16 @@ instruction_ref module::get_parameter(std::string name) const return this->end(); } +std::vector module::get_parameters() const +{ + std::vector result; + auto refs = iterator_for(*this); + std::copy_if(refs.begin(), refs.end(), std::back_inserter(result), [&](instruction_ref ins) { + return ins->name() == "@param"; + }); + return result; +} + void module::rename_parameter(instruction_ref ins, const std::string& name) { assert(ins->name() == "@param"); @@ -732,6 +752,200 @@ void module::finalize(std::vector& contexts) << std::endl; } +std::unordered_map +module::get_ins_param_map(const std::vector& inputs, bool reverse) const +{ + std::unordered_map result; + auto params = this->get_parameters(); + assert(params.size() == inputs.size()); + sort_params(params); + if(reverse) + { + std::transform( + params.begin(), + params.end(), + inputs.begin(), + std::inserter(result, result.end()), + [&](instruction_ref param, auto input) { return std::make_pair(param, input); }); + } + else + { + std::transform( + params.begin(), + params.end(), + inputs.begin(), + std::inserter(result, result.end()), + [&](instruction_ref param, auto input) { return std::make_pair(input, param); }); + } + return result; +} + +static std::vector +select_params(const std::vector& instructions, + const std::unordered_map& param_map) +{ + std::vector result; + std::vector params; + std::copy_if(instructions.begin(), + instructions.end(), + std::back_inserter(params), + [&](instruction_ref ins) { return contains(param_map, ins); }); + sort_params(params); + std::transform(params.begin(), + params.end(), + std::back_inserter(result), + [&](instruction_ref ins) { return param_map.at(ins); }); + return result; +} + +static std::array +generic_split(const module& m, + const std::vector& args, + const std::vector& splits, + std::unordered_map* map_ins = nullptr) +{ + std::unordered_map param_map = + m.get_ins_param_map(args, true); + + std::unordered_set selected_instructions; + fix([&](auto self, const std::vector& inputs) { + for(auto input : inputs) + { + if(contains(selected_instructions, input)) + continue; + selected_instructions.insert(input); + self(input->inputs()); + } + })(splits); + + std::vector instructions1; + // TODO: copy_if + for(auto ins : iterator_for(m)) + { + if(not contains(selected_instructions, ins)) + continue; + instructions1.push_back(ins); + } + + std::vector inputs1 = select_params(instructions1, param_map); + module m1; + std::unordered_map map_ins1; + m1.add_instructions(instructions1, &map_ins1); + std::vector outputs; + std::transform(splits.begin(), + splits.end(), + std::back_inserter(outputs), + [&](instruction_ref ins) { return map_ins1.at(ins); }); + m1.add_return(outputs); + + std::vector instructions2; + for(auto ins : iterator_for(m)) + { + if(contains(selected_instructions, ins)) + continue; + // Input params can be used in both modules + std::vector input_params; + std::copy_if(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(input_params), + [&](instruction_ref input) { + if(input->name() != "@param") + return false; + return not contains(instructions2, input); + }); + instructions2.insert(instructions2.end(), input_params.begin(), input_params.end()); + instructions2.push_back(ins); + } + + std::vector inputs2 = select_params(instructions2, param_map); + inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); + module m2; + std::size_t n = 0; + std::unordered_map map_ins2; + for(auto ins : splits) + map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); + for(auto ins : iterator_for(m)) + { + if(ins->name() != "@param") + continue; + if(not contains(instructions2, ins)) + continue; + map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); + } + auto r = m2.add_instructions(instructions2, &map_ins2); + m2.add_return(r); + if(map_ins != nullptr) + *map_ins = map_ins2; + return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; +} + +std::array module::split(const std::vector& args, + const std::vector& splits) const +{ + return generic_split(*this, args, splits); +} + +std::array module::split(const std::vector& args, + const std::vector& splits1, + const std::vector& splits2) const +{ + std::unordered_map map_ins; + auto mods1 = generic_split(*this, args, splits1, &map_ins); + + assert(all_of(mods1[0].inputs, [&](auto ins) { return contains(args, ins); })); + assert(all_of(mods1[1].inputs, + [&](auto ins) { return contains(args, ins) or contains(splits1, ins); })); + + std::vector new_splits2; + std::transform(splits2.begin(), splits2.end(), std::back_inserter(new_splits2), [&](auto ins) { + return map_ins.at(ins); + }); + + auto mods2 = mods1[1].mod.split(mods1[1].inputs, new_splits2); + // Replace new splits with old splits + mods2[1].replace(new_splits2, splits2); + + assert(all_of(mods2[0].inputs, + [&](auto ins) { return contains(args, ins) or contains(splits1, ins); })); + assert(all_of(mods2[1].inputs, [&](auto ins) { + return contains(args, ins) or contains(splits1, ins) or contains(splits2, ins); + })); + + return {{std::move(mods1[0]), std::move(mods2[0]), std::move(mods2[1])}}; +} + +void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) +{ + auto it = std::find(inputs.begin(), inputs.end(), ins); + if(it == inputs.end()) + return; + assert((*it)->get_shape().lens() == rep->get_shape().lens()); + *it = rep; +} +void module_with_inputs::replace( + const std::unordered_map& map_ins) +{ + for(auto& ins : inputs) + { + if(not contains(map_ins, ins)) + continue; + assert(ins->get_shape().lens() == map_ins.at(ins)->get_shape().lens()); + ins = map_ins.at(ins); + } +} +void module_with_inputs::replace(const std::vector& keys, + const std::vector& values) +{ + for(auto& ins : inputs) + { + auto it = std::find(keys.begin(), keys.end(), ins); + if(it == keys.end()) + continue; + assert(ins->get_shape().lens() == values[it - keys.begin()]->get_shape().lens()); + ins = values[it - keys.begin()]; + } +} + void module::debug_print() const { std::cout << *this << std::endl; } void module::debug_print(instruction_ref ins, diff --git a/src/param_utils.cpp b/src/param_utils.cpp new file mode 100644 index 00000000000..61302a0afba --- /dev/null +++ b/src/param_utils.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::string param_name(std::size_t i, const std::string& prefix) +{ + assert(i < 10); + return prefix + std::to_string(i); +} + +void sort_params(std::vector& params) +{ + std::sort(params.begin(), params.end(), by(std::less<>{}, [](instruction_ref ins) { + const auto& param = any_cast(ins->get_operator()); + return param.parameter; + })); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index be62bc4c01e..3748e094773 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -99,10 +99,10 @@ struct module_pm : module_pass_manager return prog->create_module(name); } - virtual module* create_module(const std::string& name, const module& m) override + virtual module* create_module(const std::string& name, module m) override { assert(prog); - return prog->create_module(name, m); + return prog->create_module(name, std::move(m)); } virtual module* get_common_parent() override { return common_parent; } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp new file mode 100644 index 00000000000..f2bcf343131 --- /dev/null +++ b/src/split_reduce.cpp @@ -0,0 +1,210 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct split_fused_reduce +{ + std::vector axes{}; + std::string assign = "assign_none"; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axes, "axes"), f(self.assign, "assign")); + } + + value attributes() const { return {{"prefill", 0}}; } + + shape compute_shape(const std::vector& inputs, std::vector mods) const + { + if(mods.size() != 1) + MIGRAPHX_THROW("should have one submodule."); + const auto* sm = mods.front(); + if(sm->get_output_shapes().size() != 1) + MIGRAPHX_THROW("Only one output supported"); + auto names = sm->get_parameter_names(); + check_shapes{inputs, *this}.has(names.size()).same_ndims(); + std::sort(names.begin(), names.end()); + auto shapes = sm->get_parameter_shapes(); + // Check dimension matches for each input + if(not equal(names, inputs, [&](const auto& name, const auto& input) { + return shapes.at(name).lens() == input.lens(); + })) + MIGRAPHX_THROW("Dimenstion does not match the submodule."); + const auto& s = inputs.at(0); + auto lens = s.lens(); + if(lens != sm->get_output_shapes().front().lens()) + { + for(const auto& axis : axes) + { + lens[axis] = 1; + } + } + + return shape::from_permutation( + sm->get_output_shapes().front().type(), lens, find_permutation(inputs)); + } + + std::string name() const { return "split_fused_reduce"; } +}; +MIGRAPHX_REGISTER_OP(split_fused_reduce); + +static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } + +static std::vector find_split(const_module_ref rm) +{ + std::vector result; + auto reduce_ins = std::find_if(rm->begin(), rm->end(), &is_reduce); + if(reduce_ins == rm->end()) + return result; + // Bail if there is more than one reduce for now + // TODO: Support multiple reductions + if(std::any_of(std::next(reduce_ins), rm->end(), &is_reduce)) + return result; + // Only handle reduce_sum for now + // TODO: Support other reduction types + if(reduce_ins->name() != "reduce_sum") + return result; + result.push_back(reduce_ins); + return result; +} + +static std::vector get_alive(const_module_ref rm, + const std::vector& splits) +{ + std::vector result; + bool stop = false; + liveness(*rm, [&](auto ins, const auto& live_set) { + if(stop) + return; + if(not contains(splits, ins)) + return; + std::copy_if(live_set.begin(), + live_set.end(), + std::back_inserter(result), + [](instruction_ref live) { return live->name() != "@param"; }); + stop = true; + }); + return result; +} + +static std::string assign_op(const std::vector& splits) +{ + static std::unordered_map m = { + {"reduce_sum", "assign_add"}, + {"reduce_mean", "assign_add"}, + {"reduce_prod", "assign_mul"}, + {"reduce_max", "assign_max"}, + {"reduce_min", "assign_min"}, + }; + return m.at(splits.front()->name()); +} + +static std::vector +insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& mwi) +{ + auto param_map = mwi.mod.get_ins_param_map(mwi.inputs, true); + return m.insert_instructions(ins, &mwi.mod, ¶m_map); +} + +static std::size_t get_reduce_size(const_module_ref rm) +{ + auto ins = std::find_if(rm->begin(), rm->end(), &is_reduce); + assert(ins != rm->end()); + return ins->inputs().front()->get_shape().elements() / ins->get_shape().elements(); +} + +void split_reduce::apply(module_pass_manager& mpm) const +{ + for(auto ins : iterator_for(mpm.get_module())) + { + if(ins->name() != "fused_reduce") + continue; + auto* rm = ins->module_inputs().front(); + if(get_reduce_size(rm) < split_size) + continue; + auto splits = find_split(rm); + if(splits.empty()) + continue; + // Only use split reduce with float for now + // TODO: Support half and other data types + if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { + return split->get_shape().type() == shape::float_type; + })) + continue; + auto v = ins->get_operator().to_value(); + auto axes = v["axes"].to_vector(); + + auto alive = get_alive(rm, splits); + + std::array mods; + if(not alive.empty()) + { + auto mods3 = rm->split(ins->inputs(), alive, splits); + auto r = insert_module_inline(mpm.get_module(), ins, mods3[0]); + mods3[1].replace(alive, r); + mods3[2].replace(alive, r); + mods = {std::move(mods3[1]), std::move(mods3[2])}; + } + else + { + mods = rm->split(ins->inputs(), splits); + } + + auto* splitm = mpm.create_module(rm->name() + "_split", std::move(mods[0].mod)); + splitm->set_bypass(); + + // Insert split reduce + auto split_reduce = mpm.get_module().insert_instruction( + ins, + make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}), + mods[0].inputs, + {splitm}); + + mods[1].replace(splits.front(), split_reduce); + auto replaced = insert_module_inline(mpm.get_module(), ins, mods[1]); + assert(replaced.size() == 1); + mpm.get_module().replace_instruction(ins, replaced.front()); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/split_single_dyn_dim.cpp b/src/split_single_dyn_dim.cpp index 0bd21521ed5..a5707c38674 100644 --- a/src/split_single_dyn_dim.cpp +++ b/src/split_single_dyn_dim.cpp @@ -140,7 +140,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const auto static_shape = dyn_param_shape.to_static(dim_size); map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape); } - auto outputs = submod->add_instructions(mm, map_ins); + auto outputs = submod->add_instructions(mm, &map_ins); submod->add_return({outputs}); submodules.push_back(submod); } diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index d0897187336..bbea66f24ad 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -358,7 +358,7 @@ fold_pointwise_mod(instruction_ref pm_ins, pm->get_parameter(name), parent_mod->add_parameter(name, input->get_shape().as_standard())); }); - return parent_mod->insert_instructions(parent_mod->end(), pm, param_map); + return parent_mod->insert_instructions(parent_mod->end(), pm, ¶m_map); } // Whitelist supported fusion options, including imposing type constraints diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index b5306e681ed..49505bcf8be 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_REGISTER_OP(hip_allocate) +MIGRAPHX_REGISTER_OP(hip_fill) MIGRAPHX_REGISTER_OP(hip_sync_stream) MIGRAPHX_REGISTER_OP(hip_copy_to_gpu) MIGRAPHX_REGISTER_OP(hip_copy_from_gpu) @@ -246,6 +247,14 @@ void gpu_sync() void gpu_sync(const context& ctx) { ctx.finish(); } +void hip_async_memset(context& ctx, const argument& dst, int value) +{ + std::size_t dst_size = dst.get_shape().bytes(); + auto status = hipMemsetAsync(dst.data(), value, dst_size, ctx.get_stream().get()); + if(status != hipSuccess) + MIGRAPHX_THROW("Gpu fill failed: " + hip_error(status)); +} + void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind) { std::size_t src_size = src.get_shape().bytes(); @@ -293,6 +302,13 @@ argument get_preallocation(context& ctx, const std::string& id) return ctx.get_current_device().preallocations.at(id); } +void gpu_fill(context& ctx, const argument& dst, int value) +{ + // TODO: Handle non-packed tensor when value is not 0 + assert(dst.get_shape().packed() and value == 0); + hip_async_memset(ctx, dst, value); +} + void store_preallocated_param(context& ctx, const std::string& id, const argument& a) { ctx.get_current_device().preallocations[id] = a; diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index d51df82005a..acd7525d620 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -59,6 +59,8 @@ MIGRAPHX_GPU_EXPORT void copy_from_gpu(context& ctx, const argument& src, const MIGRAPHX_GPU_EXPORT argument get_preallocation(context& ctx, const std::string& id); +MIGRAPHX_GPU_EXPORT void gpu_fill(context& ctx, const argument& dst, int value = 0); + struct hip_allocate { shape s; @@ -81,6 +83,30 @@ struct hip_allocate } }; +struct hip_fill +{ + int value = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.value, "value")); + } + + std::string name() const { return "hip::fill"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(1); + return inputs.front(); + } + argument compute(context& ctx, const shape&, const std::vector& args) const + { + gpu_fill(ctx, args.front(), value); + return args.front(); + } + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } +}; + struct hip_sync_stream { diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index fe00cb3fcc4..0dc8e34b855 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -27,6 +27,8 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -131,6 +133,67 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) return wavefront_size; } +/// This will adjust the input shapes so a partial reduction is done per workgroup. +/// This is done by splitting the reduction axis so each split group becomes +/// part of the batch. So if we want to do a split redution of a tensor +/// {K}, then this will create a tensor of {K/N, N} where N is the number of +/// split groups. To compute the number of split groups it finds the largest +/// divisor that can divide K to make it less than min_size. +static std::vector split_reduce(const std::vector& inputs, + std::size_t min_size = 1024) +{ + std::vector result; + auto input_shape = inputs.front(); + const auto& reduce_shape = inputs[inputs.size() - 2]; + const auto& output_shape = inputs[inputs.size() - 1]; + + auto is = range(reduce_shape.lens().size()); + using array_type = std::array; + auto initial = array_type{std::numeric_limits::max(), + std::numeric_limits::max()}; + auto faxis = transform_accumulate( + is.begin(), is.end(), initial, MIGRAPHX_LIFT(std::min), [&](auto i) -> array_type { + if(input_shape.lens()[i] == output_shape.lens()[i]) + return initial; + return {input_shape.strides()[i], std::size_t(i)}; + })[1]; + + assert(faxis < reduce_shape.lens().size()); + + std::size_t n = 1; + auto r = input_shape.lens()[faxis]; + auto factors = make_array(2, 3, 5, 7, 11); + while(r > min_size) + { + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + if(it == factors.end()) + break; + r /= *it; + n *= *it; + } + assert(n != 1); + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { + auto lens = s.lens(); + auto strides = s.strides(); + + lens.push_back(n); + if(lens[faxis] == 1) + { + strides.push_back(0); + } + else + { + lens[faxis] /= n; + strides.push_back(strides[faxis] * lens[faxis]); + } + + return {s.type(), lens, strides}; + }); + return reduce_dims(normalize_permutation(result)); +} + struct simple_reduce_compiler : compiler { std::vector names() const @@ -231,7 +294,7 @@ extern "C" { MIGRAPHX_GLOBAL void ${kernel}(${params}) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { - fused_reduce(y, partial(${lambda})(xs...)); + fused_reduce(y, ${assign}{}, partial(${lambda})(xs...)); }); } @@ -243,15 +306,18 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) struct fused_reduce_compiler : compiler { - std::vector names() const { return {"fused_reduce"}; } + std::vector names() const { return {"fused_reduce", "split_fused_reduce"}; } operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { + auto assign = v.get("assign", "assign_none"); auto axes = v.at("axes").to_vector(); auto virtual_inputs = inputs; virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); virtual_inputs.push_back(get_output_shape(inputs.front(), axes)); virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs)); + if(assign != "assign_none") + virtual_inputs = split_reduce(virtual_inputs); auto reduce_output_shape = virtual_inputs.back(); virtual_inputs.pop_back(); auto reduction_shape = virtual_inputs.back(); @@ -301,13 +367,14 @@ struct fused_reduce_compiler : compiler auto src = interpolate_string( fused_reduce_kernel, {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"algo", algo}, - {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, - {"lambda", v.at("lambda").to()}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"assign", assign}, + {"algo", algo}, + {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); options.emplace_param("-Wno-float-equal"); return compile_hip_code_object(src, options); } diff --git a/src/targets/gpu/jit/scatter.hpp b/src/targets/gpu/jit/scatter.hpp index ff615331684..2902aa9c599 100644 --- a/src/targets/gpu/jit/scatter.hpp +++ b/src/targets/gpu/jit/scatter.hpp @@ -48,9 +48,7 @@ struct scatter_compiler : compiler options.output = inputs.back(); options.kernel_name = derived().get_kernel_name(op); options.virtual_inputs = inputs; - // The compiler protests the inequality comparison in assign_mul when pertaining to floating - // point, despite it making sense in the context. Thus the warning removal. - options.emplace_param("-Wno-float-equal"); + options.emplace_param("-DMIGRAPHX_ALLOW_ATOMIC_CAS=1"); const auto src = derived().make_interpolated_string(op); return prepend_copy_data_to_output(compile_hip_code_object(src, options)); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index e37f7423147..b2fb0f4b00f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace migraphx { @@ -730,18 +731,18 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu }); } -template -__device__ void fused_reduce(Output output, F f) +template +__device__ void fused_reduce(Output output, Assign assign, F f) { Algo::template run([&](auto out_idx, auto r) { auto result = f(r, out_idx); if constexpr(reduce::is_inner_storage{}) { - r.inner([&](auto& y, auto x) { y = x; })(output, result); + r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); } else { - r.outer([&] { output[out_idx] = implicit_conversion(result); }); + r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); } }); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index 4081c313ad4..b0236f92f2c 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,17 @@ #define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP #include +#include +#include + +#ifndef MIGRAPHX_ALLOW_ATOMIC_CAS +// NOLINTNEXTLINE +#define MIGRAPHX_ALLOW_ATOMIC_CAS 0 +#endif + +// NOLINTNEXTLINE +#define MIGRAPHX_ATOMIC_CAS_WARNING() \ + MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow") namespace migraphx { @@ -42,7 +53,15 @@ struct assign_add template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicAdd(&x, y); + if constexpr(is_same{} or is_same{}) + { + unsafeAtomicAdd(&x, T(y)); + } + else + { + MIGRAPHX_ATOMIC_CAS_WARNING(); + atomicAdd(&x, T(y)); + } } }; @@ -51,13 +70,17 @@ struct assign_mul template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { + MIGRAPHX_ATOMIC_CAS_WARNING(); T old = x; T assumed; do { assumed = old; old = atomicCAS(&x, assumed, assumed * y); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wfloat-equal" } while(assumed != old); +#pragma clang diagnostic pop } }; @@ -66,7 +89,7 @@ struct assign_max template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicMax(&x, y); + atomicMax(&x, T(y)); } }; @@ -75,7 +98,7 @@ struct assign_min template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicMin(&x, y); + atomicMin(&x, T(y)); } }; diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index fcde59841fd..51d924f6a5f 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -174,10 +174,23 @@ struct miopen_apply { check_shape(s, insert_custom_op(it, attrs)); } + if(attrs.contains("prefill")) + { + insert_fill(it, attrs.at("prefill")); + } } copy_params(); } + void insert_fill(instruction_ref ins, value v) const + { + instruction_ref alloc = instruction::get_output_alias(ins, true); + if(alloc == ins) + return; + auto fill = mod->insert_instruction(ins, make_op("hip::fill", {{"value", v}}), alloc); + instruction::replace_argument(ins, alloc, fill); + } + instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const { const auto& custom_op = ins->get_operator(); diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index cc0a136892d..b3ed59df57b 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -54,6 +54,7 @@ #include #include #include +#include #include #include #include @@ -77,6 +78,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) @@ -162,6 +164,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), dead_code_elimination{}, + enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}), + dead_code_elimination{}, fuse_concat{}, dead_code_elimination{}, #ifndef _WIN32 diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index ae3726729d4..7ea6c137cad 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -79,7 +79,7 @@ migraphx::module create_mlir_submodule(const migraphx::module& mmlir) auto param = mmlir.get_parameter(name); map_ins[param] = m.add_parameter(name, param->get_shape().as_standard()); } - auto y = m.add_instructions(&mmlir, map_ins); + auto y = m.add_instructions(&mmlir, &map_ins); m.add_return(y); return m; } diff --git a/test/module_test.cpp b/test/module_test.cpp index a21e565a7b1..f44c089d30d 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -309,6 +309,16 @@ TEST_CASE(parameter_name_order) EXPECT(param_names == names1); } +struct map_ins +{ + using type = std::unordered_map; + map_ins(std::initializer_list x) : m(x) {} + + operator type*() { return &m; } + + type m; +}; + TEST_CASE(insert_instructions_module) { migraphx::shape s{migraphx::shape::int32_type, {1}}; @@ -321,7 +331,7 @@ TEST_CASE(insert_instructions_module) auto x2 = m2.add_parameter("x2", s); m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.insert_instructions(sqrt, &m2, {{x2, x1}}); + m1.insert_instructions(sqrt, &m2, map_ins{{x2, x1}}); EXPECT(std::prev(sqrt)->name() == "sqrt"); EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) == @@ -343,7 +353,7 @@ TEST_CASE(add_instructions_module) auto x2 = m2.add_parameter("x2", s); m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.add_instructions(&m2, {{x2, x1}}); + m1.add_instructions(&m2, map_ins{{x2, x1}}); EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) == 2); @@ -364,7 +374,7 @@ TEST_CASE(add_instructions_range) auto x2 = m2.add_parameter("x2", s); auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.add_instructions(sqrt2, m2.end(), {{x2, x1}}); + m1.add_instructions(sqrt2, m2.end(), map_ins{{x2, x1}}); EXPECT(std::any_of( m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); })); @@ -387,7 +397,7 @@ TEST_CASE(add_instructions_vector) auto x2 = m2.add_parameter("x2", s); auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2}); - m1.add_instructions({sqrt2}, {{x2, x1}}); + m1.add_instructions({sqrt2}, map_ins{{x2, x1}}); EXPECT(std::any_of( m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); })); @@ -448,4 +458,115 @@ TEST_CASE(multiple_module_dependency) EXPECT((sub->validate() == sub->end())); } +TEST_CASE(module_split2) +{ + migraphx::shape s{migraphx::shape::float_type, {1}}; + migraphx::module input_m; + std::vector inputs; + { + auto x1 = input_m.add_parameter("x1", s); + auto x2 = input_m.add_parameter("x2", s); + auto x3 = input_m.add_parameter("x3", s); + auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); + auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); + auto sx3 = input_m.add_instruction(migraphx::make_op("sqrt"), x3); + inputs = {sx1, sx2, sx3}; + } + migraphx::module m; + std::vector splits; + { + auto x1 = m.add_parameter("x1", s); + auto x2 = m.add_parameter("x2", s); + auto x3 = m.add_parameter("x3", s); + auto add = m.add_instruction(migraphx::make_op("add"), x1, x2); + auto mul = m.add_instruction(migraphx::make_op("mul"), add, x3); + m.add_return({mul}); + splits.push_back(add); + } + auto mods = m.split(inputs, splits); + + migraphx::module m1; + { + auto x1 = m1.add_parameter("x1", s); + auto x2 = m1.add_parameter("x2", s); + auto add = m1.add_instruction(migraphx::make_op("add"), x1, x2); + m1.add_return({add}); + } + migraphx::module m2; + { + auto x0 = m2.add_parameter("x0", s); + auto x1 = m2.add_parameter("x1", s); + auto mul = m2.add_instruction(migraphx::make_op("mul"), x0, x1); + m2.add_return({mul}); + } + EXPECT(mods[0].mod.sort() == m1.sort()); + EXPECT(mods[1].mod.sort() == m2.sort()); + + EXPECT(bool{mods[0].inputs[0] == inputs[0]}); + EXPECT(bool{mods[0].inputs[1] == inputs[1]}); + + EXPECT(bool{mods[1].inputs[0] == splits.front()}); + EXPECT(bool{mods[1].inputs[1] == inputs[2]}); +} + +TEST_CASE(module_split3) +{ + migraphx::shape s{migraphx::shape::float_type, {1}}; + migraphx::module input_m; + std::vector inputs; + { + auto x1 = input_m.add_parameter("x1", s); + auto x2 = input_m.add_parameter("x2", s); + auto sx1 = input_m.add_instruction(migraphx::make_op("sqrt"), x1); + auto sx2 = input_m.add_instruction(migraphx::make_op("sqrt"), x2); + inputs = {sx1, sx2}; + } + migraphx::module m; + std::vector splits1; + std::vector splits2; + { + auto x1 = m.add_parameter("x1", s); + auto x2 = m.add_parameter("x2", s); + auto mul = m.add_instruction(migraphx::make_op("mul"), x1, x2); + auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), mul); + auto add = m.add_instruction(migraphx::make_op("add"), sqrt, mul); + m.add_return({add}); + splits1.push_back(mul); + splits2.push_back(sqrt); + } + auto mods = m.split(inputs, splits1, splits2); + + migraphx::module m1; + { + auto x1 = m1.add_parameter("x1", s); + auto x2 = m1.add_parameter("x2", s); + auto mul = m1.add_instruction(migraphx::make_op("mul"), x1, x2); + m1.add_return({mul}); + } + migraphx::module m2; + { + auto x0 = m2.add_parameter("x0", s); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), x0); + m2.add_return({sqrt}); + } + migraphx::module m3; + { + auto x0 = m3.add_parameter("x0", s); + auto x1 = m3.add_parameter("x1", s); + auto add = m3.add_instruction(migraphx::make_op("add"), x0, x1); + m3.add_return({add}); + } + EXPECT(mods[0].mod.sort() == m1.sort()); + EXPECT(mods[1].mod.sort() == m2.sort()); + EXPECT(mods[2].mod.sort() == m3.sort()); + + EXPECT(bool{mods[0].inputs[0] == inputs[0]}); + EXPECT(bool{mods[0].inputs[1] == inputs[1]}); + + EXPECT(bool{mods[1].inputs[0] == splits1.front()}); + + EXPECT(bool{mods[2].inputs[0] == splits2.front()}); + EXPECT(bool{mods[2].inputs[1] == splits1.front()}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp new file mode 100644 index 00000000000..f757391216e --- /dev/null +++ b/test/split_reduce.cpp @@ -0,0 +1,227 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +void run_pass(migraphx::program& p) +{ + migraphx::run_passes(p, + {migraphx::fuse_pointwise{}, + migraphx::fuse_reduce{}, + migraphx::split_reduce{.split_size = 8192}, + migraphx::dead_code_elimination{}}); +} + +void run_fuse_pass(migraphx::program& p) +{ + migraphx::run_passes( + p, + {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); +} + +bool all_instructions_are_local(const migraphx::module& m) +{ + return std::all_of(m.begin(), m.end(), [&](const auto& ins) { + return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { + return m.has_instruction(input); + }); + }); +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + const std::string& assign, + F f) +{ + auto* rm = p.create_module(name); + auto* mm = p.get_main_module(); + rm->set_bypass(); + std::vector params; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { + return rm->add_parameter( + "x" + std::to_string(params.size()), + migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + }); + auto r = f(rm, params, axes); + rm->add_return({r}); + EXPECT(all_instructions_are_local(*rm)); + return mm->add_instruction( + migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), + inputs, + {rm}); +} + +inline auto single_reduce(const std::string& name) +{ + return [=](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs); + }; +} + +TEST_CASE(single) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + mm->add_return({rsum}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0_split", {x}, {2}, "assign_add", single_reduce("reduce_sum")); + mm->add_return({rsum}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(fused) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = mm->add_instruction(migraphx::make_op("add"), x, rsumb); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce(p2, + "main:reduce_sum0:main:pointwise0_split", + {x}, + {2}, + "assign_add", + single_reduce("reduce_sum")); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = add_pointwise(p2, mm, "main:pointwise0", {x, rsumb}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(small) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 1024}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = mm->add_instruction(migraphx::make_op("add"), x, rsumb); + mm->add_return({add}); + } + migraphx::program p2 = p1; + run_fuse_pass(p2); + run_pass(p1); + + EXPECT(p1 == p2); +} + +TEST_CASE(split_pointwise) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto sqrt = mm->add_instruction(migraphx::make_op("sqrt"), x); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sqrt); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = mm->add_instruction(migraphx::make_op("add"), sqrt, rsumb); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto sqrt = add_pointwise(p2, mm, "main:pointwise0", {x}, single_pointwise("sqrt")); + auto rsum = add_reduce(p2, + "main:pointwise0:main:reduce_sum0:main:pointwise1_split", + {sqrt}, + {2}, + "assign_add", + single_reduce("reduce_sum")); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum); + auto add = add_pointwise(p2, mm, "main:pointwise1", {sqrt, rsumb}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(sequence_reduces) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsum1b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1); + auto sub = mm->add_instruction(migraphx::make_op("sub"), x, rsum1b); + auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sub); + auto rsum2b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum2); + auto add = mm->add_instruction(migraphx::make_op("add"), rsum2b, x); + mm->add_return({add}); + } + migraphx::program p2 = p1; + run_fuse_pass(p2); + run_pass(p1); + + EXPECT(p1 == p2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); }