From 7460ac30849147ab55c7224410cc6d7cf0fea70f Mon Sep 17 00:00:00 2001 From: Umang Yadav <29876643+umangyadav@users.noreply.github.com> Date: Fri, 15 Mar 2024 10:12:52 -0400 Subject: [PATCH] Fuse layout with pointwise op (#2886) --- src/targets/gpu/fuse_ops.cpp | 12 ++-- test/gpu/fuse_ops.cpp | 83 ++++++++++++++++++++++++ test/verify/test_pointwise_conv_nhwc.cpp | 55 ++++++++++++++++ 3 files changed, 146 insertions(+), 4 deletions(-) create mode 100644 test/verify/test_pointwise_conv_nhwc.cpp diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index e4b28067f19..a76e6a97666 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -767,11 +767,15 @@ struct find_contiguous } }; -struct find_pointwise_contiguous +struct find_pointwise_layout_contiguous { auto matcher() const { - return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise"))); + auto is_layout = precompile_name("layout")( + match::arg(0)(match::used_once(), precompile_name("pointwise"))); + auto is_contiguous = match::name("gpu::contiguous")( + match::arg(0)(match::used_once(), precompile_name("pointwise"))); + return match::any_of(is_layout, is_contiguous); } void apply(module& m, const match::matcher_result& r) const @@ -782,7 +786,7 @@ struct find_pointwise_contiguous auto args = pw->inputs(); args.back() = alloc; - // Ensure the output shape of the pointwise module is contiguous + // Ensure the output shape of the pointwise module retains the memory layout auto pw_op_val = pw->get_operator().to_value(); pw_op_val["output_shape"] = to_value(ins->get_shape()); @@ -852,7 +856,7 @@ struct find_concat_pointwise void fuse_ops::apply(module& m) const { - match::find_matches(m, find_pointwise_contiguous{}); + match::find_matches(m, find_pointwise_layout_contiguous{}); run_passes(m, {dead_code_elimination{}}); match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx}); run_passes(m, {dead_code_elimination{}}); diff --git a/test/gpu/fuse_ops.cpp b/test/gpu/fuse_ops.cpp index 51223181bb3..c300c61316c 100644 --- a/test/gpu/fuse_ops.cpp +++ b/test/gpu/fuse_ops.cpp @@ -160,6 +160,89 @@ TEST_CASE(pointwise_contiguous) EXPECT(p1 == p2); } +TEST_CASE(pointwise_layout_convolution) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 320, 128, 128}}; + migraphx::shape s2{migraphx::shape::float_type, {320, 320, 3, 3}, {2880, 1, 960, 320}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 320, 128, 128}, {5242880, 1, 40960, 320}}; + // workspace for gpu::convolution, memory space can change based on gfx arch and rocm version, + // For the unit-test just use some random number. + migraphx::shape s4{migraphx::shape::int8_type, {41943040}}; + + auto create_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x1 = mm->add_parameter("x1", s1); + auto x2 = mm->add_parameter("x2", s1); + auto weights = mm->add_parameter("weights", s2); + auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s1)}}); + auto alloc_ins = mm->add_instruction(alloc); + auto* pwm = create_pointwise_module( + p, "main:pointwise0", {x1, x2}, [=](auto* pm, const auto& inputs) { + auto mul_ins = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("sigmoid"), mul_ins); + }); + auto pw_ins = + mm->add_instruction(make_precompile_op("pointwise"), {x1, x2, alloc_ins}, {pwm}); + + auto alloc_ins2 = + mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}})); + auto layout_op = migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}}); + auto layout_ins = mm->add_instruction(make_precompile_op(layout_op), {pw_ins, alloc_ins2}); + auto conv_op = migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}); + auto alloc_ins3 = + mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s4)}})); + auto alloc_ins4 = + mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}})); + auto conv = + mm->add_instruction(migraphx::make_op("gpu::convolution", {{"op", conv_op.to_value()}}), + layout_ins, + weights, + alloc_ins3, + alloc_ins4); + mm->add_return({conv}); + return p; + }; + auto create_fused_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x1 = mm->add_parameter("x1", s1); + auto x2 = mm->add_parameter("x2", s1); + auto weights = mm->add_parameter("weights", s2); + auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s3)}}); + auto alloc_ins = mm->add_instruction(alloc); + auto* pwm = create_pointwise_module( + p, "main:pointwise0", {x1, x2}, [=](auto* pm, const auto& inputs) { + auto mul_ins = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("sigmoid"), mul_ins); + }); + auto pw_op = migraphx::make_op("pointwise"); + auto pre_comp_op = migraphx::make_op( + "gpu::precompile_op", + {{"op", migraphx::to_value(pw_op)}, {"output_shape", migraphx::to_value(s3)}}); + + auto pw_ins = mm->add_instruction(pre_comp_op, {x1, x2, alloc_ins}, {pwm}); + + auto conv_op = migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}); + auto alloc_ins2 = + mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s4)}})); + auto alloc_ins3 = + mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}})); + auto conv = + mm->add_instruction(migraphx::make_op("gpu::convolution", {{"op", conv_op.to_value()}}), + pw_ins, + weights, + alloc_ins2, + alloc_ins3); + mm->add_return({conv}); + return p; + }; + migraphx::program p1 = create_program(); + run_pass(p1); + migraphx::program p2 = create_fused_program(); + EXPECT(p1 == p2); +} + TEST_CASE(concat_pointwise_contiguous) { migraphx::shape s1 = migraphx::shape::from_permutation( diff --git a/test/verify/test_pointwise_conv_nhwc.cpp b/test/verify/test_pointwise_conv_nhwc.cpp new file mode 100644 index 00000000000..12be39134f2 --- /dev/null +++ b/test/verify/test_pointwise_conv_nhwc.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_pointwise_conv_nhwc : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}}); + auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1)); + auto v = mm->add_parameter("v", {DType, {2, 8, 3, 3}}); + auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); + auto mul = mm->add_instruction(migraphx::make_op("mul"), x, y); + auto sigmoid = mm->add_instruction(migraphx::make_op("sigmoid"), mul); + auto layout_ins = mm->add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}}), sigmoid); + auto add_ins = mm->add_instruction(migraphx::make_op("add"), w, v); + auto layout_w = mm->add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}}), add_ins); + mm->add_instruction(migraphx::make_op("convolution"), layout_ins, layout_w); + return p; + } + std::string section() const { return "conv"; } +}; + +template struct test_pointwise_conv_nhwc; +template struct test_pointwise_conv_nhwc;