Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject desired pattern for handling Transpose for fp8 gemm rewrite #17440

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 52 additions & 7 deletions xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/permutation_util.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -361,7 +362,8 @@ std::optional<MatchedFp8Param> MatchFp8Param(HloInstruction *instr) {
// dimension. There must be only one contracting and only one non-contracting
// dimension. Keeps the layout the same.
HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
absl::Span<const int64_t> batch_dims) {
absl::Span<const int64_t> batch_dims,
bool col_maj = false) {
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
// Identify the dimensional order which describes a transpose of the
// contracting and non-contracting dimensions of the GEMM.
std::vector<int64_t> permutation(instr->shape().dimensions_size(), -1);
Expand All @@ -376,13 +378,52 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
non_contracting_dim = i;
}
}
permutation[non_contracting_dim] = contracting_dim;
permutation[contracting_dim] = non_contracting_dim;
if (!col_maj) {
permutation[non_contracting_dim] = contracting_dim;
permutation[contracting_dim] = non_contracting_dim;

Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
*new_shape.mutable_layout() = instr->shape().layout();
return instr->AddInstruction(
HloInstruction::CreateTranspose(new_shape, instr, permutation));
Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape());
*new_shape.mutable_layout() = instr->shape().layout();
wenscarl marked this conversation as resolved.
Show resolved Hide resolved

return instr->AddInstruction(
HloInstruction::CreateTranspose(new_shape, instr, permutation));
}

Shape normalized_input_shape =
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
instr->shape());
auto a0 = MakeBitcastHlo(instr, normalized_input_shape);

int new_contracting_dim = -1;
int new_non_contracting_dim = -1;
for (int i = 0; i < instr->shape().dimensions_size(); ++i) {
auto dim = LayoutUtil::Major(instr->shape().layout(), i);
if (dim == contracting_dim) {
new_contracting_dim = i;
} else if (dim == non_contracting_dim) {
new_non_contracting_dim = i;
} else {
// Discard the batch dimensions.
permutation[i] = i;
}
}

permutation[new_non_contracting_dim] = new_contracting_dim;
permutation[new_contracting_dim] = new_non_contracting_dim;

Shape transpose_shape =
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
ShapeUtil::PermuteDimensions(permutation, a0->shape());
*transpose_shape.mutable_layout() = a0->shape().layout();
HloInstruction *normalized_transpose = instr->AddInstruction(
HloInstruction::CreateTranspose(transpose_shape, a0, permutation));
std::vector<int64_t> layout_permuation(
instr->shape().layout().minor_to_major().begin(),
instr->shape().layout().minor_to_major().end());
absl::c_reverse(layout_permuation);
auto inv_perm = InversePermutation(layout_permuation);
Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape);
*final_shape.mutable_layout() = instr->shape().layout();
return MakeBitcastHlo(normalized_transpose, final_shape);
}

// If the bias is a sequence of ops that depend only on broadcasts of
Expand Down Expand Up @@ -1223,8 +1264,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
} else {
dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims);
}

a.fp8_input =
TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims);

a.fp8_input = RegulateColMajorTransposeMatrixF8(
wenscarl marked this conversation as resolved.
Show resolved Hide resolved
a.fp8_input, a_contracting_dims[0], a_batch_dims);
}

// Similarly, cuBLASLt requires the second operand to be column-major, so
Expand Down
Loading