diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 909c0f6bc26..3a847bbe7ef 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -66,7 +66,7 @@ add_library(migraphx insert_pad.cpp instruction.cpp json.cpp - layout_nhwc.cpp + layout_convolution.cpp lexing.cpp load_save.cpp make_op.cpp diff --git a/src/include/migraphx/layout_nhwc.hpp b/src/include/migraphx/layout_convolution.hpp similarity index 81% rename from src/include/migraphx/layout_nhwc.hpp rename to src/include/migraphx/layout_convolution.hpp index faf097a4d9d..9e45033a8db 100644 --- a/src/include/migraphx/layout_nhwc.hpp +++ b/src/include/migraphx/layout_convolution.hpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP -#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP +#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP #include #include @@ -34,14 +34,15 @@ inline namespace MIGRAPHX_INLINE_NS { struct module_pass_manager; /** - * Transform convolutions to nhwc + * Transform convolutions layout */ -struct MIGRAPHX_EXPORT layout_nhwc +struct MIGRAPHX_EXPORT layout_convolution { - std::string name() const { return "layout_nhwc"; } + bool channels_last = false; + std::string name() const { return "layout_convolution"; } void apply(module_pass_manager& mpm) const; }; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP +#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP diff --git a/src/layout_nhwc.cpp b/src/layout_convolution.cpp similarity index 62% rename from src/layout_nhwc.cpp rename to src/layout_convolution.cpp index 9d2a0083a34..83acb839ce6 100644 --- a/src/layout_nhwc.cpp +++ b/src/layout_convolution.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include @@ -32,49 +32,61 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -template -std::vector find_lasts(const module& m, Predicate pred) +namespace { +std::vector get_permutation(instruction_ref ins, const layout_convolution& lc) { - std::vector result; - fix([&](auto self, auto ins) { - if(pred(ins)) - { - result.push_back(ins); - return; - } - for(auto input : ins->inputs()) - self(input); - })(std::prev(m.end())); - return result; + if(lc.channels_last) + { + std::vector perm(ins->get_shape().ndim()); + std::iota(perm.begin() + 1, perm.end() - 1, 2); + perm.back() = 1; + return perm; + } + return find_permutation(ins->inputs().front()->get_shape()); +} + +bool skip_layout(const shape& s) +{ + return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type; } void preserve_output_layout(module& m) { auto last = std::prev(m.end()); - std::vector outputs; if(last->name() == "@return") - outputs = last->inputs(); - else - outputs = {last}; - - for(auto output : outputs) { - auto permutation = find_permutation(output->get_shape()); - auto layout = m.insert_instruction( - std::next(output), make_op("layout", {{"permutation", permutation}}), output); - m.replace_instruction(output, layout); + std::vector outputs; + std::transform(last->inputs().begin(), + last->inputs().end(), + std::back_inserter(outputs), + [&](instruction_ref ins) { + if(skip_layout(ins->get_shape())) + return ins; + auto permutation = find_permutation(ins->get_shape()); + return m.insert_instruction( + last, make_op("layout", {{"permutation", permutation}}), ins); + }); + m.replace_return(outputs); + } + else if(not skip_layout(last->get_shape())) + { + auto permutation = find_permutation(last->get_shape()); + m.add_instruction(make_op("layout", {{"permutation", permutation}}), last); } } -void transform_convolutions(module& m) +void transform_convolutions(module& m, const layout_convolution& lc) { for(auto ins : iterator_for(m)) { - if(ins->name() != "convolution") + if(not contains({"convolution", "quant_convolution"}, ins->name())) + continue; + if(ins->get_shape().dynamic()) continue; if(ins->get_shape().lens().size() != 4) continue; @@ -82,8 +94,9 @@ void transform_convolutions(module& m) if(v.at("group").to() > 1) continue; auto args = ins->inputs(); + auto perm = get_permutation(ins, lc); std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) { - return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i); + return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i); }); auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto c = m.insert_instruction(ins, make_op("contiguous"), conv); @@ -102,11 +115,12 @@ void remove_layout(module& m) m.replace_instruction(ins, ins->inputs().front()); } } +} // namespace -void layout_nhwc::apply(module_pass_manager& mpm) const +void layout_convolution::apply(module_pass_manager& mpm) const { preserve_output_layout(mpm.get_module()); - transform_convolutions(mpm.get_module()); + transform_convolutions(mpm.get_module(), *this); mpm.run_pass(dead_code_elimination{}); mpm.run_pass(eliminate_contiguous{"contiguous"}); mpm.run_pass(dead_code_elimination{}); diff --git a/src/module.cpp b/src/module.cpp index e7beb7c88bb..c839b98c3cf 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -355,7 +355,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref { impl->changed.notify(); assert(has_instruction(ins)); - assert(has_instruction(rep)); assert(ins != rep); if(ins == std::prev(this->end())) @@ -541,7 +540,6 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name, instruction_ref module::replace_return(std::vector args) { impl->changed.notify(); - assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); auto last = std::prev(this->end()); // If there is no return then add a return if(last->name() != "@return") diff --git a/src/targets/cpu/target.cpp b/src/targets/cpu/target.cpp index 6e4e4051a80..e148aa5b6f3 100644 --- a/src/targets/cpu/target.cpp +++ b/src/targets/cpu/target.cpp @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3437651f056..db76c5a24ac 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -35,7 +35,7 @@ #include #include #include -#include +#include #include #include #include @@ -182,7 +182,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, - enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), + layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})}, dead_code_elimination{}, prefuse_ops{}, dead_code_elimination{}, diff --git a/test/layout_nhwc.cpp b/test/layout_convolution.cpp similarity index 58% rename from test/layout_nhwc.cpp rename to test/layout_convolution.cpp index 7dae574d113..64e8830d67b 100644 --- a/test/layout_nhwc.cpp +++ b/test/layout_convolution.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include @@ -32,9 +32,9 @@ #include -void run_pass(migraphx::module& m) +void run_pass(migraphx::module& m, migraphx::layout_convolution lc = {}) { - migraphx::run_passes(m, {migraphx::layout_nhwc{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {lc, migraphx::dead_code_elimination{}}); } migraphx::operation layout(std::vector permutation = {0, 1, 2, 3}) @@ -47,7 +47,7 @@ migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruc return m.add_instruction(layout({0, 2, 3, 1}), ins); } -TEST_CASE(conv_relu) +TEST_CASE(auto_conv_nchw) { migraphx::module m1; { @@ -59,9 +59,128 @@ TEST_CASE(conv_relu) {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), x, w); - m1.add_instruction(migraphx::make_op("relu"), conv); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); } + migraphx::module m2 = m1; run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_conv_nhwc) +{ + auto transpose = migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 16, 16, 8}}); + auto xtranspose = m1.add_instruction(transpose, x); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 3, 3, 8}})); + auto wtranspose = m1.add_instruction(transpose, w); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + xtranspose, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_conv_mixed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {3, 3, 16, 8}})); + auto wtranspose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {3, 3, 16, 8}})); + auto wtranspose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto wlayout = m2.add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 1, 2, 3}}}), wtranspose); + auto conv = m2.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wlayout); + auto relu = m2.add_instruction(migraphx::make_op("relu"), conv); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_quant_conv_mixed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {3, 3, 16, 8}})); + auto wtranspose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto conv = m1.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m2.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {3, 3, 16, 8}})); + auto wtranspose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto wlayout = m2.add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 1, 2, 3}}}), wtranspose); + auto conv = m2.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wlayout); + auto relu = m2.add_instruction(migraphx::make_op("relu"), conv); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_relu) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}})); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + m1.add_instruction(migraphx::make_op("relu"), conv); + } + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -81,7 +200,7 @@ TEST_CASE(conv_relu) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_add) +TEST_CASE(nhwc_conv_add) { migraphx::module m1; { @@ -99,7 +218,7 @@ TEST_CASE(conv_add) y); m1.add_instruction(migraphx::make_op("add"), conv, b); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -114,7 +233,7 @@ TEST_CASE(conv_add) {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), x, w); - auto b = m2.add_instruction( + auto b = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), y); auto add = m2.add_instruction(migraphx::make_op("add"), conv, b); @@ -123,7 +242,49 @@ TEST_CASE(conv_add) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_conv) +TEST_CASE(nhwc_quant_conv_add) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {16, 8, 3, 3}})); + auto y = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {16}})); + auto conv = m1.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto b = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + y); + m1.add_instruction(migraphx::make_op("add"), conv, b); + } + run_pass(m1, {.channels_last = true}); + + migraphx::module m2; + { + auto x = add_layout_nhwc( + m2, m2.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}})); + auto w = add_layout_nhwc(m2, + m2.add_literal(migraphx::generate_literal( + {migraphx::shape::int8_type, {16, 8, 3, 3}}))); + auto y = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {16}})); + auto conv = m2.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto b = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + y); + auto add = m2.add_instruction(migraphx::make_op("add"), conv, b); + m2.add_instruction(layout(), add); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_conv) { migraphx::module m1; { @@ -149,7 +310,7 @@ TEST_CASE(conv_conv) auto relu2 = m1.add_instruction(migraphx::make_op("relu"), add2); m1.add_return({relu2}); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -182,7 +343,7 @@ TEST_CASE(conv_conv) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_reduce) +TEST_CASE(nhwc_conv_reduce) { migraphx::module m1; { @@ -201,7 +362,7 @@ TEST_CASE(conv_reduce) auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), reduce); m1.add_return({squeeze}); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; {