From be4658c59f3fbb29eee2232d770674b191a929c3 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 11 Oct 2023 06:57:33 -0500 Subject: [PATCH 1/2] Fix MLIR input fusion non-std shapes from squeeze, flatten and unsqueeze Currently, we see MLIR partition candidates recieving non-standard shape due to not fusing in squeeze, flatten and unsqueeze ops. These ops could be canonicalized to reshape without introducing additional ops as long as MLIR backend is concerned. --- src/targets/gpu/fuse_mlir.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index ef4158c49fb..7adcb384250 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -131,9 +131,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) for(instruction_ref input : gemm_based_op->inputs()) { std::vector op_stream; - while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) + while(contains( + {"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"}, + input->name())) { - op_stream.push_back(input->get_operator()); + operation op = input->get_operator(); + if(contains({"squeeze", "flatten", "unsqueeze"}, input->name())) + { + op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}}); + } + op_stream.push_back(op); input = input->inputs().at(0); } top_inputs.push_back(input); From 6239ef4ce5354bf54de56460df6cd5f15a6af246 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 12 Oct 2023 05:49:28 -0500 Subject: [PATCH 2/2] * add tests --- test/verify/test_flatten_dot_relu.cpp | 46 ++++++++++++++++++++++++ test/verify/test_squeeze_conv_relu.cpp | 45 +++++++++++++++++++++++ test/verify/test_unsqueeze_conv_relu.cpp | 45 +++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 test/verify/test_flatten_dot_relu.cpp create mode 100644 test/verify/test_squeeze_conv_relu.cpp create mode 100644 test/verify/test_unsqueeze_conv_relu.cpp diff --git a/test/verify/test_flatten_dot_relu.cpp b/test/verify/test_flatten_dot_relu.cpp new file mode 100644 index 00000000000..c09ea6b9d2c --- /dev/null +++ b/test/verify/test_flatten_dot_relu.cpp @@ -0,0 +1,46 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_flatten_dot_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = + mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 3, 5}}); + a = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 3}}), a); + auto b = + mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 3, 1}}); + b = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 3}}), b); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + mm->add_instruction(migraphx::make_op("relu"), dot); + return p; + } +}; diff --git a/test/verify/test_squeeze_conv_relu.cpp b/test/verify/test_squeeze_conv_relu.cpp new file mode 100644 index 00000000000..232a333e770 --- /dev/null +++ b/test/verify/test_squeeze_conv_relu.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_squeeze_conv_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 1, 3, 3}}); + input = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), input); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_instruction(migraphx::make_op("relu"), conv); + return p; + } +}; diff --git a/test/verify/test_unsqueeze_conv_relu.cpp b/test/verify/test_unsqueeze_conv_relu.cpp new file mode 100644 index 00000000000..eba2085c568 --- /dev/null +++ b/test/verify/test_unsqueeze_conv_relu.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_unsqueeze_conv_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 3, 3}}); + weights = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), weights); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_instruction(migraphx::make_op("relu"), conv); + return p; + } +};