Skip to content

Commit

Permalink
[LLVMGPU] Add multi-row vector reduction configuration (#73)
Browse files Browse the repository at this point in the history
This is to speed up matvec. The new configuration is experimental and
only applied on ROCm targets.
  • Loading branch information
kuhar authored and monorimet committed Jan 11, 2024
1 parent 6df3e0c commit 413bc13
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class VectorReductionToGPUPass
bool expandSubgroupReduction,
std::function<int(func::FuncOp)> getWarpSize)
: expandSubgroupReduction(expandSubgroupReduction),
getWarpSize(getWarpSize) {}
getWarpSize(std::move(getWarpSize)) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ hal.executable private @shared_memory_copy {
// CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32>
// CHECK: return


// -----

// Check that we multi-row matvec gets distributed across subgroup threads.
// Check that we multi-row matvec gets distributed across subgoroup threads.

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
Expand Down Expand Up @@ -316,7 +317,7 @@ hal.executable private @multirow {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16>
// CHECK: }
// CHECK-COUNT-12: gpu.shuffle xor
// CHECK: gpu.shuffle xor
// CHECK: scf.if {{.*}} {
// CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
// CHECK: }
Expand Down
27 changes: 12 additions & 15 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -961,24 +961,21 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
// process a few reductions (rows) along the last parallel dimension.
//
// TODO: This is enabled for matvec on ROCm for now. We should
// validate this strategy and extend to more linalg generics and to CUDA.
if (isRocmTarget(entryPoint) &&
llvm::none_of(bounds, ShapedType::isDynamic) && isMatvecLike(op)) {
int64_t lastParallelBound = bounds[parallelDims.back()];
int64_t numParallelReductions = 1;
const int64_t maxParallelFactor = groupSize / 4;
for (int64_t parallelFactor = 2;
(parallelFactor < maxParallelFactor) &&
(lastParallelBound % parallelFactor == 0) &&
(lastParallelBound > parallelFactor);
parallelFactor *= 2) {
numParallelReductions = parallelFactor;
// now.
if (numDynamicReductionDims == 0 && numParallelDims == 2 &&
isRocmTarget(entryPoint)) {
if (*parallelSize && !parallelDims.empty() && groupSize == subgroupSize) {
int maxParallelFactor = 4; // Keeping this conservative for now.
int64_t lastParallelBound = bounds[parallelDims.back()];
if (!ShapedType::isDynamic(lastParallelBound) &&
(lastParallelBound % maxParallelFactor == 0) &&
lastParallelBound > maxParallelFactor) {
workgroupTileSizes.back() = maxParallelFactor;
}
}
workgroupTileSizes.back() = numParallelReductions;
}

std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t remainingGroupSize = groupSize;
for (int i = reductionDims.size() - 1; i >= 0; --i) {
int64_t dim = reductionDims[i];
Expand Down
151 changes: 1 addition & 150 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 8], [0, 0, 512]{{\]}}>
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 4], [0, 0, 512]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUWarpReduction>
// CHECK-LABEL: hal.executable.export public @vmt
// CHECK-SAME: subgroup_size = 64 : index
Expand All @@ -99,152 +99,3 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf
// CHECK: func.func @vmt()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>

hal.executable @vmt {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100"}>) {
hal.executable.export @vmt layout(#pipeline_layout)
builtin.module {
func.func @vmt() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x32000xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>> -> tensor<1x4096xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<1x32000xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1x32000xf16>) -> tensor<1x32000xf16>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<1x32000xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %out, %8 : f16
linalg.yield %9 : f16
} -> tensor<1x32000xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 32000], strides = [1, 1] : tensor<1x32000xf16> -> !flow.dispatch.tensor<writeonly:tensor<1x32000xf16>>
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 8], [0, 0, 512]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUWarpReduction>
// CHECK-LABEL: hal.executable.export public @vmt
// CHECK-SAME: subgroup_size = 32 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
// CHECK: func.func @vmt()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

hal.executable private @i4_dequant_matvec {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) {
hal.executable.export public @i4_dequant_matvec ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer, ReadOnly>, <4, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @i4_dequant_matvec() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%6 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%7 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%8 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%9 = tensor.empty() : tensor<4096xf16>
%10 = tensor.empty() : tensor<4096x32x128xf16>
%11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<4096xf16>) -> tensor<4096xf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %6, %7 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%10 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%14 = arith.extui %in : i4 to i32
%15 = arith.uitofp %14 : i32 to f16
%16 = arith.subf %15, %in_1 : f16
%17 = arith.mulf %16, %in_0 : f16
linalg.yield %17 : f16
} -> tensor<4096x32x128xf16>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%8, %12 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%11 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%14 = arith.mulf %in, %in_0 : f16
%15 = arith.addf %14, %out : f16
linalg.yield %15 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %13, %4, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
}
}

// TODO: We should process multiple rows per subgroup.

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1], [0, 4, 128]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUWarpReduction>
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
// CHECK: func.func @i4_dequant_matvec()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>

hal.executable @not_vmt {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) {
hal.executable.export @not_vmt layout(#pipeline_layout)
builtin.module {
func.func @not_vmt() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x32000xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096xf16>> -> tensor<2x4096xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<2x32000xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x32000xf16>) -> tensor<2x32000xf16>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<2x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<2x32000xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %out, %8 : f16
linalg.yield %9 : f16
} -> tensor<2x32000xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2, 32000], strides = [1, 1] : tensor<2x32000xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x32000xf16>>
return
}
}
}
}

// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUMatmulSimt>
// CHECK-LABEL: hal.executable.export public @not_vmt
// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: func.func @not_vmt()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

0 comments on commit 413bc13

Please sign in to comment.