Skip to content

Commit

Permalink
Double buffering improvements
Browse files Browse the repository at this point in the history
- Split the LDS reads and MFMA/WMMA into two independent loops
- Have them into two separate stages (so that they can be executed in
parallel)

This is to make our pipeline similar to what CK is doing in:
- https://github.com/ROCm/composable_kernel/blob/6d073d31bbc7d39d8b170d549f2af61970378150/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
  • Loading branch information
giuseros committed May 13, 2024
1 parent 6361349 commit c9930bc
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 20 deletions.
185 changes: 165 additions & 20 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,112 @@ struct GridwiseGemmAccelRewritePattern
: public OpRewritePattern<GridwiseGemmAccelOp> {
using OpRewritePattern<GridwiseGemmAccelOp>::OpRewritePattern;

// Generate only the compute loop, i.e., we assume here that all
// the data that we need is already in LDS
void generateComputeLoop(
Location loc, PatternRewriter &b,
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
Value regsA, Value regsB, Value regsC, StringAttr arch,
GemmFeaturesAttr features,
const RockAccelTuningParamAttrInterface tuningParams) const {

rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams();
int64_t mRepeats = params.mRepeats;
int64_t nRepeats = params.nRepeats;
int64_t kBasePerThread = params.kBasePerThread;

auto mLoop = b.create<affine::AffineForOp>(loc, 0, mRepeats);
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(mLoop.getBody());
Value i = mLoop.getInductionVar();

auto nLoop = b.create<affine::AffineForOp>(loc, 0, nRepeats);
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(nLoop.getBody());
Value j = nLoop.getInductionVar();

// regsC += regsA * regsB
auto kLoop = b.create<affine::AffineForOp>(loc, 0, kBasePerThread);
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(kLoop.getBody());
Value viewA =
accelEmitterPtr->generateThreadwiseViewBufferA(b, loc, regsA);
Value viewB =
accelEmitterPtr->generateThreadwiseViewBufferB(b, loc, regsB);
Value viewC =
accelEmitterPtr->generateThreadwiseViewBufferC(b, loc, regsC);
Value k = kLoop.getInductionVar();
b.create<ThreadwiseAccelGemmOp>(loc, viewA, viewB, viewC,
ValueRange{i, j, k}, arch, features,
tuningParams);
}
}
}
}

// Generate the Read loop from LDS. So we read A[0:mRepeats, 0:kBasePerThread]
// and B[0:nRepeats, 0:kBasePerThread] before entering the MMA loop
void generateReadLoop(
Location loc, PatternRewriter &b,
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
Value tid, Value ldsAView, Value ldsBView, Value regsA, Value regsB,
Value regsC, int64_t blockSize, int64_t inMPerThread,
int64_t inNPerThread, bool rotateMWithK, bool rotateNWithK) const {

// wrapLDSBufferForLoad is reading a single set of Ks into private memory
// A/B[m/n, 0:kBasePerThread]
Value ldsA = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, ldsAView, blockSize, inMPerThread, "m", rotateMWithK);

Value ldsB = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, ldsBView, blockSize, inNPerThread, "n", rotateNWithK);

rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams();
int64_t mRepeats = params.mRepeats;
int64_t nRepeats = params.nRepeats;
int64_t kBasePerThread = params.kBasePerThread;

// We enhance the transformation from wrapLDSBufferForLoad using a builder
// that, given a single index, splits it into "m"("n") and "k" and lets
// tid pass through. We can give those indices to wrapLDSBufferForLoad which should
// compute the right transform

// Read from LDS buffer for A
{
TopDownTMBuilder mkBuilder(b, {"tid", "mk"},
{blockSize, mRepeats * kBasePerThread}, loc);
mkBuilder.passThrough("tid");
mkBuilder.merge({"m", "k"}, {1, 2}, "mk", {mRepeats, kBasePerThread});

auto [ldsBufferA, ldsTransformsA, ignoreA] = rock::untransform(b, ldsA);
ldsTransformsA = rock::prependUpperViews(
b, b.getArrayAttr({mkBuilder.get()}), ldsTransformsA);
ldsA = rock::transform(b, ldsBufferA, ldsTransformsA);
b.create<ThreadwiseReadIntoOp>(loc, ldsA, regsA, b.getArrayAttr({}),
ValueRange{tid}, /*forceUnroll=*/true,
/*useIndexDiffs=*/true);
}

// Read from LDS buffer for B
{
TopDownTMBuilder nkBuilder(b, {"tid", "nk"},
{blockSize, nRepeats * kBasePerThread}, loc);
nkBuilder.passThrough("tid");
nkBuilder.merge({"n", "k"}, {1, 2}, "nk", {nRepeats, kBasePerThread});

auto [ldsBufferB, ldsTransformsB, ignoreB] = rock::untransform(b, ldsB);
ldsTransformsB = rock::prependUpperViews(
b, b.getArrayAttr({nkBuilder.get()}), ldsTransformsB);
ldsB = rock::transform(b, ldsBufferB, ldsTransformsB);
b.create<ThreadwiseReadIntoOp>(loc, ldsB, regsB, b.getArrayAttr({}),
ValueRange{tid}, /*forceUnroll=*/true,
/*useIndexDiffs=*/true);
}
}

