diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 5b9da64ca2b..e4b945360b0 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -103,7 +104,7 @@ struct mlir_op if(inputs.size() < 2) MIGRAPHX_THROW("should have at least two inputs."); - auto type = mod->get_output_shapes().front().type(); + auto type = mod->get_output_shapes().front().type(); auto mod_params = mod->get_parameter_names(); std::sort(mod_params.begin(), mod_params.end()); std::unordered_map mod_ins_shapes; @@ -478,15 +479,17 @@ struct find_mlir_standalone_attention_op static size_t counter = 0; module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++)); auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"]; - std::vector inputs; mm->set_bypass(); - std::unordered_map ins_map; - auto gemm0_inputs = gemm_softmax_gemm->inputs(); - gemm0_inputs.pop_back(); + auto orig_inputs = gemm_softmax_gemm->inputs(); + + std::vector gemm0_inputs = {orig_inputs[0], orig_inputs[1]}; auto [gemm0, top_gemm0_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot")); + + std::vector inputs; inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end()); + // handle scale auto v = gemm_softmax_gemm->get_operator().to_value(); assert(v.contains("scale")); @@ -496,10 +499,20 @@ struct find_mlir_standalone_attention_op make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit); auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast); + std::optional bias{nullopt}; + if(orig_inputs.size() == 4) + { + auto bias_input = orig_inputs[2]; + instruction_ref bias_param = + mm->add_parameter("bias", bias_input->get_shape().as_standard()); + bias = mm->add_instruction(make_op("add"), scaled_gemm0, bias_param); + inputs.push_back(bias_input); + } + auto softmax = mm->add_instruction( - make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0); - auto [old_upper_v, upper_v_op_stream] = - get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]); + make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), + bias ? bias.value() : scaled_gemm0); + auto [old_upper_v, upper_v_op_stream] = get_fusable_input_op_stream(orig_inputs.back()); instruction_ref new_upper_v = mm->add_parameter("z", old_upper_v->get_shape().as_standard()); for(const auto& op : reverse(upper_v_op_stream)) @@ -507,7 +520,10 @@ struct find_mlir_standalone_attention_op new_upper_v = mm->add_instruction(op, {new_upper_v}); } inputs.push_back(old_upper_v); - auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v}); + + auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v}); + + std::unordered_map ins_map; ins_map[gemm_softmax_gemm] = gemm1; auto ins_to_replace = gemm1; auto ins_to_be_replaced = gemm_softmax_gemm; @@ -521,6 +537,7 @@ struct find_mlir_standalone_attention_op ins_to_be_replaced = r.instructions["trailing_pm"]; } mm->add_return({ins_to_replace}); + mpm.get_module().replace_instruction( ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm}); } diff --git a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp index 38f4a1aef7f..0e13e49f18b 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp @@ -26,6 +26,7 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -55,14 +56,30 @@ struct gemm_softmax_gemm check_shapes{inputs, *this}.same_ndims(); if(inputs.size() < 3) MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size())); - auto a = inputs[0]; - auto b = inputs[1]; - auto b1 = inputs[2]; + + const bool is_bias_enabled = inputs.size() == 4; + auto a = inputs[0]; + auto b = inputs[1]; + auto b1 = inputs[is_bias_enabled ? 3 : 2]; + for(const auto& input : inputs) { check_gemm_shape(input); } - return op.compute_shape({op.compute_shape({a, b}), b1}); + auto gemm0_shape = op.compute_shape({a, b}); + if(is_bias_enabled) + { + auto bias_shape = inputs[2]; + if(bias_shape != gemm0_shape) + { + std::stringstream err_msg; + err_msg << name() << ": has inconsistent bias size" + << ". Expected: " << gemm0_shape << ". Given: " << bias_shape; + MIGRAPHX_THROW(err_msg.str()); + } + } + + return op.compute_shape({gemm0_shape, b1}); } static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index ef3afb38883..7055868ef0a 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -170,7 +170,10 @@ struct find_gemm_softmax_gemm match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1"))); auto mul = match::name("mul")( match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); - auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); + auto add = match::name("add")(match::nargs(2), + match::any_arg(0, 1)(match::none_of(mul).bind("bias"))); + auto softmax = + match::name("softmax")(match::arg(0)(match::any_of(mul, add))).bind("softmax"); return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))( match::arg(0)(softmax)); @@ -181,19 +184,25 @@ struct find_gemm_softmax_gemm auto ins = r.result; auto gemm2_ins = r.instructions["gemm2"]; auto gemm1_ins = r.instructions["gemm1"]; - auto scale_lit = r.instructions["scale"]; float scale = 1.0; - scale_lit->eval().visit([&](const auto s) { - // CK only supports single-valued scale - if(std::all_of( - s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) - scale = s.front(); - else - return; - }); + if(auto scale_lit = r.instructions.find("scale"); scale_lit != r.instructions.end()) + { + scale_lit->second->eval().visit([&](const auto s) { + // CK only supports single-valued scale + if(std::all_of( + s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) + scale = s.front(); + else + return; + }); + } - auto inputs = gemm1_ins->inputs(); // A, B + auto inputs = gemm1_ins->inputs(); // A, B + if(auto bias = r.instructions.find("bias"); bias != r.instructions.end()) + { + inputs.push_back(bias->second); + } inputs.push_back(gemm2_ins->inputs().back()); // B1 mpm.get_module().replace_instruction( diff --git a/test/verify/gemm_softmax_gemm_relu.cpp b/test/verify/gemm_softmax_gemm_relu.cpp index f0bdd46460a..f473d5ca036 100644 --- a/test/verify/gemm_softmax_gemm_relu.cpp +++ b/test/verify/gemm_softmax_gemm_relu.cpp @@ -27,7 +27,8 @@ #include #include -struct gemm_softmax_gemm_relu : verify_program +template +struct gemm_softmax_gemm_relu : verify_program> { migraphx::program create_program() const { @@ -41,16 +42,26 @@ struct gemm_softmax_gemm_relu : verify_program auto b1 = mm->add_parameter("3", m1_shape); std::vector eights(m2_elements, 0.125); auto eight = mm->add_literal(migraphx::literal{m2_shape, eights}); - std::vector zeros(m2_elements, 0); - auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros}); b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); - auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero); - auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), bias); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); + + std::optional add_bias{std::nullopt}; + if constexpr(with_bias) + { + std::vector one_tenth(m2_elements, 0.1); + auto bias = mm->add_literal(migraphx::literal{m2_shape, one_tenth}); + add_bias = mm->add_instruction(migraphx::make_op("add"), scale, bias); + } + + auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), + with_bias ? add_bias.value() : scale); auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, b1); mm->add_instruction(migraphx::make_op("relu"), gemm2); return p; } }; + +template struct gemm_softmax_gemm_relu; +template struct gemm_softmax_gemm_relu;