Skip to content

Commit

Permalink
Added bias term to attenOp for rocMLIR
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Feb 15, 2024
1 parent f413364 commit 54a6c4a
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ struct find_gemm_softmax_gemm
match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
auto add = match::name("add")(match::nargs(2),
match::any_arg(0, 1)(match::none_of(mul).bind("bias")));
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, add))).bind("softmax");

return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
Expand All @@ -193,7 +196,11 @@ struct find_gemm_softmax_gemm
return;
});

auto inputs = gemm1_ins->inputs(); // A, B
auto inputs = gemm1_ins->inputs(); // A, B
if(auto it = r.instructions.find("bias"); it != r.instructions.end())
{
inputs.push_back(it->second);
}
inputs.push_back(gemm2_ins->inputs().back()); // B1

mpm.get_module().replace_instruction(
Expand Down

0 comments on commit 54a6c4a

Please sign in to comment.