Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Test Only] Transpose dma on L2 target side #809

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,67 @@ LogicalResult processInputs(Operation *op, SmallVector<OpFoldResult> &offsets,
return success();
}

LogicalResult packL3ToL2(IREE::LinalgExt::PackOp packOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOp.getContext();

llvm::ArrayRef<int64_t> permutation = packOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();

SmallVector<OpFoldResult> innerSizes;
SmallVector<OpFoldResult> innerStrides;
SmallVector<OpFoldResult> innerOffsets;
auto innerDimsPos = packOp.getInnerDimsPos();

int numOuterDims = sizes.size() - innerTiles.size();
SmallVector<OpFoldResult> outerOffsets = SmallVector<OpFoldResult>(
offsets.begin(), offsets.begin() + numOuterDims);
SmallVector<OpFoldResult> outerStrides = SmallVector<OpFoldResult>(
strides.begin(), strides.begin() + numOuterDims);
SmallVector<OpFoldResult> outerSizes =
SmallVector<OpFoldResult>(sizes.begin(), sizes.begin() + numOuterDims);

// Apply inverse permutation to the outer dims if permutation provided (if
// permutation not provided, it is identity, and therefore so is the inverse).
if (!permutation.empty()) {
SmallVector<int64_t> inversePermutation =
invertPermutationVector(permutation);
applyPermutationToVector(outerStrides, inversePermutation);
applyPermutationToVector(outerSizes, inversePermutation);
applyPermutationToVector(outerOffsets, inversePermutation);
}
// Do the unpacking on the Outer dims.
llvm::SmallDenseMap<int64_t, int64_t> outerDimsIndexMap;
// Intialize the indexing of each outer dim.
for (int i = 0; i < numOuterDims; i++) {
outerDimsIndexMap[i] = i;
}
for (int i = 0; i < innerTiles.size(); i++) {
// Insert inner dims adjacent to there corresponding outer dims.
outerSizes.insert(
outerSizes.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
getAsIndexOpFoldResult(ctx, innerTiles[i]));
outerStrides.insert(
outerStrides.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
strides[numOuterDims + i]);
outerOffsets.insert(
outerOffsets.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
offsets[numOuterDims + i]);
// Update the map as all the dimensions inner to the innerDimsPos[i] are now
// shifted by 1.
for (int j = innerDimsPos[i] + 1; j < numOuterDims; j++) {
outerDimsIndexMap[j]++;
}
}
// Make the outer dims as the final returned dims
offsets = outerOffsets;
strides = outerStrides;
sizes = outerSizes;
return success();
}

/// Rewrite the pack/unpack op 'op' as a DMA operation. The function arguments
/// 'input', 'output', and 'innerTiles' are the input, output, and inner tile
/// of 'op'. If 'op' is not a pack/unpack op, or if it determined to not
Expand Down Expand Up @@ -283,10 +344,6 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
return failure();
}

if (!succeeded(processInputs(op, srcOffsets, srcShape, srcBaseStrides))) {
return failure();
}

// Prepare destination DMA inputs.
SmallVector<OpFoldResult> dstOffsets;
SmallVector<OpFoldResult> dstBaseStrides;
Expand All @@ -295,6 +352,23 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
return failure();
}

uint32_t srcMemspace =
cast<MemRefType>(input.getType()).getMemorySpaceAsInt();
uint32_t dstMemspace =
cast<MemRefType>(output.getType()).getMemorySpaceAsInt();

if (auto packOp = dyn_cast<IREE::LinalgExt::PackOp>(op) && srcMemspace == 0 &&
dstMemspace == 1) {
if (!succeeded(packL3ToL2(dyn_cast<IREE::LinalgExt::PackOp>(op), dstOffsets,
dstShape, dstBaseStrides))) {
return failure();
}
} else {
if (!succeeded(processInputs(op, srcOffsets, srcShape, srcBaseStrides))) {
return failure();
}
}

// Create logical objectFifos from source and destination memrefs.
Value srcVal = sourceOp->getResult(0);
Value dstVal = dstOp->getResult(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,24 @@ struct SubsumeLoopIntoDMA
return false;
};

auto circularUsersInSameScope =
[&](Value result,
SmallVector<AMDAIE::DoublyStridedOpInterface> users) -> bool {
bool currentUser = false;
for (AMDAIE::DoublyStridedOpInterface userOp : llvm::reverse(users)) {
if (isa<AMDAIE::NpuCircularDmaCpyNdOp>(userOp) &&
userOp != op.getOperation()) {
return true;
}
if (userOp == op.getOperation()) {
currentUser = true;
continue;
}
if (currentUser) return true;
}
return false;
};

uint8_t sourceMemspaceInt;
uint8_t targetMemspaceInt;
if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op.getOperation())) {
Expand Down Expand Up @@ -525,7 +543,17 @@ struct SubsumeLoopIntoDMA
return rewriter.notifyMatchFailure(
op, "should operate on an `amdaie.connection` op");
}
if (hasUsersInSameScope(connectionOp.getResult())) {
// Walk the parentOp and get users of the connection op in order.
Value dma = npuCircularDmaOp.getConnection();
SmallVector<AMDAIE::DoublyStridedOpInterface> dmaUsers;
parentOp->walk([&](AMDAIE::DoublyStridedOpInterface op) {
auto connection = dyn_cast_if_present<AMDAIE::ConnectionOp>(
op->getOperand(0).getDefiningOp());
if (connection == npuCircularDmaOp.getConnectionOp()) {
dmaUsers.push_back(op);
}
});
if (circularUsersInSameScope(dma, dmaUsers)) {
return rewriter.notifyMatchFailure(
op,
"Has users of same DMA in scope, analysis to check validity of "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager) {
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEDmaCompositionPass());
passManager.addPass(createAMDAIECanonicalizeDoublyStridedOpPass());
//passManager.addPass(createAMDAIEDmaCompositionPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIEDmaCSEPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ iree_lit_test_suite(
"pack_and_transpose_level1.mlir"
"pack_and_transpose_level2.mlir"
"pack_to_air.mlir"
"convert_to_dma.mlir"
"convert_to_dma_failures.mlir"
"pad.mlir"
"peel_for_loop.mlir"
Expand Down
Loading