From 8d790befdcf08a8371ad3b0c21095acdf4adc2ef Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 15 Feb 2024 14:02:52 +0000 Subject: [PATCH 01/15] WIP --- src/fuse_concat.cpp | 50 ++++++++++- test/fuse_concat.cpp | 87 +++++++++++++++----- test/verify/test_pooling_add_concat_relu.cpp | 4 +- 3 files changed, 117 insertions(+), 24 deletions(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index b6e169cef44..4af8f30bfc5 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -66,7 +66,7 @@ struct fused_concat module_ref post_mod = mods.back(); // post_mod has one input argument that is result of concat and will get generated from // pre-mods internally. Therefore deduct 1 from post_mod params while asserting. - assert(input_iter + post_mod->get_parameter_names().size() - 1 == inputs.end()); + // assert(input_iter + post_mod->get_parameter_names().size() - 1 == inputs.end()); auto type = std::prev(post_mod->end())->get_shape().type(); const auto& first_shape_lens = concat_inputs.front().lens(); auto mismatch_it = @@ -98,7 +98,53 @@ struct fused_concat MIGRAPHX_REGISTER_OP(fused_concat); namespace { +struct find_concat_pointwise +{ + auto matcher() const + { + return match::name("concat")( + match::any_of[match::inputs()](match::name("pointwise")(match::used_once()))); + } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto concat_ins = r.result; + + std::vector inputs; + for(auto input : concat_ins->inputs()) + { + if(input->name() == "pointwise") + inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end()); + else + inputs.push_back(input); + } + + std::vector module_inputs; + static unsigned int counter = 0; + std::transform(concat_ins->inputs().begin(), + concat_ins->inputs().end(), + std::back_inserter(module_inputs), + [&](instruction_ref input) { + if(input->name() == "pointwise" and input->outputs().size() == 1) + { + auto* pm = input->module_inputs().front(); + return mpm.create_module("concat:" + pm->name(), *pm); + } + auto* pm = + mpm.create_module("concat:identity" + std::to_string(counter++)); + + auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); + auto id = pm->add_instruction(make_op("identity"), x); + pm->add_return({id}); + return pm; + }); + mpm.get_module().replace_instruction( + concat_ins, + make_op("fused_concat", concat_ins->normalized_operator().to_value()), + inputs, + module_inputs); + } +}; struct find_pointwise_concat_pointwise { auto matcher() const @@ -173,6 +219,8 @@ struct find_pointwise_concat_pointwise void fuse_concat::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_pointwise_concat_pointwise{}); + mpm.run_pass(migraphx::dead_code_elimination{}); + match::find_matches(mpm, find_concat_pointwise{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index 9d33f5a9a9c..ba8b072d89f 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -51,9 +51,25 @@ concat_arg arg(std::string name, std::vector input return {std::move(name), std::move(inputs), std::move(f)}; } +template +migraphx::instruction_ref add_pointwise_concat(migraphx::program& p, std::size_t axis, Args... args) +{ + std::vector module_inputs; + std::vector ins_inputs; + migraphx::each_args( + [&](auto arg) { + module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f)); + ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end()); + }, + args...); + auto* mm = p.get_main_module(); + return mm->add_instruction( + migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs); +} + template migraphx::instruction_ref -add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args) +add_pointwise_concat_pointwise(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args) { std::vector module_inputs; std::vector ins_inputs; @@ -81,7 +97,36 @@ add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args) migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs); } -TEST_CASE(simple_pointwise_concat) +TEST_CASE(simple_concat_pointwise) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub")); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); + mm->add_return({concat}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused_concat = + add_pointwise_concat(p2, + 1, + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); + mm->add_return({fused_concat}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(simple_pointwise_concat_pointwise) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; @@ -101,18 +146,18 @@ TEST_CASE(simple_pointwise_concat) auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); - auto fused_concat = - add_concat(p2, - 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), - arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); + auto fused_concat = add_pointwise_concat_pointwise( + p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); mm->add_return({fused_concat}); } EXPECT(p1 == p2); } -TEST_CASE(partial_pointwise_concat) +TEST_CASE(partial_pointwise_concat_pointwise) { migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; @@ -138,12 +183,12 @@ TEST_CASE(partial_pointwise_concat) auto z = mm->add_parameter("z", s2); auto pooling = mm->add_instruction( migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); - auto fused_concat = - add_concat(p2, - 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), - arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:identity0", {pooling}, single_pointwise("identity"))); + auto fused_concat = add_pointwise_concat_pointwise( + p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:identity0", {pooling}, single_pointwise("identity"))); mm->add_return({fused_concat}); } EXPECT(p1 == p2); @@ -171,12 +216,12 @@ TEST_CASE(pointwise_concat_fusion) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s2); auto yc = mm->add_instruction(migraphx::make_op("contiguous"), y); - auto fused_concat = - add_concat(p2, - 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), - arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")), - arg("concat:identity1", {yc}, single_pointwise("identity"))); + auto fused_concat = add_pointwise_concat_pointwise( + p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")), + arg("concat:identity1", {yc}, single_pointwise("identity"))); mm->add_return({fused_concat}); } EXPECT(p1 == p2); diff --git a/test/verify/test_pooling_add_concat_relu.cpp b/test/verify/test_pooling_add_concat_relu.cpp index 4a000c10871..40d1df6d552 100644 --- a/test/verify/test_pooling_add_concat_relu.cpp +++ b/test/verify/test_pooling_add_concat_relu.cpp @@ -43,8 +43,8 @@ struct test_pooling_add_concat_relu : verify_programadd_instruction(migraphx::make_op("add"), x, y); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); - auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); - mm->add_return({relu}); + // auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); + mm->add_return({concat}); return p; } }; From ca22cc77deb77f33a0a4f5a33a9112f1a250d938 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 14:52:34 +0000 Subject: [PATCH 02/15] tests working --- src/fuse_concat.cpp | 16 +++--- src/targets/gpu/compile_gen.cpp | 2 + test/fuse_concat.cpp | 55 +++++++------------- test/include/pointwise.hpp | 5 ++ test/verify/test_pooling_add_concat_relu.cpp | 43 +++++++++------ 5 files changed, 63 insertions(+), 58 deletions(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index 4af8f30bfc5..e4f82d52f29 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -66,7 +66,7 @@ struct fused_concat module_ref post_mod = mods.back(); // post_mod has one input argument that is result of concat and will get generated from // pre-mods internally. Therefore deduct 1 from post_mod params while asserting. - // assert(input_iter + post_mod->get_parameter_names().size() - 1 == inputs.end()); + assert(input_iter + post_mod->get_parameter_names().size() - 1 == inputs.end()); auto type = std::prev(post_mod->end())->get_shape().type(); const auto& first_shape_lens = concat_inputs.front().lens(); auto mismatch_it = @@ -134,10 +134,13 @@ struct find_concat_pointwise mpm.create_module("concat:identity" + std::to_string(counter++)); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); - auto id = pm->add_instruction(make_op("identity"), x); - pm->add_return({id}); + pm->add_return({x}); return pm; }); + auto* post_pm = mpm.create_module("noop:concat" + std::to_string(counter++)); + auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()}); + post_pm->add_return({x}); + module_inputs.push_back(post_pm); mpm.get_module().replace_instruction( concat_ins, make_op("fused_concat", concat_ins->normalized_operator().to_value()), @@ -186,12 +189,9 @@ struct find_pointwise_concat_pointwise auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); } - auto* pm = - mpm.create_module("concat:identity" + std::to_string(counter++)); - + auto* pm = mpm.create_module("concat:noop" + std::to_string(counter++)); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); - auto id = pm->add_instruction(make_op("identity"), x); - pm->add_return({id}); + pm->add_return({x}); return pm; }); diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index cf3afed9497..2469c6571a8 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -388,6 +388,8 @@ std::string generate_name_from_ops(const module& m, const std::string& postname) auto op_names = get_op_names(m); if(not postname.empty()) op_names.push_back(postname); + if(op_names.empty()) + return "noop"; return join_strings(op_names, "_"); } diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index ba8b072d89f..2ea0d24d05e 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -51,25 +51,9 @@ concat_arg arg(std::string name, std::vector input return {std::move(name), std::move(inputs), std::move(f)}; } -template -migraphx::instruction_ref add_pointwise_concat(migraphx::program& p, std::size_t axis, Args... args) -{ - std::vector module_inputs; - std::vector ins_inputs; - migraphx::each_args( - [&](auto arg) { - module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f)); - ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end()); - }, - args...); - auto* mm = p.get_main_module(); - return mm->add_instruction( - migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs); -} - template migraphx::instruction_ref -add_pointwise_concat_pointwise(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args) +add_pointwise_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args) { std::vector module_inputs; std::vector ins_inputs; @@ -119,6 +103,7 @@ TEST_CASE(simple_concat_pointwise) auto fused_concat = add_pointwise_concat(p2, 1, + arg("noop:concat0", {}, noop_pointwise()), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); mm->add_return({fused_concat}); @@ -146,12 +131,12 @@ TEST_CASE(simple_pointwise_concat_pointwise) auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); - auto fused_concat = add_pointwise_concat_pointwise( - p2, - 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), - arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); + auto fused_concat = + add_pointwise_concat(p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); mm->add_return({fused_concat}); } EXPECT(p1 == p2); @@ -183,12 +168,12 @@ TEST_CASE(partial_pointwise_concat_pointwise) auto z = mm->add_parameter("z", s2); auto pooling = mm->add_instruction( migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); - auto fused_concat = add_pointwise_concat_pointwise( - p2, - 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), - arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:identity0", {pooling}, single_pointwise("identity"))); + auto fused_concat = + add_pointwise_concat(p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:noop0", {pooling}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); @@ -216,12 +201,12 @@ TEST_CASE(pointwise_concat_fusion) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s2); auto yc = mm->add_instruction(migraphx::make_op("contiguous"), y); - auto fused_concat = add_pointwise_concat_pointwise( - p2, - 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), - arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")), - arg("concat:identity1", {yc}, single_pointwise("identity"))); + auto fused_concat = + add_pointwise_concat(p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")), + arg("concat:noop1", {yc}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); diff --git a/test/include/pointwise.hpp b/test/include/pointwise.hpp index 1f5282277bc..d101767859d 100644 --- a/test/include/pointwise.hpp +++ b/test/include/pointwise.hpp @@ -67,6 +67,11 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p, return add_pointwise(p, p.get_main_module(), name, inputs, f); } +inline auto noop_pointwise() +{ + return [=](auto*, const auto& inputs) { return inputs; }; +} + inline auto single_pointwise(const std::string& name) { return [=](auto* pm, const auto& inputs) { diff --git a/test/verify/test_pooling_add_concat_relu.cpp b/test/verify/test_pooling_add_concat_relu.cpp index 40d1df6d552..fe7ab91363a 100644 --- a/test/verify/test_pooling_add_concat_relu.cpp +++ b/test/verify/test_pooling_add_concat_relu.cpp @@ -28,23 +28,36 @@ #include #include -struct test_pooling_add_concat_relu : verify_program +migraphx::program create_concat_fusion_program(bool post_pointwise) { - migraphx::program create_program() const + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); + if(post_pointwise) + { + auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); + mm->add_return({relu}); + } + else { - migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; - migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto z = mm->add_parameter("z", s2); - auto pooling = mm->add_instruction( - migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); - auto add = mm->add_instruction(migraphx::make_op("add"), x, y); - auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); - // auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); mm->add_return({concat}); - return p; } + return p; +} +struct test_pooling_add_concat_relu : verify_program +{ + migraphx::program create_program() const { return create_concat_fusion_program(true); } +}; + +struct test_pooling_add_concat : verify_program +{ + migraphx::program create_program() const { return create_concat_fusion_program(false); } }; From c8f66fd0ef3e5499c8cf34cd626f5ae01ae26495 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 14:58:50 +0000 Subject: [PATCH 03/15] add used_once --- src/fuse_concat.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index e4f82d52f29..966b5cb0a59 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -103,6 +103,7 @@ struct find_concat_pointwise auto matcher() const { return match::name("concat")( + match::used_once(), match::any_of[match::inputs()](match::name("pointwise")(match::used_once()))); } @@ -148,6 +149,7 @@ struct find_concat_pointwise module_inputs); } }; + struct find_pointwise_concat_pointwise { auto matcher() const From b8eb2e94e46c8a497cb49c78a47eb2e602b3c29d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:03:06 +0000 Subject: [PATCH 04/15] use noop --- src/fuse_concat.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index 966b5cb0a59..b164ae54278 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -131,9 +131,7 @@ struct find_concat_pointwise auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); } - auto* pm = - mpm.create_module("concat:identity" + std::to_string(counter++)); - + auto* pm = mpm.create_module("concat:noop" + std::to_string(counter++)); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); pm->add_return({x}); return pm; From 521b3b035584ad42d39cc6ba71aba6d38777dd2b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:09:49 +0000 Subject: [PATCH 05/15] add partial case --- test/fuse_concat.cpp | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index 2ea0d24d05e..c97cf52ff22 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -142,6 +142,42 @@ TEST_CASE(simple_pointwise_concat_pointwise) EXPECT(p1 == p2); } +TEST_CASE(partial_pointwise_concat) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); + mm->add_return({concat}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); + auto fused_concat = + add_pointwise_concat(p2, + 1, + arg("noop:concat2", {}, noop_pointwise()), + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:noop1", {pooling}, noop_pointwise())); + mm->add_return({fused_concat}); + } + EXPECT(p1 == p2); +} + TEST_CASE(partial_pointwise_concat_pointwise) { migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; From 4b5336eaf97948897a57310007b8e723edbdec70 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:15:05 +0000 Subject: [PATCH 06/15] licensing fix --- test/include/pointwise.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/include/pointwise.hpp b/test/include/pointwise.hpp index d101767859d..b7d4f38a217 100644 --- a/test/include/pointwise.hpp +++ b/test/include/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 62b17323bb9f430429cf5171e957f0b64e5f0448 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:29:52 +0000 Subject: [PATCH 07/15] add case of multiple uses --- src/fuse_concat.cpp | 6 +++--- test/fuse_concat.cpp | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index b164ae54278..96698542676 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -114,7 +114,7 @@ struct find_concat_pointwise std::vector inputs; for(auto input : concat_ins->inputs()) { - if(input->name() == "pointwise") + if(input->name() == "pointwise" and input->outputs().size() == 1) inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end()); else inputs.push_back(input); @@ -168,7 +168,7 @@ struct find_pointwise_concat_pointwise std::vector inputs; for(auto input : concat_ins->inputs()) { - if(input->name() == "pointwise") + if(input->name() == "pointwise" and input->outputs().size() == 1) inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end()); else inputs.push_back(input); @@ -184,7 +184,7 @@ struct find_pointwise_concat_pointwise concat_ins->inputs().end(), std::back_inserter(module_inputs), [&](instruction_ref input) { - if(input->name() == "pointwise") + if(input->name() == "pointwise" and input->outputs().size() == 1) { auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index c97cf52ff22..e027a2169ba 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -142,6 +142,46 @@ TEST_CASE(simple_pointwise_concat_pointwise) EXPECT(p1 == p2); } +TEST_CASE(multiple_use_pointwise_concat_pointwise) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub")); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); + auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); + auto slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), relu); + auto mul = add_pointwise(p1, "main:pointwise3", {slice, sub}, single_pointwise({"mul"})); + mm->add_return({mul}); + } + run_pass(p1); + p1.debug_print(); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto sub = add_pointwise(p2, "main:pointwise1", {x, y}, single_pointwise("sub")); + auto fused_concat = + add_pointwise_concat(p2, + 1, + arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), + arg("concat:noop0", {sub}, noop_pointwise())); + auto slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), + fused_concat); + auto mul = add_pointwise(p2, "main:pointwise3", {slice, sub}, single_pointwise({"mul"})); + mm->add_return({mul}); + } + EXPECT(p1 == p2); +} + TEST_CASE(partial_pointwise_concat) { migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; From b73ae2f0ff8f3ebed2bd3d9b3b5a70326b1031fb Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:30:03 +0000 Subject: [PATCH 08/15] remove prints --- test/fuse_concat.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index e027a2169ba..8d9a2fb0b6f 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -160,7 +160,6 @@ TEST_CASE(multiple_use_pointwise_concat_pointwise) mm->add_return({mul}); } run_pass(p1); - p1.debug_print(); migraphx::program p2; { auto* mm = p2.get_main_module(); From 025965354dcbb0a566b3fd081c80ab49472a038e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:51:22 +0000 Subject: [PATCH 09/15] add multiple use verify test --- test/verify/test_pooling_add_concat_relu.cpp | 21 ++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/verify/test_pooling_add_concat_relu.cpp b/test/verify/test_pooling_add_concat_relu.cpp index fe7ab91363a..ab38622385c 100644 --- a/test/verify/test_pooling_add_concat_relu.cpp +++ b/test/verify/test_pooling_add_concat_relu.cpp @@ -61,3 +61,24 @@ struct test_pooling_add_concat : verify_program { migraphx::program create_program() const { return create_concat_fusion_program(false); } }; + +struct test_add_sub_concat_slice_mul : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto sub = mm->add_instruction(migraphx::make_op("sub"), x, y); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); + auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); + auto slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {6}}}), relu); + auto mul = mm->add_instruction(migraphx::make_op("mul"), sub, slice); + mm->add_return({mul}); + return p; + } +}; From 0f0bf96be0cfc9527c1104d0332f3b6cd2b80c8e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 15:58:34 +0000 Subject: [PATCH 10/15] formatting --- src/fuse_concat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index 96698542676..098a16d1f1d 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -132,7 +132,7 @@ struct find_concat_pointwise return mpm.create_module("concat:" + pm->name(), *pm); } auto* pm = mpm.create_module("concat:noop" + std::to_string(counter++)); - auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); + auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); pm->add_return({x}); return pm; }); From 675f49421b8e0dca80ef1269ad62bd3d6ecf8674 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Feb 2024 16:09:34 +0000 Subject: [PATCH 11/15] Move test to separate file --- test/verify/test_add_sub_concat_slice_mul.cpp | 51 +++++++++++++++++++ test/verify/test_pooling_add_concat_relu.cpp | 21 -------- 2 files changed, 51 insertions(+), 21 deletions(-) create mode 100644 test/verify/test_add_sub_concat_slice_mul.cpp diff --git a/test/verify/test_add_sub_concat_slice_mul.cpp b/test/verify/test_add_sub_concat_slice_mul.cpp new file mode 100644 index 00000000000..245a65fa1ff --- /dev/null +++ b/test/verify/test_add_sub_concat_slice_mul.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +// test for fuse_concat pass where pointwise input to concat has multiple outputs +struct test_add_sub_concat_slice_mul : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto sub = mm->add_instruction(migraphx::make_op("sub"), x, y); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); + auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); + auto slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {6}}}), relu); + auto mul = mm->add_instruction(migraphx::make_op("mul"), sub, slice); + mm->add_return({mul}); + return p; + } +}; diff --git a/test/verify/test_pooling_add_concat_relu.cpp b/test/verify/test_pooling_add_concat_relu.cpp index ab38622385c..fe7ab91363a 100644 --- a/test/verify/test_pooling_add_concat_relu.cpp +++ b/test/verify/test_pooling_add_concat_relu.cpp @@ -61,24 +61,3 @@ struct test_pooling_add_concat : verify_program { migraphx::program create_program() const { return create_concat_fusion_program(false); } }; - -struct test_add_sub_concat_slice_mul : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto add = mm->add_instruction(migraphx::make_op("add"), x, y); - auto sub = mm->add_instruction(migraphx::make_op("sub"), x, y); - auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); - auto relu = mm->add_instruction(migraphx::make_op("relu"), concat); - auto slice = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {6}}}), relu); - auto mul = mm->add_instruction(migraphx::make_op("mul"), sub, slice); - mm->add_return({mul}); - return p; - } -}; From 8e31d6a78bc9d373a151ff67715ad14ec4977656 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 23 Feb 2024 15:30:23 +0000 Subject: [PATCH 12/15] fix counter issue because of static --- src/fuse_concat.cpp | 12 +++--- test/fuse_concat.cpp | 100 +++++++++++++++++++++---------------------- 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index 098a16d1f1d..577291ea6c1 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -121,7 +121,7 @@ struct find_concat_pointwise } std::vector module_inputs; - static unsigned int counter = 0; + static unsigned int counter_cp = 0; std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), @@ -131,12 +131,13 @@ struct find_concat_pointwise auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); } - auto* pm = mpm.create_module("concat:noop" + std::to_string(counter++)); + auto* pm = + mpm.create_module("concat:noop" + std::to_string(counter_cp++)); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); pm->add_return({x}); return pm; }); - auto* post_pm = mpm.create_module("noop:concat" + std::to_string(counter++)); + auto* post_pm = mpm.create_module("noop:concat" + std::to_string(counter_cp++)); auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()}); post_pm->add_return({x}); module_inputs.push_back(post_pm); @@ -179,7 +180,7 @@ struct find_pointwise_concat_pointwise [&](auto input) { return input != concat_ins; }); std::vector module_inputs; - static unsigned int counter = 0; + static unsigned int counter_pcp = 0; std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), @@ -189,7 +190,8 @@ struct find_pointwise_concat_pointwise auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); } - auto* pm = mpm.create_module("concat:noop" + std::to_string(counter++)); + auto* pm = + mpm.create_module("concat:noop" + std::to_string(counter_pcp++)); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); pm->add_return({x}); return pm; diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index 8d9a2fb0b6f..e9b564008e8 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -111,38 +111,43 @@ TEST_CASE(simple_concat_pointwise) EXPECT(p1 == p2); } -TEST_CASE(simple_pointwise_concat_pointwise) +TEST_CASE(partial_pointwise_concat) { - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); - auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub")); - auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); - auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); - mm->add_return({relu}); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); + mm->add_return({concat}); } run_pass(p1); migraphx::program p2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); auto fused_concat = add_pointwise_concat(p2, 1, - arg("main:pointwise2:concat", {}, single_pointwise("relu")), + arg("noop:concat2", {}, noop_pointwise()), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); + arg("concat:noop1", {pooling}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); } -TEST_CASE(multiple_use_pointwise_concat_pointwise) +TEST_CASE(simple_pointwise_concat_pointwise) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; @@ -154,10 +159,7 @@ TEST_CASE(multiple_use_pointwise_concat_pointwise) auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub")); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); - auto slice = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), relu); - auto mul = add_pointwise(p1, "main:pointwise3", {slice, sub}, single_pointwise({"mul"})); - mm->add_return({mul}); + mm->add_return({relu}); } run_pass(p1); migraphx::program p2; @@ -165,23 +167,18 @@ TEST_CASE(multiple_use_pointwise_concat_pointwise) auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); - auto sub = add_pointwise(p2, "main:pointwise1", {x, y}, single_pointwise("sub")); auto fused_concat = add_pointwise_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:noop0", {sub}, noop_pointwise())); - auto slice = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), - fused_concat); - auto mul = add_pointwise(p2, "main:pointwise3", {slice, sub}, single_pointwise({"mul"})); - mm->add_return({mul}); + arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); + mm->add_return({fused_concat}); } EXPECT(p1 == p2); } -TEST_CASE(partial_pointwise_concat) +TEST_CASE(partial_pointwise_concat_pointwise) { migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; @@ -195,7 +192,8 @@ TEST_CASE(partial_pointwise_concat) migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); - mm->add_return({concat}); + auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); + mm->add_return({relu}); } run_pass(p1); migraphx::program p2; @@ -209,47 +207,49 @@ TEST_CASE(partial_pointwise_concat) auto fused_concat = add_pointwise_concat(p2, 1, - arg("noop:concat2", {}, noop_pointwise()), + arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:noop1", {pooling}, noop_pointwise())); + arg("concat:noop0", {pooling}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); } -TEST_CASE(partial_pointwise_concat_pointwise) +TEST_CASE(multiple_use_pointwise_concat_pointwise) { - migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; - migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto z = mm->add_parameter("z", s2); - auto pooling = mm->add_instruction( - migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); - auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, pooling); + auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub")); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); - mm->add_return({relu}); + auto slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), relu); + auto mul = add_pointwise(p1, "main:pointwise3", {slice, sub}, single_pointwise({"mul"})); + mm->add_return({mul}); } run_pass(p1); migraphx::program p2; { - auto* mm = p2.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto z = mm->add_parameter("z", s2); - auto pooling = mm->add_instruction( - migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto sub = add_pointwise(p2, "main:pointwise1", {x, y}, single_pointwise("sub")); auto fused_concat = add_pointwise_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:noop0", {pooling}, noop_pointwise())); - mm->add_return({fused_concat}); + arg("concat:noop1", {sub}, noop_pointwise())); + auto slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), + fused_concat); + auto mul = add_pointwise(p2, "main:pointwise3", {slice, sub}, single_pointwise({"mul"})); + mm->add_return({mul}); } EXPECT(p1 == p2); } @@ -281,7 +281,7 @@ TEST_CASE(pointwise_concat_fusion) 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")), - arg("concat:noop1", {yc}, noop_pointwise())); + arg("concat:noop2", {yc}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); From f2f7d91d181ae29d1ec49222b9e57aaf72168cc7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 23 Feb 2024 15:42:30 +0000 Subject: [PATCH 13/15] fix static counter issue --- src/fuse_concat.cpp | 19 +++++++++++-------- test/fuse_concat.cpp | 6 +++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index 577291ea6c1..77909d31af6 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -35,6 +35,12 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +unsigned int get_noop_counter() +{ + static unsigned int counter = 0; + return counter++; +} + struct fused_concat { int64_t axis = 0; @@ -119,9 +125,7 @@ struct find_concat_pointwise else inputs.push_back(input); } - std::vector module_inputs; - static unsigned int counter_cp = 0; std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), @@ -131,13 +135,13 @@ struct find_concat_pointwise auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); } - auto* pm = - mpm.create_module("concat:noop" + std::to_string(counter_cp++)); + auto* pm = mpm.create_module("concat:noop" + + std::to_string(get_noop_counter())); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); pm->add_return({x}); return pm; }); - auto* post_pm = mpm.create_module("noop:concat" + std::to_string(counter_cp++)); + auto* post_pm = mpm.create_module("noop:concat" + std::to_string(get_noop_counter())); auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()}); post_pm->add_return({x}); module_inputs.push_back(post_pm); @@ -180,7 +184,6 @@ struct find_pointwise_concat_pointwise [&](auto input) { return input != concat_ins; }); std::vector module_inputs; - static unsigned int counter_pcp = 0; std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), @@ -190,8 +193,8 @@ struct find_pointwise_concat_pointwise auto* pm = input->module_inputs().front(); return mpm.create_module("concat:" + pm->name(), *pm); } - auto* pm = - mpm.create_module("concat:noop" + std::to_string(counter_pcp++)); + auto* pm = mpm.create_module("concat:noop" + + std::to_string(get_noop_counter())); auto x = pm->add_parameter("x0", shape{input->get_shape().type()}); pm->add_return({x}); return pm; diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index e9b564008e8..b58f94b71d6 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -209,7 +209,7 @@ TEST_CASE(partial_pointwise_concat_pointwise) 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:noop0", {pooling}, noop_pointwise())); + arg("concat:noop3", {pooling}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); @@ -244,7 +244,7 @@ TEST_CASE(multiple_use_pointwise_concat_pointwise) 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), - arg("concat:noop1", {sub}, noop_pointwise())); + arg("concat:noop4", {sub}, noop_pointwise())); auto slice = mm->add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}), fused_concat); @@ -281,7 +281,7 @@ TEST_CASE(pointwise_concat_fusion) 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x}, single_pointwise("sigmoid")), - arg("concat:noop2", {yc}, noop_pointwise())); + arg("concat:noop5", {yc}, noop_pointwise())); mm->add_return({fused_concat}); } EXPECT(p1 == p2); From 7e01371c1fbb79810dc8205e6725e4dd028c6c05 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 11 Mar 2024 18:39:07 +0000 Subject: [PATCH 14/15] skip fusion if number of no-ops are more than 1. --- src/fuse_concat.cpp | 10 ++++++++++ test/fuse_concat.cpp | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index 77909d31af6..e31adb6e9d6 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -118,12 +118,22 @@ struct find_concat_pointwise auto concat_ins = r.result; std::vector inputs; + size_t num_noops = 0; for(auto input : concat_ins->inputs()) { if(input->name() == "pointwise" and input->outputs().size() == 1) + { inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end()); + } else + { + num_noops++; inputs.push_back(input); + } + } + if(num_noops > 1) + { + return; } std::vector module_inputs; std::transform(concat_ins->inputs().begin(), diff --git a/test/fuse_concat.cpp b/test/fuse_concat.cpp index b58f94b71d6..d8214e1a49f 100644 --- a/test/fuse_concat.cpp +++ b/test/fuse_concat.cpp @@ -147,6 +147,32 @@ TEST_CASE(partial_pointwise_concat) EXPECT(p1 == p2); } +TEST_CASE(skip_pointwise_concat) +{ + // number of no-ops are two in this case and therefore pointwise+concat fusion wouldn't be + // applicable. + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 8, 8}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 4, 16, 16}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto w = mm->add_parameter("w", s1); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto reduce_ins = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), w); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"lengths", {2, 2}}, {"stride", {2, 2}}}), z); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto concat = mm->add_instruction( + migraphx::make_op("concat", {{"axis", 1}}), reduce_ins, add, pooling); + mm->add_return({concat}); + } + migraphx::program p2 = p1; + run_pass(p1); + EXPECT(p1 == p2); +} + TEST_CASE(simple_pointwise_concat_pointwise) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; From cc8ef6bcede9749715d072653db6c6fa02c4bbe8 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 11 Mar 2024 18:43:37 +0000 Subject: [PATCH 15/15] change rule to allow more than 1 no-ops --- src/fuse_concat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fuse_concat.cpp b/src/fuse_concat.cpp index e31adb6e9d6..7fcbf934503 100644 --- a/src/fuse_concat.cpp +++ b/src/fuse_concat.cpp @@ -131,7 +131,7 @@ struct find_concat_pointwise inputs.push_back(input); } } - if(num_noops > 1) + if(num_noops > std::max(size_t{1}, concat_ins->inputs().size() / 4)) { return; }