Skip to content

Commit

Permalink
[LLVMGPU] Nuke logic that is trying to simplify thread id arithmetic (i…
Browse files Browse the repository at this point in the history
…ree-org#16507)

This logic never helped anyway, and now that we have a pipeline that
uses multiple subgroups per workgroup, it is just a bug. Remove it.
  • Loading branch information
qedawkins authored Feb 21, 2024
1 parent ede9135 commit 4a80ee3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,28 +315,8 @@ struct LLVMGPUVectorDistributePass
void runOnOperation() override {
auto func = getOperation();

std::optional<int64_t> maybeSubgroupSize = std::nullopt;
if (func->hasAttr("subgroup_size")) {
maybeSubgroupSize =
llvm::cast<IntegerAttr>(func->getAttr("subgroup_size")).getInt();
} else {
maybeSubgroupSize = getSubgroupSize(func);
}
if (!maybeSubgroupSize) {
func.emitError() << "subgroup size required for vector distribution";
return signalPassFailure();
}

OpBuilder builder(func);
builder.setInsertionPointToStart(&func.getFunctionBody().front());
SmallVector<OpFoldResult> threadGrid = {
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(), gpu::Dimension::x),
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(), gpu::Dimension::y),
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(),
gpu::Dimension::z)};

std::array<int64_t, 3> workgroupSize;
if (func->hasAttr("subgroup_size")) {
if (func->hasAttr("workgroup_size")) {
auto tmpSizes =
llvm::cast<ArrayAttr>(func->getAttr("workgroup_size")).getValue();
for (auto [i, size] : llvm::enumerate(tmpSizes)) {
Expand All @@ -361,36 +341,20 @@ struct LLVMGPUVectorDistributePass
// Construct the expression for linearizing the thread indices.
AffineExpr linearId =
x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z;
AffineExpr laneId = linearId % *maybeSubgroupSize;

// This all needs some kind of simplification; the arithmetic it produces
// doest not get folded away as nicely as it could.
AffineMap idMap = AffineMap::getMultiDimIdentityMap(2, func.getContext());

// Clamp the thread indices to the workgroup sizes.
OpFoldResult c0 =
builder.createOrFold<arith::ConstantIndexOp>(func.getLoc(), 0);
threadGrid[0] = affine::makeComposedFoldedAffineMax(
builder, func.getLoc(), idMap, {threadGrid[0], c0});
threadGrid[1] = affine::makeComposedFoldedAffineMax(
builder, func.getLoc(), idMap, {threadGrid[1], c0});
threadGrid[2] = affine::makeComposedFoldedAffineMax(
builder, func.getLoc(), idMap, {threadGrid[2], c0});

OpFoldResult dimX = builder.getIndexAttr(workgroupSize[0] - 1);
OpFoldResult dimY = builder.getIndexAttr(workgroupSize[1] - 1);
OpFoldResult dimZ = builder.getIndexAttr(workgroupSize[2] - 1);
threadGrid[0] = affine::makeComposedFoldedAffineMin(
builder, func.getLoc(), idMap, {threadGrid[0], dimX});
threadGrid[1] = affine::makeComposedFoldedAffineMin(
builder, func.getLoc(), idMap, {threadGrid[1], dimY});
threadGrid[2] = affine::makeComposedFoldedAffineMin(
builder, func.getLoc(), idMap, {threadGrid[2], dimZ});
Value laneVal = affine::makeComposedAffineApply(builder, func.getLoc(),
laneId, threadGrid);

OpBuilder builder(func);
builder.setInsertionPointToStart(&func.getFunctionBody().front());
SmallVector<OpFoldResult> threadGrid = {
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(), gpu::Dimension::x),
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(), gpu::Dimension::y),
builder.createOrFold<gpu::ThreadIdOp>(func.getLoc(),
gpu::Dimension::z)};

Value linearThreadIdVal = affine::makeComposedAffineApply(
builder, func.getLoc(), linearId, threadGrid);

ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr,
laneVal, testLayout);
linearThreadIdVal, testLayout);
if (failed(distributeVectorOps(func, options.getPatterns(), options))) {
func->emitOpError() << "failed to distribute";
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset:
attributes {
mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 1, 1]} {
workgroup_size = [64, 1, 1]} {
%alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%cst = arith.constant 0.000000e+00 : f16
Expand Down Expand Up @@ -60,7 +60,7 @@ func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset:
attributes {
mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 1, 1]} {
workgroup_size = [64, 1, 1]} {
%alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%cst = arith.constant 0.000000e+00 : f16
Expand Down Expand Up @@ -91,17 +91,23 @@ func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset:
return
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 64)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>

// CHECK-LABEL: func.func @matmul_256x256x256
// CHECK: %[[TIDX:.+]] = gpu.thread_id x
// CHECK: %[[TIDY:.+]] = gpu.thread_id y
// CHECK: %[[TIDZ:.+]] = gpu.thread_id z
// CHECK: %[[LIN_ID:.+]] = affine.apply #[[$MAP]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: affine.delinearize_index %[[LIN_ID]]
// CHECK: %[[INIT_READ:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<4x1xf32>
// CHECK: %[[INIT_TRANSP:.+]] = vector.transpose %[[INIT_READ]], [1, 0]
// CHECK: %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_TRANSP]]
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>)
// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16>
// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16>
// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP1]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16>
// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: vector.transfer_write %[[RLOAD]], %[[RHS_ALLOC]]{{.*}} : vector<8x1xf16>, memref<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ func.func @matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>,
mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 1, 1]} {
workgroup_size = [64, 1, 1]} {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
Expand All @@ -31,7 +31,7 @@ func.func @matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>,
mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 1, 1]} {
workgroup_size = [64, 1, 1]} {
%0 = vector.contract {
indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
Expand All @@ -58,7 +58,7 @@ func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: ve
mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 2, 1]} {
workgroup_size = [64, 2, 1]} {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
Expand All @@ -78,7 +78,7 @@ func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offs
attributes {
mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 1, 1]} {
workgroup_size = [64, 1, 1]} {
%alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%cst = arith.constant 0.000000e+00 : f16
Expand Down Expand Up @@ -137,7 +137,7 @@ func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256,
attributes {
mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>,
subgroup_size = 64, workgroup_size = [64, 1, 1]} {
workgroup_size = [64, 1, 1]} {
%alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%cst = arith.constant 0.000000e+00 : f16
Expand Down

0 comments on commit 4a80ee3

Please sign in to comment.