Skip to content

Commit

Permalink
[VectorDistribution] Add vector distribution support multi-dim reduct…
Browse files Browse the repository at this point in the history
…ion with scalars (iree-org#18800)

Splitting iree-org#18519 into four patches.

Depends iree-org#18784 

This is the second one, adding the corresponding layout analysis and
especially supporting the case where reduction is performed inside
scf.for operation.

Also, the relevant tests are added. 

Since patch 2 includes changes from patch iree-org#18784, the necessary updates
from the first patch have also been included here.

---------

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Oct 28, 2024
1 parent 8806173 commit a041798
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 35 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:VectorDialect",
],
)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ iree_cc_library(
LLVMSupport
MLIRAnalysis
MLIRIR
MLIRSCFDialect
MLIRVectorDialect
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ struct DistributeConstants final : OpDistributionPattern<arith::ConstantOp> {
Type elementType = constant.getType().getElementType();
auto vectorType =
VectorType::get(layout.getDistributedShape(), elementType);
Operation *distirbutedOp = rewriter.create<arith::ConstantOp>(
auto distributedOp = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), vectorType,
SplatElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()));
replaceOpWithDistributedValues(rewriter, constantOp,
distirbutedOp->getResult(0));
distributedOp->getResult(0));
return success();
}
};
Expand Down Expand Up @@ -536,8 +536,10 @@ struct DistributeScfFor final : OpDistributionPattern<scf::ForOp> {
SmallVector<Value> newInitArgs;
for (Value initArg : forOp.getInitArgs()) {
if (auto vectorInitArg = dyn_cast<VectorValue>(initArg)) {
initArg =
getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]);
if (isNonZeroRank(vectorInitArg)) {
initArg =
getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]);
}
}
newInitArgs.push_back(initArg);
}
Expand Down Expand Up @@ -582,8 +584,14 @@ struct DistributeScfFor final : OpDistributionPattern<scf::ForOp> {
SmallVector<Value> operands;
for (Value operand : yieldOp->getOperands()) {
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
operand = DistributionPattern::getDistributed(rewriter, vectorOperand,
signature[vectorOperand]);
// Distributing the operand requires it to have a non-zero rank, meaning
// it must have at least one dimension. If the vector has a non-zero
// rank, the operand is distributed according to the provided layout
// signature.
if (isNonZeroRank(vectorOperand)) {
operand = DistributionPattern::getDistributed(
rewriter, vectorOperand, signature[vectorOperand]);
}
}
operands.push_back(operand);
}
Expand All @@ -606,8 +614,10 @@ struct DistributeScfFor final : OpDistributionPattern<scf::ForOp> {
for (auto [bbArg, oldInit] : llvm::zip_equal(bbArgs, oldInits)) {
Value val = bbArg;
if (auto oldVectorInit = dyn_cast<VectorValue>(oldInit)) {
val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldVectorInit.getLoc(), oldVectorInit.getType(), val);
if (isNonZeroRank(oldVectorInit)) {
val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldVectorInit.getLoc(), oldVectorInit.getType(), val);
}
}
replacements.push_back(val);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern<vector::BroadcastOp> {
auto vectorType = VectorType::get(distShape, elementType);

VectorValue srcVector = dyn_cast<VectorValue>(broadcastOp.getSource());
if (!srcVector) {
// If the srcVector is a scalar (like f32) or a rank-0 vector (like
// vector<f32>), we proceed with the scalar distribution branch.
if (!srcVector || !isNonZeroRank(srcVector)) {
// The way distribution currently works, there is no partial thread
// distribution, so a scalar is available to all threads. Scalar
// distribution is simply a broadcast from scalar to the distributed
Expand Down Expand Up @@ -413,16 +415,10 @@ struct DistributeMultiReduction final
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue srcVector = multiReduceOp.getSource();
auto accVector = dyn_cast<VectorValue>(multiReduceOp.getAcc());
if (!accVector) {
return rewriter.notifyMatchFailure(
multiReduceOp, "unimplemented: scalar accumulator distribution");
}
auto resVector = dyn_cast<VectorValue>(multiReduceOp.getResult());
if (!resVector) {
return rewriter.notifyMatchFailure(
multiReduceOp, "unimplemented: scalar result distribution");
}
Value acc = multiReduceOp.getAcc();
Value res = multiReduceOp.getResult();
auto accVector = dyn_cast<VectorValue>(acc);
auto resVector = dyn_cast<VectorValue>(res);

auto srcLayout = dyn_cast_or_null<NestedLayoutAttr>(signature[srcVector]);
if (!srcLayout) {
Expand All @@ -440,8 +436,14 @@ struct DistributeMultiReduction final

VectorValue disSrc =
getDistributed(rewriter, srcVector, signature[srcVector]);
VectorValue disAcc =
getDistributed(rewriter, accVector, signature[accVector]);

Value disAcc;
if (accVector) {
disAcc = getDistributed(rewriter, accVector, signature[accVector]);
} else {
// Scalars are always distributed to all threads already.
disAcc = multiReduceOp.getAcc();
}

Location loc = multiReduceOp.getLoc();

Expand All @@ -462,7 +464,16 @@ struct DistributeMultiReduction final
auto localReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, disSrc, localInit, distributedReductionMask,
multiReduceOp.getKind());
auto locallyReduced = dyn_cast<VectorValue>(localReduction.getResult());

VectorValue locallyReduced;
if (accVector) {
locallyReduced = dyn_cast<VectorValue>(localReduction.getResult());
} else {
// Broadcast scalar accumulator to vector.
VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, elemTy);
locallyReduced = rewriter.create<vector::BroadcastOp>(
loc, vecType, localReduction.getResult());
}

assert(locallyReduced && "result should have been a vector");

Expand All @@ -485,15 +496,30 @@ struct DistributeMultiReduction final
// reduction.
VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReduced.value());

if (!accVector) {
// Broadcast the scalar (e.g., f32) to a vector type (e.g., vector<f32>)
// because the following implementation requires the operand to be a
// vector.
disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
}

Value accReduction = vector::makeArithReduction(
rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
}
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);

