Skip to content

Commit

Permalink
Fix VNNI affine maps (#998)
Browse files Browse the repository at this point in the history
Retires incorrect VNNI affine map representation using `floordiv 2` and
replaces it with an extra separate VNNI dimension in matrix operand A:
`expand <M x K> into <M x K/VNNI x VNNI>`.

The updated VNNI representation is propagated through matchers,
lowerings, and tests.
A new test is added to ensure correct results between Linalg to SCF
loops and XSMM lowering.
  • Loading branch information
adam-smnk authored Jan 10, 2025
1 parent bd4981e commit d42506b
Show file tree
Hide file tree
Showing 30 changed files with 531 additions and 265 deletions.
7 changes: 4 additions & 3 deletions include/TPP/IR/MatcherUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ bool isTwoDFillOpWithZeros(linalg::LinalgOp linalgOp,
SmallVectorImpl<Value> *capturedOperands = nullptr);

// Return a pair where the first member is true if and only if the operation
// represents a brgemm in VNNI layout. The second member tells if the brgemm has
// the batch dimension; it has meaning only if the first field is valid.
// represents a matmul (GEMM or BRGEMM) in VNNI layout. The second member tells
// if the brgemm has the batch dimension; it has meaning only if the first field
// is valid.
std::pair<bool, bool>
isBrgemmVnniOp(linalg::GenericOp linalgOp,
isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *capturedOperands = nullptr);

} // namespace utils
Expand Down
11 changes: 5 additions & 6 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AffineMap;
class VectorType;

namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg

namespace vnni {
Expand All @@ -46,11 +46,10 @@ bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector);

bool isInVnniLayout(int64_t expectedRank, VectorType vector);

// Return the first AffineDimExpr in the map `affineMap`
// with a VNNI layout pattern (AffineDimExpr floordiv VNNI).
FailureOr<AffineDimExpr> isInVnniLayout(linalg::GenericOp linalgOp,
AffineMap affineMap,
int64_t blockingFactor);
// Return true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);

} // namespace utils
} // namespace vnni
Expand Down
81 changes: 56 additions & 25 deletions lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,21 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
}
auto flags = rewriter.getArrayAttr(gemmFlags);
SmallVector<Value> invokeOperands;
SmallVector<Value> inputs = {linalgOp->getOperands()};

// Collapse VNNI factor dimension for matrix A:
// A <32x8x2> -> A <32x16>
if (brgemmInfo.isVnni) {
auto rankA = cast<ShapedType>(inputs[0].getType()).getRank();
assert(rankA >= 3 && "Invalid A mat rank for VNNI");
SmallVector<ReassociationIndices> reassoc;
for (int64_t index = 0; index < rankA - 2; index++)
reassoc.push_back({index});
reassoc.push_back(ReassociationIndices{rankA - 2, rankA - 1});

inputs[0] =
rewriter.create<memref::CollapseShapeOp>(loc, inputs[0], reassoc);
}

if (batch != 0) {
DenseI64ArrayAttr dims = DenseI64ArrayAttr::get(
Expand All @@ -463,8 +478,7 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
Value batchDim = rewriter.create<arith::ConstantOp>(
loc, integer64, rewriter.getIntegerAttr(integer64, batch));
invokeOperands.push_back(dispatched);
invokeOperands.append(linalgOp->getOperands().begin(),
linalgOp->getOperands().end());
invokeOperands.append(inputs);
invokeOperands.push_back(batchDim);
rewriter.replaceOpWithNewOp<xsmm::BrgemmOp>(linalgOp, dtype,
invokeOperands);
Expand All @@ -474,8 +488,7 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
Value dispatched = rewriter.create<xsmm::GemmDispatchOp>(
loc, integer64, dims, flags, dtype);
invokeOperands.push_back(dispatched);
invokeOperands.append(linalgOp->getOperands().begin(),
linalgOp->getOperands().end());
invokeOperands.append(inputs);
rewriter.replaceOpWithNewOp<xsmm::GemmOp>(linalgOp, dtype, invokeOperands);
}
}
Expand All @@ -502,7 +515,7 @@ checkStructure(linalg::LinalgOp linalgOp) {
return failure();
}
if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 ||
(contractionDims->k.size() != 2 && contractionDims->k.size() != 1) ||
contractionDims->k.size() > 3 || contractionDims->k.size() < 1 ||
contractionDims->batch.size() != 0) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n");
return failure();
Expand Down Expand Up @@ -575,14 +588,16 @@ static FailureOr<BrgemmInfo> checkAccess(linalg::LinalgOp linalgOp, unsigned m,
auto loops = linalgOp.computeStaticLoopSizes();
int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0;

bool isVnni = vnni::utils::isInVnniLayout(linalgOp);

BrgemmInfo info{loops[m], loops[n], loops[k], batchVal, *lda,
*ldb, *ldc, strideA, strideB};
*ldb, *ldc, strideA, strideB, isVnni};
return info;
}

