Skip to content

Commit

Permalink
Improve VNNI APIs and verification
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jan 16, 2025
1 parent c7d5a5d commit ae2dfa6
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 29 deletions.
13 changes: 6 additions & 7 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,24 @@ enum class VnniOperandRank {
BRGEMM_OUTS = 3
};

// Return the VNNI blocking factor.
// Return the VNNI blocking factor if it can be determined for the given type or
// zero, otherwise.
// Optionally, an operation can be provided to give access to DLTI.
std::optional<int64_t> getVnniBlockingFactor(Type type,
Operation *op = nullptr);
unsigned getVnniBlockingFactor(Type type, Operation *op = nullptr);

// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor = std::nullopt);
unsigned blockingFactor = 0);

// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor = std::nullopt);
unsigned blockingFactor = 0);

// Return true if the linalg operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);
bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor = 0);

} // namespace utils
} // namespace vnni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,9 +1068,10 @@ struct ConvertVnniPacking : public OpRewritePattern<linalg::TransposeOp> {
if (failed(stridesOnOutput) || stridesOnOutput->back() != 1)
return failure();
// Ajust ldo based on the VNNI factor.
unaryInfo.ldo =
stridesOnOutput->front() /
*vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp);
auto vnniFactor =
vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp);
assert(vnniFactor && "Failed to get VNNI blocking factor");
unaryInfo.ldo = stridesOnOutput->front() / vnniFactor;
auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get(
rewriter.getContext(), xsmm::UnaryFlags::NONE));
xsmm::UnaryKindAttr kind =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE,
outType)) {
// Adjust ldo based on vnni factor
auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType, transposeOp);
auto vnniFactor = vnni::utils::getVnniBlockingFactor(outType, transposeOp);
assert(vnniFactor && "Failed to get VNNI blocking factor");
unaryInfo.ldo = unaryInfo.ldo / vnniFactor;
} else {
std::swap(unaryInfo.m, unaryInfo.n);
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter,

Location loc = matmulOp.getLoc();
SmallVector<OpFoldResult> tilesOnSmallK = {
rewriter.getI64IntegerAttr(*blockingFactor)};
rewriter.getI64IntegerAttr(blockingFactor)};
SmallVector<std::pair<Value, unsigned>> kOperands;
matmulOp.mapIterationSpaceDimToAllOperandDims(dims->k.back(), kOperands);
if (kOperands.size() != 2)
Expand Down Expand Up @@ -416,7 +416,7 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter,
"unsupported blocking factor for type");
}
SmallVector<OpFoldResult> tilesOnK = {
rewriter.getI64IntegerAttr(*blockingFactor)};
rewriter.getI64IntegerAttr(blockingFactor)};

Location loc = brgemmOp.getLoc();
// Reshape input A.
Expand Down
34 changes: 18 additions & 16 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace mlir {
namespace vnni {
namespace utils {

std::optional<int64_t> getVnniBlockingFactor(Type type, Operation *op) {
int64_t blockingFactor = 0;
unsigned getVnniBlockingFactor(Type type, Operation *op) {
unsigned blockingFactor = 0;

auto elementType = getElementTypeOrSelf(type);
if (elementType.isBF16()) {
Expand All @@ -37,14 +37,14 @@ std::optional<int64_t> getVnniBlockingFactor(Type type, Operation *op) {
}
}

if (blockingFactor != 0 && blockingFactor % 2 == 0)
return blockingFactor;
// Ensure that the factor is divisible by two.
if (blockingFactor % 2 != 0)
return 0;

return std::nullopt;
return blockingFactor;
}

bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor) {
bool isInVnniLayout(linalg::LinalgOp linalgOp, unsigned blockingFactor) {
// Narrow down type operations - VNNI only applies to contractions.
if (!linalg::isaContractionOpInterface(linalgOp))
return false;
Expand Down Expand Up @@ -101,10 +101,12 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp,
// - statically known
// - multiple of 2 or equal to the specified factor
auto vnniDimSize = typeB.getShape().back();
if (!(vnniDimSize != ShapedType::kDynamic &&
typeA.getShape().back() == vnniDimSize &&
(blockingFactor ? vnniDimSize == *blockingFactor
: vnniDimSize % 2 == 0)))
if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
vnniDimSize % 2 != 0)
return false;
if (typeA.getShape().back() != vnniDimSize)
return false;
if (blockingFactor && vnniDimSize != blockingFactor)
return false;

// The split reduction dimension size should also match.
Expand All @@ -115,20 +117,20 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp,
}

bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor) {
unsigned blockingFactor) {
return isInVnniLayout(static_cast<int64_t>(expectedRank), shape,
blockingFactor);
}

bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor) {
unsigned blockingFactor) {
if (shape.getRank() != expectedRank || !shape.getElementType().isBF16())
return false;

if (shape.getShape().back() % 2 != 0)
auto vnniDim = shape.getShape().back();
if (vnniDim == 0 || vnniDim % 2 != 0)
return false;

if (blockingFactor && shape.getShape().back() != *blockingFactor)
if (blockingFactor && vnniDim != blockingFactor)
return false;

return true;
Expand Down

0 comments on commit ae2dfa6

Please sign in to comment.