LogicalResult matchAndRewrite(GridwiseGemmAccelOp op,
PatternRewriter &b) const override {
Location loc = op.getLoc();
Expand Down Expand Up @@ -2692,11 +2798,21 @@ struct GridwiseGemmAccelRewritePattern
Value ldsViewForGemmB = viewBufferAs(b, ldsByteBufferB, ldsReadTypeB);
int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats;

// TODO: add an heuristic to decide if the ii should be 1 or 2. This is for
// now not worth it, since any form of double buffering results in poor
// assembly begin generated. So we need to stick with II=2
int64_t initiationInterval = 2;

// Logic to setup buffers for blockwise_gemm_accel.
auto arrayA =
gpuAlloc(b, loc, kBasePerThread, argTypeA, AddressSpace::Private);
auto arrayB =
gpuAlloc(b, loc, kBasePerThread, argTypeB, AddressSpace::Private);
int64_t arrayALen = kBasePerThread;
int64_t arrayBLen = kBasePerThread;
if (initiationInterval == 1) {
arrayALen *= mRepeats;
arrayBLen *= nRepeats;
}

auto arrayA = gpuAlloc(b, loc, arrayALen, argTypeA, AddressSpace::Private);
auto arrayB = gpuAlloc(b, loc, arrayBLen, argTypeB, AddressSpace::Private);
auto regCAllocOp =
gpuAlloc(b, loc, nOutputVectors, accVectorType, AddressSpace::Private);

Expand All @@ -2709,8 +2825,9 @@ struct GridwiseGemmAccelRewritePattern
BlockwiseGemmAccelOp blockwiseGemmAccelOp;

auto loopOp = b.create<scf::ForOp>(loc, zeroConstantOp, nIterations, step);
loopOp->setAttr(PipelineAttr::getMnemonic(),
rock::PipelineAttr::get(b.getContext(), 2));
loopOp->setAttr(
PipelineAttr::getMnemonic(),
rock::PipelineAttr::get(b.getContext(), initiationInterval));
{
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(loopOp.getBody());
Expand Down Expand Up @@ -2772,20 +2889,48 @@ struct GridwiseGemmAccelRewritePattern
b.create<rock::YieldOp>(loc);
}

// Emit blockwise GEMM.
auto stage2 = b.create<StageOp>(loc, "MMA");
{
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
blockwiseGemmAccelOp = b.create<BlockwiseGemmAccelOp>(
loc, ldsViewForGemmA, ldsViewForGemmB,
b.getI32IntegerAttr(copyMPerThread),
b.getI32IntegerAttr(copyNPerThread),
(rotateMWithK ? b.getUnitAttr() : nullptr),
(rotateNWithK ? b.getUnitAttr() : nullptr), arrayA, arrayB,
regCAllocOp, op.getArchAttr(), op.getFeaturesAttr(),
op.getBlockSizeAttr(), op.getParamsAttr());
b.create<rock::YieldOp>(loc);
if (initiationInterval > 1) {
// Emit blockwise GEMM. This will load data from LDS and
// compute the MMA at the same time
auto stage2 = b.create<StageOp>(loc, "MMA");
{
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
blockwiseGemmAccelOp = b.create<BlockwiseGemmAccelOp>(
loc, ldsViewForGemmA, ldsViewForGemmB,
b.getI32IntegerAttr(copyMPerThread),
b.getI32IntegerAttr(copyNPerThread),
(rotateMWithK ? b.getUnitAttr() : nullptr),
(rotateNWithK ? b.getUnitAttr() : nullptr), arrayA, arrayB,
regCAllocOp, op.getArchAttr(), op.getFeaturesAttr(),
op.getBlockSizeAttr(), op.getParamsAttr());
b.create<rock::YieldOp>(loc);
}
} else {
// If we are running double-buffered pipeleines, it makes sense to also
// parellize The LDSRead/MMA stages. We do this here, by splitting the
// MMA loop in two separate stages
auto stage2 = b.create<StageOp>(loc, "LDSRead");
{
// Read from LDS into registers
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
generateReadLoop(loc, b, accelEmitterPtr, tid, ldsViewForGemmA,
ldsViewForGemmB, arrayA, arrayB, regCAllocOp,
blockSize, copyMPerThread, copyNPerThread,
rotateMWithK, rotateNWithK);
b.create<rock::YieldOp>(loc);
}
auto stage3 = b.create<StageOp>(loc, "MMA");
{
// Compute the matrix-multiplication
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage3.getRegion().emplaceBlock());
generateComputeLoop(loc, b, accelEmitterPtr, arrayA, arrayB,
regCAllocOp, op.getArchAttr(),
op.getFeaturesAttr(), tuningParams);
b.create<rock::YieldOp>(loc);
}
}
}

