Skip to content

Commit

Permalink
match gemm_softmax_gemm when there is no scale (#2748)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Feb 16, 2024
1 parent 55f79bb commit c8b6c6d
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -36,6 +36,7 @@ namespace gpu {

struct MIGRAPHX_GPU_EXPORT prefuse_ops
{
bool enable_attention = false;
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module_pass_manager& mpm) const;
};
Expand Down
47 changes: 32 additions & 15 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -162,36 +162,53 @@ auto is_mlir_gemm()
});
}

auto is_test_gemm(bool enable_attention)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "dot")
return false;
return enable_attention;
});
}

struct find_gemm_softmax_gemm
{
bool enable_attention = false;

auto matcher() const
{
auto gemm1 = match::skip(match::name("contiguous"))(
match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")(
auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(
match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention))
.bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, gemm1))).bind("softmax");

return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
return match::name("dot")(
match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention))
.bind("gemm2"))(match::arg(0)(softmax));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_lit = r.instructions["scale"];

float scale = 1.0;
scale_lit->eval().visit([&](const auto s) {
if(contains(r.instructions, "scale"))
{
auto scale_lit = r.instructions["scale"];
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(not std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
return;
scale = s.front();
else
return;
});
});
}

auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
Expand All @@ -208,7 +225,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
match::find_matches(mpm, find_gemm_softmax_gemm{});
match::find_matches(mpm, find_gemm_softmax_gemm{enable_attention});
}

} // namespace gpu
Expand Down
143 changes: 143 additions & 0 deletions test/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>

struct pre_gemm_softmax_gemm : migraphx::gpu::gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
};

void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::gpu::prefuse_ops{true}, migraphx::dead_code_elimination{}});
}

TEST_CASE(find_gemm_softmax_gemm)
{
migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}};
migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s1);
auto scale = m1.add_literal(2.0f);

auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x, y);
auto scale_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot1->get_shape().lens()}}), scale);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot1, scale_mb);
auto sm = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul);
m1.add_instruction(migraphx::make_op("dot"), sm, z);
}
run_pass(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s1);

m2.add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 2}, x, y, z);
}

EXPECT(m1 == m2);
}

TEST_CASE(find_gemm_softmax_gemm_multi_scale)
{
migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}};
migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s1);
auto scale =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}, 10));

auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x, y);
auto scale_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot1->get_shape().lens()}}), scale);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot1, scale_mb);
auto sm = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul);
m1.add_instruction(migraphx::make_op("dot"), sm, z);
}
run_pass(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s1);

m2.add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z);
}

EXPECT(m1 == m2);
}

TEST_CASE(find_gemm_softmax_gemm_no_scale)
{
migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}};
migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s1);

auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x, y);
auto sm = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), dot1);
m1.add_instruction(migraphx::make_op("dot"), sm, z);
}
run_pass(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s1);

m2.add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z);
}

EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit c8b6c6d

Please sign in to comment.