return failure();
if (resVector) {
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, ArrayRef{int64_t(0)});
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
}

return success();
}

FailureOr<VectorValue> doThreadReduction(RewriterBase &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,16 @@ void DistributionPattern::replaceOpWithDistributedValues(
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
// If this value is a vector type, it must be converted back to simd.
if (isa<VectorType>(replacement.getType())) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
if (auto replacementType = dyn_cast<VectorType>(replacement.getType())) {
if (replacementType.getRank() != 0) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
}
}
replacements.push_back(replacement);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,95 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
// Accumulator reduction
// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>

// -----

#nested = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [2, 2],
outer_tile = [1, 1],
thread_tile = [16, 4],
element_tile = [1, 4],

subgroup_strides = [1, 1],
thread_strides = [1, 16]
>

func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 {
%arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
%0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32
return %0 : f32
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims
// Local reduction
// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
// Global reduction
// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32
// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// Accumulator reduction
// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32>

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [2, 2],
outer_tile = [1, 1],
thread_tile = [16, 4],
element_tile = [1, 4],

subgroup_strides = [1, 1],
thread_strides = [1, 16]
>

func.func @distribute_scf_for(%arr: memref<32x32xf16>, %a: vector<32x32xf16>) -> vector<f32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%cst = arith.constant dense<0.000000e+00> : vector<f32>
%cst_0 = arith.constant 0.0 : f16
%out = scf.for %i = %c0 to %c128 step %c1 iter_args(%arg0 = %cst) -> (vector<f32>) {
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
%rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<32x32xf16>
%b = arith.addf %rootl, %a : vector<32x32xf16>
%c = arith.extf %b : vector<32x32xf16> to vector<32x32xf32>
%init = vector.extractelement %arg0[] : vector<f32>
%root_red = vector.multi_reduction<add>, %c, %init [0, 1] : vector<32x32xf32> to f32
%d = vector.broadcast %root_red : f32 to vector<f32>
scf.yield %d : vector<f32>
}
return %out : vector<f32>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @distribute_scf_for
// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK: iter_args(%[[ARG0:.*]] = %[[ROOT]]) -> (vector<f32>)
// CHECK: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf16> -> vector<2x2x1x1x1x4xf16>
// CHECK: %[[B:.*]] = arith.addf %{{.*}}, %[[A]]
// CHECK: %[[C:.*]] = arith.extf %[[B]]
// CHECK-NEXT: %[[D:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
// Local reduction
// CHECK: vector.multi_reduction <add>, %[[C]], %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
// Global reduction
// CHECK: gpu.subgroup_reduce add %{{.*}} cluster(size = 16) : (f32) -> f32
// CHECK-NEXT: gpu.subgroup_reduce add %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// Accumulator reduction
// CHECK: vector.broadcast %[[D]] : f32 to vector<1xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1xf32>
56 changes: 53 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -135,6 +136,9 @@ class EnforceLayout : public DataFlowAnalysis {
RegionBranchPoint branchPoint,
MutableArrayRef<OpOperand> operands);

void visitRegionBranchTerminatorOpInterface(RegionBranchOpInterface branch,
RegionBranchPoint branchPoint);

DistributionLayout *getLatticeElement(Value val);

MLIRContext *ctx;
Expand Down Expand Up @@ -662,6 +666,9 @@ static void enforceLayoutToMultiReductionOp(
ArrayRef<DistributionLayout *> operandLattices,
ArrayRef<const DistributionLayout *> resultLattices,
std::function<void(DistributionLayout *, ChangeResult)> update) {
if (resultLattices.empty()) {
return;
}
// Reductions should always propagate value layout to result. Result can
// enforce it's layout on init.
const DistributionLayout *result = resultLattices[0];
Expand Down Expand Up @@ -727,9 +734,12 @@ static void enforceLayoutToBroadcastOp(

auto resultShape = broadcast.getResultVectorType().getShape();
auto inputType = broadcast.getSourceType();
assert(isa<VectorType>(inputType) &&
"Scalar broadcast not supported for now.");
auto inputShape = cast<VectorType>(inputType).getShape();

VectorType inputVectorType = dyn_cast<VectorType>(inputType);
if (!inputVectorType)
return;

auto inputShape = inputVectorType.getShape();

SmallVector<bool> reductionMask(resultShape.size(), false);
// Set the trailing dimensions to be reduced.
Expand Down Expand Up @@ -994,6 +1004,9 @@ void EnforceLayout::visitOperation(Operation *op) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionSuccessors(branch, RegionBranchPoint::parent(),
branch->getOpOperands());

// Handle the propagation from scf.for to yield op.
visitRegionBranchTerminatorOpInterface(branch, RegionBranchPoint::parent());
return;
}

Expand Down Expand Up @@ -1086,6 +1099,43 @@ void EnforceLayout::visitRegionSuccessors(RegionBranchOpInterface branch,
}
}

void EnforceLayout::visitRegionBranchTerminatorOpInterface(
RegionBranchOpInterface branch, RegionBranchPoint branchPoint) {
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(branchPoint, successors);
if (!branch.hasLoop())
return;
SmallVector<DistributionLayout *> resultLattices;
for (Value result : branch->getResults()) {
DistributionLayout *resultLattice = getLatticeElement(result);
if (resultLattice->isUninitialized())
continue;
resultLattices.push_back(resultLattice);
}

// We do not support multiple results yet.
if (resultLattices.size() != 1)
return;

for (RegionSuccessor successor : successors) {
if (Region *succ = successor.getSuccessor()) {
Operation *terminator = succ->back().getTerminator();
if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(terminator)) {
for (Value operand : yieldOp.getOperands()) {
if (!isa<VectorType>(operand.getType())) {
continue;
}
DistributionLayout *forwardLattice = getLatticeElement(operand);
ChangeResult changed = forwardLattice->resolve(resultLattices[0]);
propagateIfChanged(forwardLattice, changed);
}
}
}
}

return;
}

DistributionLayout *EnforceLayout::getLatticeElement(Value val) {
// Add dependency of operation on the analysis state.
assert(isa<VectorType>(val.getType()) && "Lattice value should be a vector");
Expand Down
Loading

0 comments on commit a041798

Please sign in to comment.