From da3c0013b74c03b45f65215b9aef7d13e38cf6ff Mon Sep 17 00:00:00 2001 From: Shiv Date: Sat, 10 Feb 2024 00:41:51 +0000 Subject: [PATCH 1/9] match gemm_softmax_gemm when there is no scale --- src/targets/gpu/prefuse_ops.cpp | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index ef3afb38883..79a9cf5f093 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -170,7 +170,8 @@ struct find_gemm_softmax_gemm match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1"))); auto mul = match::name("mul")( match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); - auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); + auto softmax = + match::name("softmax")(match::arg(0)(match::any_of(mul, gemm1))).bind("softmax"); return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))( match::arg(0)(softmax)); @@ -181,17 +182,20 @@ struct find_gemm_softmax_gemm auto ins = r.result; auto gemm2_ins = r.instructions["gemm2"]; auto gemm1_ins = r.instructions["gemm1"]; - auto scale_lit = r.instructions["scale"]; float scale = 1.0; - scale_lit->eval().visit([&](const auto s) { - // CK only supports single-valued scale - if(std::all_of( - s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) - scale = s.front(); - else - return; - }); + if(r.instructions.find("scale") != r.instructions.end()) + { + auto scale_lit = r.instructions["scale"]; + scale_lit->eval().visit([&](const auto s) { + // CK only supports single-valued scale + if(std::all_of( + s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) + scale = s.front(); + else + return; + }); + } auto inputs = gemm1_ins->inputs(); // A, B inputs.push_back(gemm2_ins->inputs().back()); // B1 From 080a86f9469758cc6b58f20cb7ecde38be7c6d64 Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Sat, 10 Feb 2024 03:35:55 +0000 Subject: [PATCH 2/9] use contains instead of find --- src/targets/gpu/prefuse_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 79a9cf5f093..6d25a9c8d10 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -184,7 +184,7 @@ struct find_gemm_softmax_gemm auto gemm1_ins = r.instructions["gemm1"]; float scale = 1.0; - if(r.instructions.find("scale") != r.instructions.end()) + if(contains(r.instructions, "scale")) { auto scale_lit = r.instructions["scale"]; scale_lit->eval().visit([&](const auto s) { From d0776718d3bc0e8a78cfb49692bde6992b3c1370 Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Tue, 13 Feb 2024 00:50:43 +0000 Subject: [PATCH 3/9] prevent attention fusion if scale is non scalar --- src/targets/gpu/prefuse_ops.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 6d25a9c8d10..3b9ec45b643 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -187,14 +187,20 @@ struct find_gemm_softmax_gemm if(contains(r.instructions, "scale")) { auto scale_lit = r.instructions["scale"]; + // CK only supports single-valued scale + auto is_valid = false; scale_lit->eval().visit([&](const auto s) { - // CK only supports single-valued scale if(std::all_of( s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) - scale = s.front(); + { + scale = s.front(); + is_valid = true; + } else return; }); + if(not is_valid) + return; } auto inputs = gemm1_ins->inputs(); // A, B From 917139abb147eaa97db3c14a686c2c08d2546ae4 Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Tue, 13 Feb 2024 17:15:55 +0000 Subject: [PATCH 4/9] formatting --- src/targets/gpu/prefuse_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 3b9ec45b643..a6c6cfa72ce 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -188,7 +188,7 @@ struct find_gemm_softmax_gemm { auto scale_lit = r.instructions["scale"]; // CK only supports single-valued scale - auto is_valid = false; + auto is_valid = false; scale_lit->eval().visit([&](const auto s) { if(std::all_of( s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) From 88ba6e281f7edf31ec06d3386031839254aade09 Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Thu, 15 Feb 2024 16:59:40 +0000 Subject: [PATCH 5/9] rewrite unsupported scale condition --- src/targets/gpu/prefuse_ops.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index a6c6cfa72ce..8ebd155d8eb 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -188,19 +188,13 @@ struct find_gemm_softmax_gemm { auto scale_lit = r.instructions["scale"]; // CK only supports single-valued scale - auto is_valid = false; scale_lit->eval().visit([&](const auto s) { - if(std::all_of( + // CK only supports single-valued scale + if(not std::all_of( s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) - { - scale = s.front(); - is_valid = true; - } - else return; + scale = s.front(); }); - if(not is_valid) - return; } auto inputs = gemm1_ins->inputs(); // A, B From 904957ba9ecfb9c02e95e2d9b779c53b0a002afd Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Thu, 15 Feb 2024 18:16:39 +0000 Subject: [PATCH 6/9] prefuse ops test cases --- test/gpu/prefuse_ops.cpp | 204 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 test/gpu/prefuse_ops.cpp diff --git a/test/gpu/prefuse_ops.cpp b/test/gpu/prefuse_ops.cpp new file mode 100644 index 00000000000..18e360e87b6 --- /dev/null +++ b/test/gpu/prefuse_ops.cpp @@ -0,0 +1,204 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct pre_gemm_softmax_gemm : migraphx::gpu::gemm_softmax_gemm +{ + std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } +}; + +void run_pass(migraphx::program& p) +{ + migraphx::run_passes(p, {migraphx::gpu::prefuse_ops{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(find_gemm_softmax_gemm) +{ + migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}}; + migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}}; + + auto create_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto z = mm->add_parameter("z", s1); + auto scale = mm->add_literal(2.0f); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), x, y); + auto scale_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", dot1->get_shape().lens()}}), scale); + auto mul = mm->add_instruction(migraphx::make_op("mul"), dot1, scale_mb); + auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z); + mm->add_return({dot2}); + return p; + }; + + auto create_fused_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto z = mm->add_parameter("z", s1); + + auto attn = + mm->add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 2}, x, y, z); + + mm->add_return({attn}); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2; + if(migraphx::gpu::mlir_attention_enabled()) + { + p2 = create_fused_program(); + } + else + { + p2 = p1; + } + + run_pass(p1); + + EXPECT(p1 == p2); +} + +TEST_CASE(find_gemm_softmax_gemm_multi_scale) +{ + migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}}; + migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}}; + + auto create_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto z = mm->add_parameter("z", s1); + auto scale = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}, 10)); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), x, y); + auto scale_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", dot1->get_shape().lens()}}), scale); + auto mul = mm->add_instruction(migraphx::make_op("mul"), dot1, scale_mb); + auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z); + mm->add_return({dot2}); + return p; + }; + + auto create_fused_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto z = mm->add_parameter("z", s1); + + auto attn = + mm->add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z); + + mm->add_return({attn}); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2; + if(migraphx::gpu::mlir_attention_enabled()) + { + p2 = create_fused_program(); + } + else + { + p2 = p1; + } + + run_pass(p1); + + EXPECT(p1 == p2); +} + +TEST_CASE(find_gemm_softmax_gemm_no_scale) +{ + migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}}; + migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}}; + + auto create_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto z = mm->add_parameter("z", s1); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), x, y); + auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), dot1); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z); + mm->add_return({dot2}); + return p; + }; + + auto create_fused_program = [=]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto z = mm->add_parameter("z", s1); + + auto attn = + mm->add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z); + + mm->add_return({attn}); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2; + if(migraphx::gpu::mlir_attention_enabled()) + { + p2 = create_fused_program(); + } + else + { + p2 = p1; + } + + run_pass(p1); + + EXPECT(p1 == p2); +} + + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From c2ea5cf3e51af62374aeb7cdc930d41dd5d5d7d6 Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Thu, 15 Feb 2024 19:43:54 +0000 Subject: [PATCH 7/9] update test case to use enable_attention flag --- .../gpu/include/migraphx/gpu/prefuse_ops.hpp | 1 + src/targets/gpu/prefuse_ops.cpp | 25 ++- test/gpu/prefuse_ops.cpp | 179 ++++++------------ 3 files changed, 79 insertions(+), 126 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp b/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp index d9db515c462..00a62b1948b 100644 --- a/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp +++ b/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp @@ -36,6 +36,7 @@ namespace gpu { struct MIGRAPHX_GPU_EXPORT prefuse_ops { + bool enable_attention = false; std::string name() const { return "gpu::prefuse_ops"; } void apply(module_pass_manager& mpm) const; }; diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 8ebd155d8eb..4457a4481d8 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -162,19 +162,32 @@ auto is_mlir_gemm() }); } +auto is_test_gemm(bool enable_attention) +{ + return match::make_basic_pred_matcher([=](instruction_ref ins) { + if(ins->name() != "dot") + return false; + return enable_attention; + }); +} + struct find_gemm_softmax_gemm { + bool enable_attention; + auto matcher() const { - auto gemm1 = match::skip(match::name("contiguous"))( - match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1"))); - auto mul = match::name("mul")( + auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")( + match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention)) + .bind("gemm1"))); + auto mul = match::name("mul")( match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); auto softmax = match::name("softmax")(match::arg(0)(match::any_of(mul, gemm1))).bind("softmax"); - return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))( - match::arg(0)(softmax)); + return match::name("dot")( + match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention)) + .bind("gemm2"))(match::arg(0)(softmax)); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const @@ -212,7 +225,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const match::find_matches(mpm.get_module(), find_layernorm{}); mpm.run_pass(dead_code_elimination{}); match::find_matches(mpm.get_module(), find_add_layernorm{}); - match::find_matches(mpm, find_gemm_softmax_gemm{}); + match::find_matches(mpm, find_gemm_softmax_gemm{enable_attention}); } } // namespace gpu diff --git a/test/gpu/prefuse_ops.cpp b/test/gpu/prefuse_ops.cpp index 18e360e87b6..faf662e56a8 100644 --- a/test/gpu/prefuse_ops.cpp +++ b/test/gpu/prefuse_ops.cpp @@ -39,9 +39,9 @@ struct pre_gemm_softmax_gemm : migraphx::gpu::gemm_softmax_gemm std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } }; -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::gpu::prefuse_ops{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {migraphx::gpu::prefuse_ops{true}, migraphx::dead_code_elimination{}}); } TEST_CASE(find_gemm_softmax_gemm) @@ -49,52 +49,32 @@ TEST_CASE(find_gemm_softmax_gemm) migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}}; migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}}; - auto create_program = [=]() { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto z = mm->add_parameter("z", s1); - auto scale = mm->add_literal(2.0f); + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s1); + auto scale = m1.add_literal(2.0f); - auto dot1 = mm->add_instruction(migraphx::make_op("dot"), x, y); - auto scale_mb = mm->add_instruction( + auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x, y); + auto scale_mb = m1.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", dot1->get_shape().lens()}}), scale); - auto mul = mm->add_instruction(migraphx::make_op("mul"), dot1, scale_mb); - auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul); - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z); - mm->add_return({dot2}); - return p; - }; - - auto create_fused_program = [=]() { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto z = mm->add_parameter("z", s1); - - auto attn = - mm->add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 2}, x, y, z); - - mm->add_return({attn}); - return p; - }; - - migraphx::program p1 = create_program(); - migraphx::program p2; - if(migraphx::gpu::mlir_attention_enabled()) - { - p2 = create_fused_program(); + auto mul = m1.add_instruction(migraphx::make_op("mul"), dot1, scale_mb); + auto sm = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul); + m1.add_instruction(migraphx::make_op("dot"), sm, z); } - else + run_pass(m1); + + migraphx::module m2; { - p2 = p1; - } + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s1); - run_pass(p1); + m2.add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 2}, x, y, z); + } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(find_gemm_softmax_gemm_multi_scale) @@ -102,53 +82,33 @@ TEST_CASE(find_gemm_softmax_gemm_multi_scale) migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}}; migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}}; - auto create_program = [=]() { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto z = mm->add_parameter("z", s1); + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s1); auto scale = - mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}, 10)); + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}, 10)); - auto dot1 = mm->add_instruction(migraphx::make_op("dot"), x, y); - auto scale_mb = mm->add_instruction( + auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x, y); + auto scale_mb = m1.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", dot1->get_shape().lens()}}), scale); - auto mul = mm->add_instruction(migraphx::make_op("mul"), dot1, scale_mb); - auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul); - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z); - mm->add_return({dot2}); - return p; - }; - - auto create_fused_program = [=]() { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto z = mm->add_parameter("z", s1); - - auto attn = - mm->add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z); - - mm->add_return({attn}); - return p; - }; - - migraphx::program p1 = create_program(); - migraphx::program p2; - if(migraphx::gpu::mlir_attention_enabled()) - { - p2 = create_fused_program(); + auto mul = m1.add_instruction(migraphx::make_op("mul"), dot1, scale_mb); + auto sm = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), mul); + m1.add_instruction(migraphx::make_op("dot"), sm, z); } - else + run_pass(m1); + + migraphx::module m2; { - p2 = p1; - } + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s1); - run_pass(p1); + m2.add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z); + } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(find_gemm_softmax_gemm_no_scale) @@ -156,49 +116,28 @@ TEST_CASE(find_gemm_softmax_gemm_no_scale) migraphx::shape s1{migraphx::shape::float_type, {8, 16, 32}}; migraphx::shape s2{migraphx::shape::float_type, {8, 32, 16}}; - auto create_program = [=]() { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto z = mm->add_parameter("z", s1); - - auto dot1 = mm->add_instruction(migraphx::make_op("dot"), x, y); - auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), dot1); - auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z); - mm->add_return({dot2}); - return p; - }; - - auto create_fused_program = [=]() { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto z = mm->add_parameter("z", s1); - - auto attn = - mm->add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z); - - mm->add_return({attn}); - return p; - }; - - migraphx::program p1 = create_program(); - migraphx::program p2; - if(migraphx::gpu::mlir_attention_enabled()) + migraphx::module m1; { - p2 = create_fused_program(); + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s1); + + auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x, y); + auto sm = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), dot1); + m1.add_instruction(migraphx::make_op("dot"), sm, z); } - else + run_pass(m1); + + migraphx::module m2; { - p2 = p1; - } + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s1); - run_pass(p1); + m2.add_instruction(pre_gemm_softmax_gemm{migraphx::make_op("dot"), 1}, x, y, z); + } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } - int main(int argc, const char* argv[]) { test::run(argc, argv); } From 8508e66aaf3af6a4209ac0b6805eeb484e236d9a Mon Sep 17 00:00:00 2001 From: Shivad Bhavsar Date: Thu, 15 Feb 2024 19:50:57 +0000 Subject: [PATCH 8/9] license --- src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp b/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp index 00a62b1948b..bed64052009 100644 --- a/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp +++ b/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 9fcae35bdd83c31651fc3fb853afeef12275c17c Mon Sep 17 00:00:00 2001 From: shivadbhavsar <105248561+shivadbhavsar@users.noreply.github.com> Date: Fri, 16 Feb 2024 10:05:48 -0800 Subject: [PATCH 9/9] set default value --- src/targets/gpu/prefuse_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 4457a4481d8..f7d75891f01 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -173,7 +173,7 @@ auto is_test_gemm(bool enable_attention) struct find_gemm_softmax_gemm { - bool enable_attention; + bool enable_attention = false; auto matcher() const {