diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index a519925f1b3..a651d6e2432 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -120,6 +120,7 @@ struct find_nop_reshapes reshapes.insert("multibroadcast"); reshapes.insert("pad"); reshapes.insert("slice"); + reshapes.insert("step"); reshapes.insert("transpose"); reshapes.insert("reduce_mean"); reshapes.insert("reduce_max"); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 09da4758418..db074c94b92 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -400,6 +400,20 @@ MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) }); } +std::vector mlir_contiguous(module_pass_manager& mpm, + const std::vector& inputs) +{ + std::vector result; + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(result), [&](instruction_ref input) { + if(input->get_shape().packed() or input->get_shape().broadcasted()) + return input; + return mpm.get_module().insert_instruction( + std::next(input), make_op("contiguous"), input); + }); + return result; +} + struct find_mlir_fused_ops { mlir_mode conv_mode = mlir_mode::none; @@ -432,7 +446,7 @@ struct find_mlir_fused_ops [&](auto input) { return input != gemm_based_op; }); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); mpm.get_module().replace_instruction( - ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); + ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } }; @@ -461,8 +475,10 @@ struct find_mlir_standalone_op auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op( mm, gemm_based_op->inputs(), gemm_based_op->get_operator()); mm->add_return({anchor_op}); - mpm.get_module().replace_instruction( - gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm}); + mpm.get_module().replace_instruction(gemm_based_op, + mlir_op{gemm_based_op->get_operator()}, + mlir_contiguous(mpm, top_inputs), + {mm}); } }; @@ -541,7 +557,7 @@ struct find_mlir_standalone_attention_op mm->add_return({ins_to_replace}); mpm.get_module().replace_instruction( - ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm}); + ins_to_be_replaced, mlir_op{gemm1->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } }; diff --git a/test/verify/test_step_dot.cpp b/test/verify/test_step_dot.cpp new file mode 100644 index 00000000000..9b297eda87e --- /dev/null +++ b/test/verify/test_step_dot.cpp @@ -0,0 +1,48 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_step_dot : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape as{migraphx::shape::float_type, {128, 4, 64, 196}}; + migraphx::shape bs{migraphx::shape::float_type, {128, 4, 196, 196}}; + auto a = mm->add_parameter("input", as); + auto b = mm->add_literal(migraphx::generate_literal(bs)); + auto step = + mm->add_instruction(migraphx::make_op("step", {{"axes", {2}}, {"steps", {2}}}), a); + auto dot = mm->add_instruction(migraphx::make_op("dot"), step, b); + mm->add_return({dot}); + return p; + } + std::string section() const { return "gemm"; } +};