Skip to content

Commit

Permalink
Added bias term to attenOp for rocMLIR
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Feb 16, 2024
1 parent f413364 commit aae95ab
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 31 deletions.
35 changes: 26 additions & 9 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
#include <optional>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -103,7 +104,7 @@ struct mlir_op
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");

auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
auto mod_params = mod->get_parameter_names();
std::sort(mod_params.begin(), mod_params.end());
std::unordered_map<instruction_ref, shape> mod_ins_shapes;
Expand Down Expand Up @@ -478,15 +479,17 @@ struct find_mlir_standalone_attention_op
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();

std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto orig_inputs = gemm_softmax_gemm->inputs();

std::vector<instruction_ref> gemm0_inputs = {orig_inputs[0], orig_inputs[1]};
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));

std::vector<instruction_ref> inputs;
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());

// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
Expand All @@ -496,18 +499,31 @@ struct find_mlir_standalone_attention_op
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);

std::optional<instruction_ref> bias{nullopt};
if(orig_inputs.size() == 4)
{
auto bias_input = orig_inputs[2];
instruction_ref bias_param =
mm->add_parameter("bias", bias_input->get_shape().as_standard());
bias = mm->add_instruction(make_op("add"), scaled_gemm0, bias_param);
inputs.push_back(bias_input);
}

auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}),
bias ? bias.value() : scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] = get_fusable_input_op_stream(orig_inputs.back());
instruction_ref new_upper_v =
mm->add_parameter("z", old_upper_v->get_shape().as_standard());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});

auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});

std::unordered_map<instruction_ref, instruction_ref> ins_map;
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
Expand All @@ -521,6 +537,7 @@ struct find_mlir_standalone_attention_op
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});

mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
Expand Down
25 changes: 21 additions & 4 deletions src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <sstream>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -55,14 +56,30 @@ struct gemm_softmax_gemm
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];

const bool is_bias_enabled = inputs.size() == 4;
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[is_bias_enabled ? 3 : 2];

for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
auto gemm0_shape = op.compute_shape({a, b});
if(is_bias_enabled)
{
auto bias_shape = inputs[2];
if(bias_shape != gemm0_shape)
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent bias size"
<< ". Expected: " << gemm0_shape << ". Given: " << bias_shape;
MIGRAPHX_THROW(err_msg.str());
}
}

return op.compute_shape({gemm0_shape, b1});
}

static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
Expand Down
31 changes: 20 additions & 11 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ struct find_gemm_softmax_gemm
match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
auto add = match::name("add")(match::nargs(2),
match::any_arg(0, 1)(match::none_of(mul).bind("bias")));
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, add))).bind("softmax");

return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
Expand All @@ -181,19 +184,25 @@ struct find_gemm_softmax_gemm
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_lit = r.instructions["scale"];

float scale = 1.0;
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
if(auto scale_lit = r.instructions.find("scale"); scale_lit != r.instructions.end())
{
scale_lit->second->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
}

auto inputs = gemm1_ins->inputs(); // A, B
auto inputs = gemm1_ins->inputs(); // A, B
if(auto bias = r.instructions.find("bias"); bias != r.instructions.end())
{
inputs.push_back(bias->second);
}
inputs.push_back(gemm2_ins->inputs().back()); // B1

mpm.get_module().replace_instruction(
Expand Down
24 changes: 17 additions & 7 deletions test/verify/gemm_softmax_gemm_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu>
template <bool with_bias>

Check warning on line 30 in test/verify/gemm_softmax_gemm_relu.cpp

View workflow job for this annotation

GitHub Actions / tidy

invalid case style for value template parameter 'with_bias' [readability-identifier-naming,-warnings-as-errors]
struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<with_bias>>
{
migraphx::program create_program() const
{
Expand All @@ -41,16 +42,25 @@ struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu>
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});

b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), bias);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);

std::optional<migraphx::instruction_ref> add_bias{std::nullopt};
if constexpr(with_bias)
{
auto bias = mm->add_parameter("4", m1_shape);
add_bias = mm->add_instruction(migraphx::make_op("add"), scale, bias);
}

auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}),
with_bias ? add_bias.value() : scale);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
mm->add_instruction(migraphx::make_op("relu"), gemm2);
return p;
}
};

template struct gemm_softmax_gemm_relu<false>;
template struct gemm_softmax_gemm_relu<true>;

0 comments on commit aae95ab

Please sign in to comment.