From 9077e742f15aa3bb5bb166b0781207dda9bcf484 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Mon, 18 Mar 2024 16:29:14 -0500 Subject: [PATCH] Improve reduction fusion with reshape operators (#2698) --- src/common_dims.cpp | 51 ++++- src/fuse_pointwise.cpp | 53 +----- src/fuse_reduce.cpp | 99 +++++++++- src/include/migraphx/common_dims.hpp | 17 +- src/include/migraphx/module.hpp | 23 ++- src/include/migraphx/rewrite_reshapes.hpp | 206 +++++++++++++++++++++ src/module.cpp | 69 +++++-- test/common_dims.cpp | 29 ++- test/fuse_pointwise.cpp | 43 ++++- test/fuse_reduce.cpp | 215 +++++++++++++++++++++- 10 files changed, 714 insertions(+), 91 deletions(-) create mode 100644 src/include/migraphx/rewrite_reshapes.hpp diff --git a/src/common_dims.cpp b/src/common_dims.cpp index a113e6b4386..3c3e8db1c3a 100644 --- a/src/common_dims.cpp +++ b/src/common_dims.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -43,12 +43,6 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim) return it; } -template -static auto elements(const Range& r) -{ - return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{}); -} - struct common_dim_state { common_dim_state(const std::vector& pdims, @@ -152,5 +146,48 @@ common_dims common_dims::compute(const std::vector& dims1, return cd; } +const std::vector>* common_dims::get_axes_map(std::size_t n) const +{ + if(axes_map1.size() == n) + return &axes_map1; + if(axes_map2.size() == n) + return &axes_map2; + return nullptr; +} + +std::vector +common_dims::get_dimensions_for(const std::vector& idims) const +{ + if(dims.size() == idims.size()) + return idims; + if(elements(dims) == elements(idims)) + return dims; + // Bail for now since its ambiguous which axes map can be used + // TODO: Check for similiarity + if(axes_map1.size() == axes_map2.size()) + return {}; + const auto* axes_map = get_axes_map(idims.size()); + if(axes_map == nullptr) + return {}; + auto xdims = dims; + for(auto i : range(axes_map->size())) + { + auto dim = idims[i]; + const auto& axes = (*axes_map)[i]; + if(axes.size() == 1) + { + xdims[axes.front()] = dim; + } + else if(dim == 1) + { + for(auto axis : axes) + xdims[axis] = 1; + } + } + if(elements(xdims) == elements(idims)) + return xdims; + return {}; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 90ad475f7e2..15e78d338d5 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -25,14 +25,13 @@ #include #include #include -#include #include #include #include #include #include #include -#include +#include #include MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) @@ -193,52 +192,13 @@ static bool find_pointwise_modules(module& m) } return changed; } + namespace { -struct find_pointwise_reshape_pointwise +struct pointwise_reshape : rewrite_reshapes_base { - auto matcher() const - { - auto reshape = - match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once()); - auto skip_contiguous = [](auto... ms) { - return match::arg(0)(match::skip(match::name("contiguous")(match::used_once()))(ms...)); - }; - auto pointwise = match::name("pointwise")(match::used_once()); - auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape"); - return match::name("pointwise")(match::any_of[match::inputs()](reshape_pointwise)); - } - - void apply(module& m, const match::matcher_result& r) const - { - auto ins = r.result; - auto x_ins = r.instructions["x"]; - auto reshape_ins = r.instructions["reshape"]; - - auto cd = common_dims::compute(ins->get_shape().lens(), x_ins->get_shape().lens()); - if(cd.dims.empty()) - return; - - auto reshape_input = [&](const auto& ins_to_insert) { - return [&](auto input) { - return m.insert_instruction( - ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input); - }; - }; - auto x_inputs = x_ins->inputs(); - std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins)); - auto new_x_ins = - m.insert_instruction(x_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs()); - - auto inputs = ins->inputs(); - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { - if(input == reshape_ins) - return new_x_ins; - return reshape_input(ins)(input); - }); - auto pw = m.insert_instruction(ins, ins->get_operator(), inputs, ins->module_inputs()); - m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw); - } + static std::string name() { return "pointwise"; } }; + } // namespace void fuse_pointwise::apply(module_pass_manager& mpm) const @@ -252,8 +212,7 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const } for(int i = 0; i < 8; i++) { - match::find_matches(mpm.get_module(), find_pointwise_reshape_pointwise{}); - mpm.run_pass(simplify_reshapes{1}); + mpm.run_pass(rewrite_reshapes{}); if(not find_pointwise_modules(mpm.get_module())) break; mpm.run_pass(dead_code_elimination{}); diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 6f2e6a1b862..4e1602a1964 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -100,11 +101,11 @@ get_ins_param_map(const std::vector& inputs, const_module_ref s } static void insert_params(module_ref sm, - instruction_ref ins, + const std::vector& inputs, std::unordered_map& map_ins) { auto n = sm->get_parameter_shapes().size(); - for(auto input : ins->inputs()) + for(auto input : inputs) { if(contains(map_ins, input)) continue; @@ -117,7 +118,7 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map& map_ins) { - insert_params(sm, ins, map_ins); + insert_params(sm, ins->inputs(), map_ins); return sm->add_instructions({ins}, map_ins); } @@ -129,17 +130,37 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) static auto insert_module_in_submodule(module_ref sm, - instruction_ref ins, - std::unordered_map& map_ins) + const std::vector& inputs, + module_ref m, + std::unordered_map& map_ins, + module::inserter insert = nullptr) { - insert_params(sm, ins, map_ins); - auto* m = ins->module_inputs().front(); - auto param_map = get_ins_param_map(ins->inputs(), m); + insert_params(sm, inputs, map_ins); + auto param_map = get_ins_param_map(inputs, m); for(auto&& [input, param] : param_map) { map_ins[param] = map_ins.at(input); } - return sm->add_instructions(m, map_ins); + return sm->add_instructions(m, map_ins, std::move(insert)); +} + +static auto +insert_module_in_submodule(module_ref sm, + instruction_ref ins, + std::unordered_map& map_ins, + module::inserter insert = nullptr) +{ + return insert_module_in_submodule( + sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert)); +} + +static auto insert_module_in_submodule(module_ref sm, + const std::vector& inputs, + module_ref m, + module::inserter insert = nullptr) +{ + std::unordered_map map_ins; + return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert)); } static std::vector @@ -332,6 +353,65 @@ struct find_reduce_reduce } }; +struct reduce_reshape : rewrite_reshapes_base +{ + static std::string name() { return "fused_reduce"; } + + template + static auto transform_op(Transform t) + { + return [=](module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) { + auto new_op = t(op); + return m.insert_instruction(ins, new_op, inputs, mod_args); + }; + } + + template + static instruction_ref insert(module_pass_manager& mpm, + instruction_ref ins, + const std::vector& inputs, + const AxesMap& am) + { + auto op = any_cast(ins->get_operator()); + std::vector axes; + for(auto axis : op.axes) + { + auto new_axes = am.at(axis); + axes.insert(axes.end(), new_axes.begin(), new_axes.end()); + } + std::sort(axes.begin(), axes.end()); + auto dims = base_dims(inputs); + auto* oldm = ins->module_inputs().front(); + auto* sm = mpm.create_module(oldm->name() + "_reshape"); + insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) { + if(contains(sop.name(), "reduce")) + return make_op(sop.name(), {{"axes", axes}}); + if(sop.name() == "multibroadcast") + return make_op("multibroadcast", {{"out_lens", dims}}); + assert(sop.name() == "pointwise"); + return sop; + })); + return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); + } + + static std::vector base_dims(const std::vector& inputs) + { + auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) { + return i->get_shape().elements(); + })); + return (*input)->get_shape().lens(); + } + + static std::vector base_dims(instruction_ref ins) + { + return base_dims(ins->inputs()); + } +}; + } // namespace void fuse_reduce::apply(module_pass_manager& mpm) const @@ -340,6 +420,7 @@ void fuse_reduce::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); for(int i = 0; i < 4; i++) { + mpm.run_pass(rewrite_reshapes{}); match::find_matches( mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{}); mpm.run_pass(dead_code_elimination{}); diff --git a/src/include/migraphx/common_dims.hpp b/src/include/migraphx/common_dims.hpp index 2d65eb14abe..b8dd3007530 100644 --- a/src/include/migraphx/common_dims.hpp +++ b/src/include/migraphx/common_dims.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,8 @@ #include #include +#include +#include #include namespace migraphx { @@ -39,11 +41,24 @@ struct MIGRAPHX_EXPORT common_dims { static common_dims compute(const std::vector& dims1, const std::vector& dims2); + + /// Map the dimensions into the common higher dimensional space. The + /// dimension doesnt need to have the same number of elements as the + /// common dimension. + std::vector get_dimensions_for(const std::vector& idims) const; + /// Get the corresponding axes map based on the rank of tensor + const std::vector>* get_axes_map(std::size_t n) const; std::vector dims; std::vector> axes_map1; std::vector> axes_map2; }; +template +auto elements(const Range& r) +{ + return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{}); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 7a650c79914..838ee9faff5 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -55,6 +55,11 @@ using ins_dep_map = std::unordered_map& inputs, + const std::vector& mod_args)>; module(const std::string& name = ""); // move constructor @@ -122,32 +127,38 @@ struct MIGRAPHX_EXPORT module std::vector add_instructions(const std::vector& instructions, - std::unordered_map map_ins = {}); + std::unordered_map map_ins = {}, + inserter insert = nullptr); std::vector add_instructions(const_module_ref m, - std::unordered_map map_ins = {}); + std::unordered_map map_ins = {}, + inserter insert = nullptr); std::vector add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map map_ins = {}); + std::unordered_map map_ins = {}, + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map map_ins = {}); + std::unordered_map map_ins = {}, + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map map_ins = {}); + std::unordered_map map_ins = {}, + inserter insert = nullptr); std::vector insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map map_ins = {}); + std::unordered_map map_ins = {}, + inserter insert = nullptr); template instruction_ref add_literal(Ts&&... xs) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp new file mode 100644 index 00000000000..bd6ed1ec27d --- /dev/null +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -0,0 +1,206 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_REWRITE_RESHAPES_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_REWRITE_RESHAPES_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct rewrite_reshapes_base +{ + template + static instruction_ref insert(module_pass_manager& mpm, + instruction_ref ins, + const std::vector& inputs, + const AxesMap&) + { + return mpm.get_module().insert_instruction( + ins, ins->get_operator(), inputs, ins->module_inputs()); + } + + template + static bool supports(instruction_ref, std::vector&, const AxesMap&) + { + return true; + } + + static std::vector base_dims(instruction_ref ins) + { + return ins->get_shape().lens(); + } +}; + +template +struct rewrite_reshapes +{ + std::string name() const { return "rewrite_reshapes"; } + struct find_op_reshape_op + { + std::string op1; + std::string op2; + + auto matcher() const + { + auto reshape = + match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once()); + auto skip_contiguous = [](auto... ms) { + return match::arg(0)(match::skip( + match::name("contiguous", "multibroadcast")(match::used_once()))(ms...)); + }; + auto pointwise = match::name(op1)(match::used_once()); + auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape"); + return match::name(op2)(match::any_of[match::inputs()](reshape_pointwise)); + } + + template + static instruction_ref find_input_if(instruction_ref start, instruction_ref last, F f) + { + while(start != last) + { + if(f(start)) + return start; + if(start->inputs().size() != 1) + return last; + start = start->inputs().front(); + } + return last; + } + + static bool match_input(instruction_ref ins, instruction_ref x_ins) + { + if(ins->inputs().empty()) + return false; + auto input = ins->inputs().front(); + if(input->name() == "contiguous") + return match_input(input, x_ins); + return x_ins == input; + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + auto reshape_ins = r.instructions["reshape"]; + + auto broadcast_ins = find_input_if( + reshape_ins, x_ins, [&](auto i) { return i->name() == "multibroadcast"; }); + const bool has_broadcast = broadcast_ins != x_ins; + if(has_broadcast and not match_input(broadcast_ins, x_ins)) + return; + + auto dims1 = T::base_dims(ins); + auto dims2 = T::base_dims(x_ins); + + if(elements(dims1) != elements(dims2)) + return; + + auto cd = common_dims::compute(T::base_dims(ins), T::base_dims(x_ins)); + if(cd.dims.empty()) + return; + + if(ins->name() != "pointwise" and not T::supports(ins, cd.dims, cd.axes_map1)) + return; + if(x_ins->name() != "pointwise" and not T::supports(x_ins, cd.dims, cd.axes_map2)) + return; + + auto reshape_input = [&](const auto& ins_to_insert) { + return [&](auto input) { + auto dims = cd.get_dimensions_for(input->get_shape().lens()); + return mpm.get_module().insert_instruction( + ins_to_insert, make_op("reshape", {{"dims", dims}}), input); + }; + }; + auto x_inputs = x_ins->inputs(); + std::transform( + x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins)); + auto new_x_ins = insert(mpm, x_ins, x_inputs, cd.axes_map2); + if(has_broadcast) + { + new_x_ins = mpm.get_module().insert_instruction( + x_ins, make_op("multibroadcast", {{"out_lens", cd.dims}}), new_x_ins); + } + + auto inputs = ins->inputs(); + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { + if(input == reshape_ins) + return new_x_ins; + return reshape_input(ins)(input); + }); + auto pw = insert(mpm, ins, inputs, cd.axes_map1); + mpm.get_module().replace_instruction( + ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw); + } + + static bool same_dims(instruction_ref ins) + { + return all_of(ins->inputs(), [&](auto input) { + return input->get_shape().lens() == ins->get_shape().lens(); + }); + } + + template + static instruction_ref insert(module_pass_manager& mpm, + instruction_ref ins, + const std::vector& inputs, + const AxesMap& am) + { + if(ins->name() == "pointwise") + return mpm.get_module().insert_instruction( + ins, ins->get_operator(), inputs, ins->module_inputs()); + return T::insert(mpm, ins, inputs, am); + } + }; + + void apply(module_pass_manager& mpm) const + { + if(T::name() == "pointwise") + { + match::find_matches(mpm, find_op_reshape_op{"pointwise", T::name()}); + } + else + { + match::find_matches(mpm, + find_op_reshape_op{"pointwise", T::name()}, + find_op_reshape_op{T::name(), "pointwise"}, + find_op_reshape_op{T::name(), T::name()}); + } + mpm.run_pass(simplify_reshapes{1}); + mpm.run_pass(eliminate_common_subexpression{}); + mpm.run_pass(dead_code_elimination{}); + } +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_REWRITE_RESHAPES_HPP diff --git a/src/module.cpp b/src/module.cpp index 5091e92320f..f0fa8ac8a7b 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -199,12 +199,13 @@ void module::assign(const module& m) } } -template +template static std::vector -insert_generic_instructions(module& m, - instruction_ref ins, - Range&& instructions, - std::unordered_map map_ins) +insert_generic_instructions_impl(module& m, + instruction_ref ins, + Range&& instructions, + std::unordered_map map_ins, + Inserter insert) { assert(m.has_instruction(ins) or is_end(ins, m.end())); std::vector mod_outputs; @@ -246,7 +247,7 @@ insert_generic_instructions(module& m, break; } - copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args); + copy_ins = insert(m, ins, sins->get_operator(), copy_inputs, mod_args); } map_ins[sins] = copy_ins; } @@ -255,6 +256,27 @@ insert_generic_instructions(module& m, return mod_outputs; } +template +static std::vector +insert_generic_instructions(module& m, + instruction_ref ins, + Range&& instructions, + std::unordered_map map_ins, + module::inserter insert) +{ + if(insert == nullptr) + return insert_generic_instructions_impl(m, + ins, + static_cast(instructions), + std::move(map_ins), + [](module& mm, auto&&... xs) { + return mm.insert_instruction( + std::forward(xs)...); + }); + return insert_generic_instructions_impl( + m, ins, static_cast(instructions), std::move(map_ins), insert); +} + instruction_ref module::add_instruction(const operation& op, std::vector args) { return insert_instruction(impl->instructions.end(), op, std::move(args)); @@ -401,50 +423,61 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d std::vector module::add_instructions(const std::vector& instructions, - std::unordered_map map_ins) + std::unordered_map map_ins, + module::inserter insert) { - return this->insert_instructions(this->end(), instructions, std::move(map_ins)); + return this->insert_instructions( + this->end(), instructions, std::move(map_ins), std::move(insert)); } std::vector module::add_instructions(const_module_ref m, - std::unordered_map map_ins) + std::unordered_map map_ins, + module::inserter insert) { - return this->insert_instructions(this->end(), m, std::move(map_ins)); + return this->insert_instructions(this->end(), m, std::move(map_ins), std::move(insert)); } std::vector module::add_instructions(instruction_ref start, instruction_ref last, - std::unordered_map map_ins) + std::unordered_map map_ins, + module::inserter insert) { - return this->insert_instructions(this->end(), start, last, std::move(map_ins)); + return this->insert_instructions( + this->end(), start, last, std::move(map_ins), std::move(insert)); } std::vector module::insert_instructions(instruction_ref ins, const std::vector& instructions, - std::unordered_map map_ins) + std::unordered_map map_ins, + module::inserter insert) { - return insert_generic_instructions(*this, ins, instructions, std::move(map_ins)); + return insert_generic_instructions( + *this, ins, instructions, std::move(map_ins), std::move(insert)); } std::vector module::insert_instructions(instruction_ref ins, const_module_ref m, - std::unordered_map map_ins) + std::unordered_map map_ins, + module::inserter insert) { - return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins)); + return insert_generic_instructions( + *this, ins, iterator_for(*m), std::move(map_ins), std::move(insert)); } std::vector module::insert_instructions(instruction_ref ins, instruction_ref start, instruction_ref last, - std::unordered_map map_ins) + std::unordered_map map_ins, + module::inserter insert) { auto r = range(start, last); - return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins)); + return insert_generic_instructions( + *this, ins, iterator_for(r), std::move(map_ins), std::move(insert)); } instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); } diff --git a/test/common_dims.cpp b/test/common_dims.cpp index 5b70c86fcd4..1458822d45f 100644 --- a/test/common_dims.cpp +++ b/test/common_dims.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -34,12 +34,27 @@ TEST_CASE(common_d1_less) EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}}); } +void verify_common(const migraphx::common_dims& cd) +{ + EXPECT(cd.get_dimensions_for({2, 32, 40, 8, 8}) == std::vector{2, 32, 40, 8, 8}); + EXPECT(cd.get_dimensions_for({64, 2560}) == std::vector{2, 32, 40, 8, 8}); + EXPECT(cd.get_dimensions_for({2, 32, 1}) == std::vector{2, 32, 1, 1, 1}); + EXPECT(cd.get_dimensions_for({2, 1, 2560}) == std::vector{2, 1, 40, 8, 8}); + EXPECT(cd.get_dimensions_for({2, 8, 2560}) == std::vector{2, 8, 40, 8, 8}); + EXPECT(cd.get_dimensions_for({2, 1, 8, 8}) == std::vector{2, 1, 1, 8, 8}); + EXPECT(cd.get_dimensions_for({2, 32, 8}).empty()); + EXPECT(cd.get_dimensions_for({2, 8, 8, 8}).empty()); + EXPECT(cd.get_dimensions_for({2, 1, 40, 8, 8}) == std::vector{2, 1, 40, 8, 8}); + EXPECT(cd.get_dimensions_for({2, 32, 256, 8, 8}) == std::vector{2, 32, 256, 8, 8}); +} + TEST_CASE(common1) { auto cd = migraphx::common_dims::compute({2, 32, 2560}, {2, 1280, 8, 8}); EXPECT(cd.dims == std::vector{2, 32, 40, 8, 8}); EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2, 3, 4}}); EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}, {4}}); + verify_common(cd); } TEST_CASE(common2) @@ -48,6 +63,18 @@ TEST_CASE(common2) EXPECT(cd.dims == std::vector{2, 32, 40, 8, 8}); EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}, {4}}); EXPECT(cd.axes_map2 == axes_map{{0}, {1}, {2, 3, 4}}); + verify_common(cd); +} + +TEST_CASE(common_same_dims) +{ + auto cd = migraphx::common_dims::compute({{2, 32, 4}}, {64, 2, 2}); + EXPECT(cd.dims == std::vector{2, 32, 2, 2}); + EXPECT(cd.get_dimensions_for({64, 2, 2}) == std::vector{2, 32, 2, 2}); + EXPECT(cd.get_dimensions_for({2, 32, 4}) == std::vector{2, 32, 2, 2}); + // TODO: CHeck for similiarity + EXPECT(cd.get_dimensions_for({2, 32, 1}).empty()); + EXPECT(cd.get_dimensions_for({64, 2, 1}).empty()); } TEST_CASE(common_error1) diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 538609ed5bc..42b2f0f419c 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -401,6 +401,47 @@ TEST_CASE(add_reshape_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_contiguous_reshape_add) +{ + auto s1 = + migraphx::shape::from_permutation(migraphx::shape::float_type, {3, 10, 16}, {0, 2, 1}); + auto s2 = migraphx::shape{migraphx::shape::float_type, {3, 40, 2, 2}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {3, 10, 4, 2, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto contiguous = mm->add_instruction(migraphx::make_op("contiguous"), add1); + auto reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), contiguous); + auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y); + auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + auto reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), fadd); + mm->add_return({reshape}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(add_reshape_add_nonstandard) { migraphx::shape s1 = diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 24617bc1546..b277168a6a5 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -327,4 +327,217 @@ TEST_CASE(reduce_reduce_broadcast) EXPECT(p1 == p2); } +TEST_CASE(reduce_reshape_pointwise1) +{ + migraphx::shape s1{migraphx::shape::float_type, {64, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {8, 8, 2, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); + auto rsumr = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), rsumb); + auto add = add_pointwise(p1, "main:pointwise0", {rsumr, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto xr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); + auto add = add_reduce( + p2, + "main:reduce_sum0_reshape:main:pointwise0", + {xr, y}, + {2, 3}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), rsum); + return add_pointwise( + p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); + }); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(reduce_reshape_pointwise2) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 32, 40960}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 320, 64, 64}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 32, 10, 64, 64}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); + auto rsumr = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), rsumb); + auto add = add_pointwise(p1, "main:pointwise0", {rsumr, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto xr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x); + auto yr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y); + auto add = add_reduce( + p2, + "main:reduce_sum0_reshape:main:pointwise0", + {xr, yr}, + {2, 3, 4}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rsum); + return add_pointwise( + p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); + }); + auto addr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add); + mm->add_return({addr}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(reduce_contiguous_reshape_pointwise) +{ + migraphx::shape s1 = + migraphx::shape::from_permutation(migraphx::shape::float_type, {2, 32, 40960}, {1, 0, 2}); + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 320, 64, 64}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 32, 10, 64, 64}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumc = mm->add_instruction(migraphx::make_op("contiguous"), rsum); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsumc); + auto rsumr = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), rsumb); + auto add = add_pointwise(p1, "main:pointwise0", {rsumr, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto xr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x); + auto yr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y); + auto add = add_reduce( + p2, + "main:reduce_sum0_reshape:main:pointwise0", + {xr, yr}, + {2, 3, 4}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rsum); + return add_pointwise( + p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); + }); + auto addr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add); + mm->add_return({addr}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(reduce_reshape_reduce) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 32, 4096}}; + migraphx::shape s1r{migraphx::shape::float_type, {2, 32, 1}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 16, 64, 64}}; + migraphx::shape s2r{migraphx::shape::float_type, {4, 16, 1, 1}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 2, 16, 64, 64}}; + migraphx::shape s3r{migraphx::shape::float_type, {2, 2, 16, 1, 1}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x1 = mm->add_parameter("x1", s1); + auto x2 = mm->add_parameter("x2", s1r); + auto y = mm->add_parameter("y", s2); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x1); + auto rsum1_add = add_pointwise(p1, "main:pointwise0", {rsum1, x2}, single_pointwise("add")); + + auto rsum1_addb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum1_add); + auto rsum1_sub = + add_pointwise(p1, "main:pointwise1", {rsum1_addb, x1}, single_pointwise("sub")); + auto rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsum1_sub); + auto rsum2b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum2); + auto rsum2_sub = + add_pointwise(p1, "main:pointwise2", {rsum2b, x1}, single_pointwise("sub")); + auto rsum2_subr = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), rsum2_sub); + auto rsum3 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), rsum2_subr); + auto rsum3b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), rsum3); + auto rsum3_add = add_pointwise(p1, "main:pointwise3", {rsum3b, y}, single_pointwise("add")); + mm->add_return({rsum3_add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x1 = mm->add_parameter("x1", s1); + auto x2 = mm->add_parameter("x2", s1r); + auto y = mm->add_parameter("y", s2); + auto x1r = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x1); + auto x2r = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3r.lens()}}), x2); + auto freduce = add_reduce( + p2, + "main:pointwise2:main:reduce_sum2_reshape_reshape:main:reduce_sum1:main:reduce_sum0:" + "main:pointwise0:main:pointwise1_reshape", + {x1r, x2r}, + {3, 4}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto add = add_pointwise( + p2, rm, "main:pointwise0", {rsum1, inputs[1]}, single_pointwise("add")); + auto addb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), add); + auto sub1 = add_pointwise( + p2, rm, "main:pointwise1", {addb, inputs[0]}, single_pointwise("sub")); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), sub1); + auto rsum2b = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), rsum2); + auto sub2 = add_pointwise( + p2, rm, "main:pointwise2", {rsum2b, inputs[0]}, single_pointwise("sub")); + return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), sub2); + }); + auto freducer = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2r.lens()}}), freduce); + // TODO: Fuse the last add as well + auto freducerb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), freducer); + auto add = add_pointwise(p2, "main:pointwise3", {freducerb, y}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }