Skip to content

Commit

Permalink
Fuse layout with pointwise op (#2886)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Mar 15, 2024
1 parent ef285c9 commit 7460ac3
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,15 @@ struct find_contiguous
}
};

struct find_pointwise_contiguous
struct find_pointwise_layout_contiguous
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
auto is_layout = precompile_name("layout")(
match::arg(0)(match::used_once(), precompile_name("pointwise")));
auto is_contiguous = match::name("gpu::contiguous")(
match::arg(0)(match::used_once(), precompile_name("pointwise")));
return match::any_of(is_layout, is_contiguous);
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -782,7 +786,7 @@ struct find_pointwise_contiguous
auto args = pw->inputs();
args.back() = alloc;

// Ensure the output shape of the pointwise module is contiguous
// Ensure the output shape of the pointwise module retains the memory layout
auto pw_op_val = pw->get_operator().to_value();
pw_op_val["output_shape"] = to_value(ins->get_shape());

Expand Down Expand Up @@ -852,7 +856,7 @@ struct find_concat_pointwise

void fuse_ops::apply(module& m) const
{
match::find_matches(m, find_pointwise_contiguous{});
match::find_matches(m, find_pointwise_layout_contiguous{});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}});
Expand Down
83 changes: 83 additions & 0 deletions test/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,89 @@ TEST_CASE(pointwise_contiguous)
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_layout_convolution)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 320, 128, 128}};
migraphx::shape s2{migraphx::shape::float_type, {320, 320, 3, 3}, {2880, 1, 960, 320}};
migraphx::shape s3{migraphx::shape::float_type, {2, 320, 128, 128}, {5242880, 1, 40960, 320}};
// workspace for gpu::convolution, memory space can change based on gfx arch and rocm version,
// For the unit-test just use some random number.
migraphx::shape s4{migraphx::shape::int8_type, {41943040}};

auto create_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x1 = mm->add_parameter("x1", s1);
auto x2 = mm->add_parameter("x2", s1);
auto weights = mm->add_parameter("weights", s2);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s1)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pwm = create_pointwise_module(
p, "main:pointwise0", {x1, x2}, [=](auto* pm, const auto& inputs) {
auto mul_ins = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("sigmoid"), mul_ins);
});
auto pw_ins =
mm->add_instruction(make_precompile_op("pointwise"), {x1, x2, alloc_ins}, {pwm});

auto alloc_ins2 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}}));
auto layout_op = migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}});
auto layout_ins = mm->add_instruction(make_precompile_op(layout_op), {pw_ins, alloc_ins2});
auto conv_op = migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}});
auto alloc_ins3 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s4)}}));
auto alloc_ins4 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}}));
auto conv =
mm->add_instruction(migraphx::make_op("gpu::convolution", {{"op", conv_op.to_value()}}),
layout_ins,
weights,
alloc_ins3,
alloc_ins4);
mm->add_return({conv});
return p;
};
auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x1 = mm->add_parameter("x1", s1);
auto x2 = mm->add_parameter("x2", s1);
auto weights = mm->add_parameter("weights", s2);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s3)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pwm = create_pointwise_module(
p, "main:pointwise0", {x1, x2}, [=](auto* pm, const auto& inputs) {
auto mul_ins = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("sigmoid"), mul_ins);
});
auto pw_op = migraphx::make_op("pointwise");
auto pre_comp_op = migraphx::make_op(
"gpu::precompile_op",
{{"op", migraphx::to_value(pw_op)}, {"output_shape", migraphx::to_value(s3)}});

auto pw_ins = mm->add_instruction(pre_comp_op, {x1, x2, alloc_ins}, {pwm});

auto conv_op = migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}});
auto alloc_ins2 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s4)}}));
auto alloc_ins3 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}}));
auto conv =
mm->add_instruction(migraphx::make_op("gpu::convolution", {{"op", conv_op.to_value()}}),
pw_ins,
weights,
alloc_ins2,
alloc_ins3);
mm->add_return({conv});
return p;
};
migraphx::program p1 = create_program();
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}

TEST_CASE(concat_pointwise_contiguous)
{
migraphx::shape s1 = migraphx::shape::from_permutation(
Expand Down
55 changes: 55 additions & 0 deletions test/verify/test_pointwise_conv_nhwc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t DType>
struct test_pointwise_conv_nhwc : verify_program<test_pointwise_conv_nhwc<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
auto v = mm->add_parameter("v", {DType, {2, 8, 3, 3}});
auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, y);
auto sigmoid = mm->add_instruction(migraphx::make_op("sigmoid"), mul);
auto layout_ins = mm->add_instruction(
migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}}), sigmoid);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), w, v);
auto layout_w = mm->add_instruction(
migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}}), add_ins);
mm->add_instruction(migraphx::make_op("convolution"), layout_ins, layout_w);
return p;
}
std::string section() const { return "conv"; }
};

template struct test_pointwise_conv_nhwc<migraphx::shape::float_type>;
template struct test_pointwise_conv_nhwc<migraphx::shape::fp8e4m3fnuz_type>;

0 comments on commit 7460ac3

Please sign in to comment.