Skip to content

Commit

Permalink
Merge branch 'main' into help-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreEichenberger authored Oct 17, 2024
2 parents abc902f + 625ba71 commit ae9592d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 72 deletions.
77 changes: 43 additions & 34 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
bool neverHas8 = outputDims[E1].isLiteralAndSmallerThan(8);
bool hasOnly64 =
outputDims[E1].isLiteral() && (outputDims[E1].getLiteral() % 64 == 0);
bool hasOnly8 =
outputDims[E1].isLiteral() && (outputDims[E1].getLiteral() % 8 == 0);

// Parallel...
if (enableParallel) {
Expand Down Expand Up @@ -170,6 +172,11 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
IndexExpr inputDataOffset = SymIE(inputOffset);
IndexExpr inputTileOffset = inputDataOffset.floorDiv(64);

// Buffer for small leftovers (used when E1 % 8 != 0)
Value bufferF32;
if (!hasOnly8)
bufferF32 = create.mem.alignedAlloc(bufferType);

// Prefetch
#if PREFETCH_CSU
DimsExpr prefetchAF = inputAF;
Expand Down Expand Up @@ -225,7 +232,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// tile.
for (int64_t i = 0; i < unrollVL; ++i) {
vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64,
{SymIE(inputTileOffset), l + (i * archVL)}, {});
{SymIE(inputTileOffset), l + (i * archVL)});
}
// Convert back to f32.
for (int64_t i = 0; i < unrollVL; ++i) {
Expand Down Expand Up @@ -271,7 +278,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// Load f16 values from input via reinterpreted data
// tile.
Value vecF16 = create.vec.loadIE(vecF16Type,
inputAsTx64, {SymIE(inputTileOffset), l}, {});
inputAsTx64, {SymIE(inputTileOffset), l});
// Convert back to f32.
auto convertOp =
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(
Expand All @@ -286,35 +293,37 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
{litArchVLHalf.getValue()});
});
}
// Deal with the last values: compute f32 using simd.
IndexExpr remainingScalarValues = tripCount % archVL;
IndexExpr lastL = tripCount - remainingScalarValues;
Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64,
{SymIE(inputTileOffset), lastL}, {});
// Convert back to f32.
auto convertOp =
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(loc, vecF16);
Value vecF32H = convertOp.getResult(0);
Value vecF32L = convertOp.getResult(1);
// Save into archVL value buffer.
Value bufferF32 = create.mem.alignedAlloca(bufferType);
create.vec.storeIE(vecF32H, bufferF32, {litZero});
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf});
// Save the remaining values as scalars.
create.scf.forLoop(litZero.getValue(),
remainingScalarValues.getValue(), 1,
[&](SCFBuilder b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope innerScope(b, &middleScope);
Value loopIndex = loopInd[0];
IndexExpr l = DimIE(loopIndex);
// Load converted value.
Value f32 = create.krnl.loadIE(bufferF32, {l});
DimsExpr outputAF = SymListIE(inputAF);
outputAF[E1] = outputAF[E1] + SymIE(lastL);
outputAF[E1] = outputAF[E1] + l;
create.krnl.storeIE(f32, alloc, outputAF);
});
if (!hasOnly8) {
// Deal with the last <8 values: compute f32 using simd.
IndexExpr remainingScalarValues = tripCount % archVL;
IndexExpr lastL = tripCount - remainingScalarValues;
Value vecF16 = create.vec.loadIE(
vecF16Type, inputAsTx64, {SymIE(inputTileOffset), lastL});
// Convert back to f32.
auto convertOp =
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(
loc, vecF16);
Value vecF32H = convertOp.getResult(0);
Value vecF32L = convertOp.getResult(1);
// Save into archVL value buffer.
create.vec.storeIE(vecF32H, bufferF32, {litZero});
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf});
// Save the remaining values as scalars.
create.scf.forLoop(litZero.getValue(),
remainingScalarValues.getValue(), 1,
[&](SCFBuilder b, ValueRange loopInd) {
MDBuilder create(b);
IndexExprScope innerScope(b, &middleScope);
Value loopIndex = loopInd[0];
IndexExpr l = DimIE(loopIndex);
// Load converted value.
Value f32 = create.krnl.loadIE(bufferF32, {l});
DimsExpr outputAF = SymListIE(inputAF);
outputAF[E1] = outputAF[E1] + SymIE(lastL);
outputAF[E1] = outputAF[E1] + l;
create.krnl.storeIE(f32, alloc, outputAF);
});
}
});
});
rewriter.eraseOp(unstickOp);
Expand Down Expand Up @@ -456,11 +465,11 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
MDBuilder create(b);
IndexExprScope outerScope(create.krnl, &allocScope);
DimsExpr outerIndices;
getIndexExprList<SymIE>(loopInd, outerIndices);
getIndexExprList<DimIE>(loopInd, outerIndices);
DimsExpr memAF = outerIndices;
memAF[E1] = memAF[E1] * 64; // Loop index for E1 is in tiles of 64.
Value allocOffset = create.krnl.getLinearOffsetIndexIE(alloc, memAF);
IndexExpr allocTileIndex = SymIE(allocOffset).floorDiv(64);
IndexExpr allocTileIndex = DimIE(allocOffset).floorDiv(64);
#if PREFETCH_CSU
DimsExpr prefetchAF = memAF;
// Prefetch current lines.
Expand All @@ -483,7 +492,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
MDBuilder create(b);
DimsExpr inputAF;
IndexExprScope innerScope(create.krnl, &outerScope);
SymIE l(loopInd[0]);
DimIE l(loopInd[0]);
getIndexExprList<SymIE>(memAF, inputAF);
// E1: add the "l" local E1 offset.
inputAF[E1] = inputAF[E1] + l;
Expand Down
Loading

0 comments on commit ae9592d

Please sign in to comment.