Expand Down
85 changes: 85 additions & 0 deletions mlir/test/Dialect/Rock/test_rock_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,88 @@ func.func @rock_pipeline_4_stages_ii_2(%input : memref<16xi8, #gpu.address_space
memref.store %out, %output[%c0] : memref<16xi8, #gpu.address_space<global>>
return
}

// CHECK-LABEL: rock_pipeline_4_stages_ii_1
func.func @rock_pipeline_4_stages_ii_1(%input : memref<16xi8, #gpu.address_space<global>>, %output : memref<16xi8, #gpu.address_space<global>>){
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : i8
%c16 = arith.constant 16 : index

%rawLds = rock.alloc() : memref<16xi8, #gpu.address_space<workgroup>>
%rawReg0 = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
%rawReg1 = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
%rawReg2 = rock.alloc() : memref<16xi8, #gpu.address_space<private>>

%lds = memref.view %rawLds[%c0][] : memref<16xi8, #gpu.address_space<workgroup>> to memref<16xi8, #gpu.address_space<workgroup>>
%reg0 = memref.view %rawReg0[%c0][] : memref<16xi8, #gpu.address_space<private>> to memref<16xi8, #gpu.address_space<private>>
%reg1 = memref.view %rawReg1[%c0][] : memref<16xi8, #gpu.address_space<private>> to memref<16xi8, #gpu.address_space<private>>
%reg2 = memref.view %rawReg2[%c0][] : memref<16xi8, #gpu.address_space<private>> to memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[rawLds0:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<workgroup>>
// CHECK: %[[rawLds1:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<workgroup>>
// CHECK: %[[rawReg0:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[rawReg1:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[rawReg2:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[ldsView0:.*]] = memref.view %[[rawLds0]]
// CHECK: %[[ldsView1:.*]] = memref.view %[[rawLds1]]
// CHECK: %[[regView0:.*]] = memref.view %[[rawReg0]]
// CHECK: %[[regView1:.*]] = memref.view %[[rawReg1]]
// CHECK: %[[regView2:.*]] = memref.view %[[rawReg2]]

// Please note how we swap S0/S1 and S2/S3 to avoid private multi-buffers
// CHECK: name = "S0"
// CHECK: name = "S1"
// CHECK: name = "S0"
// CHECK: name = "__fwd_barrier__"
// CHECK: name = "S1"
// CHECK: name = "S0"
// CHECK: name = "S2"
// CHECK: scf.for
// CHECK: name = "__fwd_barrier__"
// CHECK: rock.extract_multibuffer(%[[regView0]])
// CHECK: rock.extract_multibuffer(%[[ldsView0]], %[[ldsView1]])
// CHECK: name = "S1"
// CHECK: rock.extract_multibuffer(%[[regView0]])
// CHECK: name = "S0"
// CHECK: rock.extract_multibuffer(%[[regView1]])
// CHECK: name = "S3"
// CHECK: rock.extract_multibuffer(%[[ldsView0]], %[[ldsView1]])
// CHECK: rock.extract_multibuffer(%[[regView1]])
// CHECK: name = "S2"
// CHECK: name = "__fwd_barrier__"
// CHECK: name = "S1"
// CHECK: name = "S3"
// CHECK: name = "S2"
// CHECK: name = "__fwd_barrier__"
// CHECK: name = "S3"
// CHECK: name = "S2"
// CHECK: name = "S3"
scf.for %arg3 = %c0 to %c16 step %c1 {
rock.stage {
%tmp = memref.load %input[%arg3] : memref<16xi8, #gpu.address_space<global>>
memref.store %tmp, %reg0[%arg3] : memref<16xi8, #gpu.address_space<private>>
rock.yield
}{name="S0"}
rock.stage {
%tmp = memref.load %reg0[%arg3] : memref<16xi8, #gpu.address_space<private>>
memref.store %tmp, %lds[%arg3] : memref<16xi8, #gpu.address_space<workgroup>>
rock.yield
}{name="S1"}
rock.stage {
%tmp = memref.load %lds[%arg3] : memref<16xi8, #gpu.address_space<workgroup>>
%comp = arith.addi %tmp, %c2 : i8
memref.store %tmp, %reg1[%arg3] : memref<16xi8, #gpu.address_space<private>>
rock.yield
}{name="S2"}
rock.stage {
%tmp = memref.load %reg1[%arg3] : memref<16xi8, #gpu.address_space<private>>
%comp = arith.addi %tmp, %c2 : i8
memref.store %comp, %reg2[%arg3] : memref<16xi8, #gpu.address_space<private>>
rock.yield
}{name="S3"}
}{pipeline = #rock.pipeline<1>}

%out = memref.load %reg2[%c0] : memref<16xi8, #gpu.address_space<private>>
memref.store %out, %output[%c0] : memref<16xi8, #gpu.address_space<global>>
return
}

0 comments on commit c9930bc

Please sign in to comment.