Skip to content

Commit

Permalink
Insert contiguous on mlir inputs that are not packed or broadcasted (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored May 7, 2024
1 parent 49aff39 commit 156122b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct find_nop_reshapes
reshapes.insert("multibroadcast");
reshapes.insert("pad");
reshapes.insert("slice");
reshapes.insert("step");
reshapes.insert("transpose");
reshapes.insert("reduce_mean");
reshapes.insert("reduce_max");
Expand Down
24 changes: 20 additions & 4 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,20 @@ MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
});
}

std::vector<instruction_ref> mlir_contiguous(module_pass_manager& mpm,
const std::vector<instruction_ref>& inputs)
{
std::vector<instruction_ref> result;
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(result), [&](instruction_ref input) {
if(input->get_shape().packed() or input->get_shape().broadcasted())
return input;
return mpm.get_module().insert_instruction(
std::next(input), make_op("contiguous"), input);
});
return result;
}

struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
Expand Down Expand Up @@ -432,7 +446,7 @@ struct find_mlir_fused_ops
[&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
}
};

Expand Down Expand Up @@ -461,8 +475,10 @@ struct find_mlir_standalone_op
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm});
mpm.get_module().replace_instruction(gemm_based_op,
mlir_op{gemm_based_op->get_operator()},
mlir_contiguous(mpm, top_inputs),
{mm});
}
};

Expand Down Expand Up @@ -541,7 +557,7 @@ struct find_mlir_standalone_attention_op
mm->add_return({ins_to_replace});

mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
}
};

Expand Down
48 changes: 48 additions & 0 deletions test/verify/test_step_dot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

struct test_step_dot : verify_program<test_step_dot>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {128, 4, 64, 196}};
migraphx::shape bs{migraphx::shape::float_type, {128, 4, 196, 196}};
auto a = mm->add_parameter("input", as);
auto b = mm->add_literal(migraphx::generate_literal(bs));
auto step =
mm->add_instruction(migraphx::make_op("step", {{"axes", {2}}, {"steps", {2}}}), a);
auto dot = mm->add_instruction(migraphx::make_op("dot"), step, b);
mm->add_return({dot});
return p;
}
std::string section() const { return "gemm"; }
};

0 comments on commit 156122b

Please sign in to comment.