Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

match gemm_softmax_gemm when there is no scale #2748

Merged
merged 13 commits into from
Feb 16, 2024
26 changes: 15 additions & 11 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -170,7 +170,8 @@ struct find_gemm_softmax_gemm
match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, gemm1))).bind("softmax");

return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
Expand All @@ -181,17 +182,20 @@ struct find_gemm_softmax_gemm
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_lit = r.instructions["scale"];

float scale = 1.0;
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
if(contains(r.instructions, "scale"))
{
auto scale_lit = r.instructions["scale"];
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to flip the logic so that it returns if scale values are different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something like this

Suggested change
if(std::all_of(
if(not std::all_of(....) return;
scale = s.front();

Copy link
Contributor Author

@shivadbhavsar shivadbhavsar Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But thats inside a visit, that return will not exit the apply method for this matcher
Oh you dont mean to modify the current functionallity, just rewrite it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah I saw that too. But looks like you rewrote it and looks fine now.

s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
}

auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
Expand Down
Loading