// Check if the given generic is mappable to a brgemm xsmm op.
// - It is a contraction, with:
// -- 1 m and 1 n and 2 k dimensions.
// -- 1 m, 1 n, and 2 or 3 (VNNI) k dimensions.
// -- m appears on the LHS and OUT but not in RHS.
// -- n appears on the RHS and OUT but not in LHS.
// -- k and k' appear on the RHS and LHS but not OUT.
Expand All @@ -600,8 +615,15 @@ static FailureOr<BrgemmInfo> isMappableToBrgemm(linalg::LinalgOp linalgOp) {
unsigned m = contractionDims->m[0];
unsigned n = contractionDims->n[0];
unsigned k = contractionDims->k.back();

// Check if there is a batch reduce dimension.
// At least one K-dim is the GEMM reduction.
// In case of VNNI layout, there is additional reduction dimension
// representing VNNI blocking factor.
std::optional<unsigned> batch;
if (contractionDims->k.size() == 2)
unsigned numBrgemmReductionDims =
vnni::utils::isInVnniLayout(linalgOp) ? 3 : 2;
if (contractionDims->k.size() == numBrgemmReductionDims)
batch = contractionDims->k.front();

LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Candidate dims: "
Expand Down Expand Up @@ -772,17 +794,23 @@ makeMinorDimensionsInnerMost(RewriterBase &rewriter, linalg::GenericOp linalgOp,
return linalgOp;
}

if (!isInnerMostDim(operandA, *minorKInCodomainOpA)) {
bool isVnni = vnni::utils::isInVnniLayout(linalgOp);

if (!isVnni && !isInnerMostDim(operandA, *minorKInCodomainOpA)) {
LLVM_DEBUG(llvm::dbgs()
<< "[makeMinorDimensionsInnerMost] emit transpose for A\n");
assert(isInnerMostDim(operandA, *minorMInCodomainOpA));
if (!isInnerMostDim(operandA, *minorMInCodomainOpA))
return failure();
emitTransposeOnOperand(rewriter, linalgOp, operandA, *minorKInCodomainOpA,
*minorMInCodomainOpA);
}
if (!isInnerMostDim(operandB, *minorNInCodomainOpB)) {
// Do not inject transposes in case of VNNI format.
// Otherwise, it breaks later VNNI layout validation.
if (!isVnni && !isInnerMostDim(operandB, *minorNInCodomainOpB)) {
LLVM_DEBUG(llvm::dbgs()
<< "[makeMinorDimensionsInnerMost] emit transpose for B\n");
assert(isInnerMostDim(operandB, *minorKInCodomainOpB));
if (!isInnerMostDim(operandB, *minorKInCodomainOpB))
return failure();
emitTransposeOnOperand(rewriter, linalgOp, operandB, *minorKInCodomainOpB,
*minorNInCodomainOpB);
}
Expand All @@ -795,7 +823,7 @@ void ConvertLinalgToXsmm::runOnOperation() {
IRRewriter rewriter(&getContext());

// Enable conversion for linalg.generic to XSMM Brgemm if possible.
auto res = getOperation()->walk([&](linalg::GenericOp genericOp) {
getOperation()->walk([&](linalg::GenericOp genericOp) {
auto contractionDims = checkStructure(genericOp);
// If the generic does not match the structure of a Brgemm op, skip it.
if (failed(contractionDims))
Expand All @@ -804,22 +832,18 @@ void ConvertLinalgToXsmm::runOnOperation() {
unsigned n = contractionDims->n[0];
unsigned k = contractionDims->k.back();
std::optional<unsigned> batch;
if (contractionDims->k.size() == 2)
if (contractionDims->k.size() == 3)
contractionDims->k.front();

if (failed(checkAccess(genericOp, m, n, k, batch))) {
// The generic is a Brgemm but the strides of the selected dims (m, n, k)
// are not unit strides. Inject transposes to bring them innermost.
if (failed(makeMinorDimensionsInnerMost(rewriter, genericOp, m, n, k))) {
return WalkResult::interrupt();
return WalkResult::skip();
}
}
return WalkResult::advance();
});
if (res.wasInterrupted()) {
LLVM_DEBUG(llvm::dbgs() << "pass failed!\n");
return signalPassFailure();
}
SmallVector<StringRef> skipPatterns(skipOperations.begin(),
skipOperations.end());
tpp::populateLinalgToXsmmPatterns(patterns, skipPatterns);
Expand Down Expand Up @@ -1069,11 +1093,11 @@ struct ConvertGenericToVnniMatmulLikeOp
return rewriter.notifyMatchFailure(genericOp, "expects buffer semantics");
}

auto [isBrgemmOp, hasBatch] = structured_match::utils::isBrgemmVnniOp(
auto [isMatmulVnni, hasBatch] = structured_match::utils::isMatmulVnniOp(
genericOp, /*operands=*/nullptr);
if (!isBrgemmOp) {
if (!isMatmulVnni) {
return rewriter.notifyMatchFailure(
genericOp, "expects an operation mappable to brgemm");
genericOp, "expects an operation mappable to VNNI contraction");
}

Value bufferA = genericOp.getDpsInputs()[0];
Expand All @@ -1085,7 +1109,15 @@ struct ConvertGenericToVnniMatmulLikeOp
int64_t kPos = 1;
if (hasBatch)
kPos++;
int64_t k = cast<ShapedType>(bufferA.getType()).getShape()[kPos];
// Take the whole reduction dim size. Account for the VNNI factor (ensured
// by the earlier check) that splits the K dim in the shape.
std::optional<int64_t> vnniFactor =
vnni::utils::getVnniBlockingFactor(bufferB.getType());
if (!vnniFactor)
return rewriter.notifyMatchFailure(genericOp,
"failed to determine VNNI factor");
int64_t k =
cast<ShapedType>(bufferA.getType()).getShape()[kPos] * *vnniFactor;
int64_t batch = 0;
if (hasBatch)
batch = cast<ShapedType>(bufferA.getType()).getShape()[0];
Expand All @@ -1107,8 +1139,7 @@ struct ConvertGenericToVnniMatmulLikeOp
if (hasBatch)
leadingDimPosOnAandB++;
int64_t lda = (*stridesOnLhs)[leadingDimPosOnAandB];
int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] /
*vnni::utils::getVnniBlockingFactor(bufferB.getType());
int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] / *vnniFactor;
int64_t ldc = (*stridesOnOutput)[0];

BrgemmInfo brgemmInfo{m, n, k, batch, lda,
Expand Down
31 changes: 16 additions & 15 deletions lib/TPP/IR/MatcherUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap,
return res;
}

// Return true if the linalg.generic can be mapped to a brgemm in VNNI
// format.
std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
// Return true if the linalg.generic can be mapped to a matmul (GEMM or BRGEMM)
// in VNNI format.
std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *operands) {
bool hasBatch = false;
auto blockingFactor =
Expand All @@ -56,8 +56,8 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
.operation(NumOfLoops(_OR(EqualsTo(5), EqualsTo(4))))
.input(MatchAll(), HasStaticShape())
.output(MatchAll(), HasStaticShape())
.input(MatchOne(0), HasMap(BroadcastableProjectedPermutation(), &mapOperandA))
.input(MatchOne(1), HasMap(Any(), &mapOperandB))
.input(MatchOne(0), HasMap(ProjectedPermutation(), &mapOperandA))
.input(MatchOne(1), HasMap(ProjectedPermutation(), &mapOperandB))
.output(MatchOne(0), HasMap(BroadcastableProjectedPermutation(), &mapOperandC))
.region(MatchOne(0),
WithOpChain<arith::MulFOp, arith::AddFOp>(operands));
Expand All @@ -82,17 +82,18 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,

llvm::SmallVector<int64_t> operandAPosIterRed = getIteratorPos(
linalgOp, mapOperandA, mlir::utils::IteratorType::reduction);
if (operandAPosIterRed.size() != 2 && operandAPosIterRed.size() != 1)
unsigned numRedItersA = operandAPosIterRed.size();
if (numRedItersA != 3 && numRedItersA != 2)
return std::make_pair(false, hasBatch);

// Check if there is an extra outer batch reduce K-dim.
// For VNNI format:
// - one inner K-dim is the GEMM reduction
// - one inner K-dim is the VNNI blocking factor
int64_t batchRedIter = std::numeric_limits<int64_t>::max();
int64_t kRedIter = std::numeric_limits<int64_t>::max();
if (operandAPosIterRed.size() == 2) {
if (numRedItersA == 3) {
batchRedIter = operandAPosIterRed[0];
kRedIter = operandAPosIterRed[1];
hasBatch = true;
} else {
kRedIter = operandAPosIterRed[0];
}

// Operand B: One parallel iterator (j) and three reduction ones (batch,
Expand All @@ -112,10 +113,10 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
return std::make_pair(false, hasBatch);
}

auto vnniDim =
vnni::utils::isInVnniLayout(linalgOp, mapOperandB, *blockingFactor);
bool isBrgemmOp = succeeded(vnniDim) && vnniDim->getPosition() == kRedIter;
return std::make_pair(isBrgemmOp, hasBatch);
// At this point, the operation is a valid matmul contraction.
// Finally, ensure that it is in VNNI layout.
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor);
return std::make_pair(isVnniMatmul, hasBatch);
}

// Return true if all the operand have the same type, i.e., no implicit
Expand Down
Loading

0 comments on commit d42506b

Please sign in to comment.