Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for a variety of (data tiled) convolution strategies #63

Merged
merged 12 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,84 @@ class TransposeUnitDimToShapeCast
}
};

// TODO: Move this upstream
// Hoists a vector.bitcast op to the output of the enclosing scf.if
//
// This transforms IR like:
// %0 = scf.if %1 -> (vector<16xi8>) {
// %2 = memref.load %4[%c0] : memref<?xvector<4xi32>>
// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8>
// scf.yield %3 : vector<16xi8>
// } else {
// scf.yield %cst : vector<16xi8>
// }
// Into:
// %0 = scf.if %1 -> (vector<4xi32>) {
// %2 = memref.load %4[%c0] : memref<?xvector<4xi32>>
// scf.yield %2 : vector<4xi32>
// } else {
// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32>
// scf.yield %0 : vector<4xi32>
// }
// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8>
struct BubbleUpBitCastOfScfIf : public OpRewritePattern<scf::IfOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(scf::IfOp ifOp,
PatternRewriter &rewriter) const override {
// Bail on more than one result for now.
scf::YieldOp thenYield = ifOp.thenYield();
if (!thenYield || thenYield.getNumOperands() != 1)
return failure();
auto bitcastOp = thenYield.getOperand(0).getDefiningOp<vector::BitCastOp>();
// Bail out if no bitcast on the if then statement.
if (!bitcastOp)
return failure();

VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
// Skip 0-D vector.
if (castSrcType.getRank() == 0)
return failure();

int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
// Require casting to more elements;
if (castSrcLastDim > castDstLastDim)
return failure();

Location loc = ifOp.getLoc();

auto bitcastedIfOp =
rewriter.create<scf::IfOp>(loc, castSrcType, ifOp.getCondition());
bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion());

scf::YieldOp newThenYield = bitcastedIfOp.thenYield();
auto newBitcastOp =
newThenYield.getOperand(0).getDefiningOp<vector::BitCastOp>();

newThenYield.setOperand(0, newBitcastOp.getSource());

auto newBitcast = rewriter.create<vector::BitCastOp>(
loc, castDstType, bitcastedIfOp.getResult(0));

scf::YieldOp elseYield = bitcastedIfOp.elseYield();
if (elseYield) {
OpBuilder::InsertionGuard elseGuard(rewriter);
rewriter.setInsertionPoint(elseYield);

Value yieldSrc = elseYield.getOperand(0);
auto elseBitcast =
rewriter.create<vector::BitCastOp>(loc, castSrcType, yieldSrc);
elseYield.setOperand(0, elseBitcast);
}
rewriter.replaceOp(ifOp, newBitcast);
return success();
}
};

static void loopInvariantCodeMotion(func::FuncOp funcOp) {
// Walk through all loops in a function in innermost-loop-first order. This
// way, we first LICM from the inner loop, and place the ops in
Expand Down Expand Up @@ -89,6 +167,7 @@ struct OptimizeVectorTransferPass
{
RewritePatternSet patterns(&getContext());
vector::populateBubbleVectorBitCastOpPatterns(patterns);
patterns.add<BubbleUpBitCastOfScfIf>(&getContext());
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ void transform_dialect::ApplyPrepareVectorToMMAPatternsOp::populatePatterns(
populatePrepareVectorToMMAPatterns(patterns, getUseNvGpu());
}

void transform_dialect::ApplySwapTensorPadWithExtractSliceOp::populatePatterns(
RewritePatternSet &patterns) {
patterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
patterns.getContext(), [](tensor::ExtractSliceOp) { return false; });
}

//===---------------------------------------------------------------------===//
// ApplyCommonSubexpressionEliminationOp
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,4 +519,17 @@ def IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp :
}];
}

def ApplySwapTensorPadWithExtractSliceOp : Op<Transform_Dialect,
"apply_patterns.iree.swap_tensor_pad",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to swap tensor pad with consumer tensor.extract_slice
operations.
}];

let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let assemblyFormat = "attr-dict";
}

#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS
24 changes: 24 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,30 @@ static LogicalResult setTransformDialectConfig(func::FuncOp entryPoint,
gpuModel.hasTF32TensorCore = targetInfo.hasTF32TensorCore;
gpuModel.hasMmaSync = targetInfo.hasMmaSync;

// Populates a subset of the fragment combinations supported in MLIR lowerings
// to NVVM (which is itself a subset of what LLVM supports) based on what the
// pipeline currently supports.
// TODO: avoid hard coding this and populate based on hardware capabilities.
// TODO: add missing supported configs once the pipeline supports it.
MLIRContext *context = entryPoint.getContext();
Type f32Type = Float32Type::get(context);
Type f16Type = Float16Type::get(context);

iree_compiler::gpu::MMAConfig f16f32AccConfig = {
/*m=*/16, /*n=*/16, /*k=*/16,
/*aType=*/f16Type, /*bType=*/f16Type, /*cType=*/f32Type};
iree_compiler::gpu::MMAConfig f16f16AccConfig = {
/*m=*/16, /*n=*/16, /*k=*/16,
/*aType=*/f16Type, /*bType=*/f16Type, /*cType=*/f16Type};
gpuModel.supportedWMMAConfigs = {f16f32AccConfig, f16f16AccConfig};

if (targetInfo.hasTF32TensorCore) {
iree_compiler::gpu::MMAConfig tf32WmmaConfig = {
/*m=*/16, /*n=*/16, /*k=*/8,
/*aType=*/f32Type, /*bType=*/f32Type, /*cType=*/f32Type};
gpuModel.supportedWMMAConfigs.push_back(tf32WmmaConfig);
}

if (failed(iree_compiler::gpu::matchAndSetTransformStrategy(entryPoint, op,
gpuModel)))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ void transform_dialect::MapNestedForallToGpuThreadsOp::build(
void transform_dialect::MapNestedForallToGpuThreadsOp::build(
OpBuilder &builder, OperationState &state, Value target,
ArrayRef<int64_t> workgroupDims, ArrayRef<int64_t> warpDims,
int64_t subgroupSize) {
std::optional<int64_t> subgroupSize) {
build(builder, state, {}, target, workgroupDims, warpDims,
builder.getI64IntegerAttr(subgroupSize));
subgroupSize ? builder.getI64IntegerAttr(*subgroupSize)
: IntegerAttr());
}

void transform_dialect::MapNestedForallToGpuThreadsOp::getEffects(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def MapNestedForallToGpuThreadsOp :
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$workgroup_dims,
"ArrayRef<int64_t>":$warp_dims,
"int64_t":$subgroupSize)>
"std::optional<int64_t>":$subgroupSize)>
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down
78 changes: 71 additions & 7 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -51,6 +52,11 @@ llvm::cl::opt<std::string> clSPIRVTransformDialectFileName(
"MLIR file containing a transform dialect specification to apply"),
llvm::cl::init(""));

llvm::cl::opt<bool> clSPIRVEnableTransformDialectJit(
"iree-spirv-enable-transform-dialect-jit",
llvm::cl::desc("enable the usage of the transform dialect JIT"),
llvm::cl::init(false));

using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1515,6 +1521,68 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
workgroupSize);
}

//===----------------------------------------------------------------------===//
// Transform Dialect Specialized Configurations
//===----------------------------------------------------------------------===//

static LogicalResult
setTransformDialectConfig(func::FuncOp entryPoint, Operation *op,
const spirv::TargetEnv &targetEnv) {
if (!clSPIRVEnableTransformDialectJit &&
clSPIRVTransformDialectFileName.empty()) {
return failure();
}

MLIRContext *context = entryPoint.getContext();

// Prefer a transform script file if provided.
if (!clSPIRVTransformDialectFileName.empty()) {
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
context, CodeGenPipeline::TransformDialectCodegen);
LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n");
return setTranslationInfo(entryPoint, translationInfo);
}

auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
entryPoint.getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen);
if (!clSPIRVTransformDialectFileName.empty()) {
return setTranslationInfo(entryPoint, translationInfo);
}

spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();

// TODO: unify the target informations into one structure.
iree_compiler::gpu::GPUModel gpuModel;
gpuModel.hasWarpShuffle =
targetEnv.allows(spirv::Capability::GroupNonUniformShuffle);
gpuModel.hasTF32TensorCore = false;
gpuModel.hasMmaSync = false;
gpuModel.minSubgroupSize = limits.getMinSubgroupSize();
gpuModel.maxSubgroupSize = limits.getMaxSubgroupSize();
gpuModel.maxWorkGroupInvocations = limits.getMaxComputeWorkgroupInvocations();

// Populates the supported WMMA fragment combinations from the target
// environment. Infer tf32 support from the list of supported fragment types.
Type f32Type = Float32Type::get(context);
auto properties = limits.getCooperativeMatrixPropertiesNv()
.getAsRange<spirv::CooperativeMatrixPropertiesNVAttr>();
for (auto property : properties) {
if (property.getScope().getValue() != spirv::Scope::Subgroup)
continue;
gpuModel.supportedWMMAConfigs.push_back(iree_compiler::gpu::MMAConfig{
property.getMSize(), property.getNSize(), property.getKSize(),
property.getAType(), property.getBType(), property.getCType()});
if (property.getAType() == f32Type && property.getBType() == f32Type)
gpuModel.hasTF32TensorCore = true;
}

if (failed(iree_compiler::gpu::matchAndSetTransformStrategy(entryPoint, op,
gpuModel)))
return failure();
return setTranslationInfo(entryPoint, translationInfo);
}

//===----------------------------------------------------------------------===//
// Configuration Dispatcher
//===----------------------------------------------------------------------===//
Expand All @@ -1531,13 +1599,9 @@ static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv,
return setUserConfig(entryPointFn, rootOp, compilationInfo);
}

if (!clSPIRVTransformDialectFileName.empty()) {
MLIRContext *context = entryPointFn.getContext();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
context, CodeGenPipeline::TransformDialectCodegen);
LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n");

return setTranslationInfo(entryPointFn, translationInfo);
// First try to see if there is a matching transform dialect configuration.
if (succeeded(setTransformDialectConfig(entryPointFn, rootOp, targetEnv))) {
return success();
}

// First try to find a proper CodeGen configuration to tile and vectorize for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ void mlir::iree_compiler::createTransformRegion(
(void)sequence;
LDBG("transformation script:\n");
LDBG("verification: " << sequence.verify().succeeded() << "\n");
LLVM_DEBUG(sequence.dump());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -290,7 +291,13 @@ Value mlir::iree_compiler::buildPad(
Value mlir::iree_compiler::buildVectorize(ImplicitLocOpBuilder &b, Value funcH,
bool applyCleanups,
bool vectorizePadding,
bool vectorizeNdExtract) {
bool vectorizeNdExtract,
bool useIreePadHandling) {
if (useIreePadHandling) {
funcH = b.create<transform::ApplyRegisteredPassOp>(
funcH.getType(), funcH,
b.getStringAttr("iree-codegen-vectorize-tensor-pad"));
}
funcH = b.create<VectorizeOp>(funcH, vectorizePadding, vectorizeNdExtract);
if (applyCleanups) {
iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ Value buildPad(ImplicitLocOpBuilder &b, Value opH,
/// If `applyCleanups` is true, also apply cleanup patterns.
Value buildVectorize(ImplicitLocOpBuilder &b, Value funcH,
bool applyCleanups = false, bool vectorizePadding = false,
bool vectorizeNdExtract = false);
bool vectorizeNdExtract = false,
bool useIreePadHandling = false);

/// Build transform IR that applies lowering of masked vector transfer
/// operations and subsequent cleanup patterns (fold-memref-aliases).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) {
cliOptionsSpecified = true;
}

/// If not specified, select instructions to target for compute.
if (!useMmaSync && !useWmma && !useFma) {
/// First, try to use tensor core.
if (getLhsElementalType() == getRhsElementalType()) {
/// Currently all supported targets at least have WMMA.
/// TODO: Handle targets without tensor core.
if (gpuModel.hasMmaSync)
useMmaSync = true;
else
useWmma = true;
} else {
/// Mixed precision only supported by fma.
useFma = true;
}
}

/// Prefer smaller subgroup sizes for tensor core strategies.
if (!useFma)
targetSubgroupSize = gpuModel.minSubgroupSize;

/// Default configuration based on hardware properties and problem bit widths.
if (clBlockTileSizes.getNumOccurrences()) {
blockTileSizes =
Expand All @@ -105,7 +125,7 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) {
// Infer from warp counts if present.
if (clNumWarps.getNumOccurrences()) {
numThreads = SmallVector<int64_t>(clNumWarps.begin(), clNumWarps.end());
numThreads[0] *= gpuModel.subgroupSize;
numThreads[0] *= getSubgroupSize();
} else {
numThreads = SmallVector<int64_t>{64, 2, 1};
}
Expand All @@ -114,7 +134,7 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) {
numWarps = SmallVector<int64_t>(clNumWarps.begin(), clNumWarps.end());
} else {
numWarps = numThreads;
numWarps[0] = mlir::ceilDiv(numWarps[0], gpuModel.subgroupSize);
numWarps[0] = mlir::ceilDiv(numWarps[0], getSubgroupSize());
}
if (clUseAsyncCopies.getNumOccurrences())
useAsyncCopies = clUseAsyncCopies;
Expand All @@ -126,21 +146,6 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) {
useWmma = clUseWmma;
if (clUseFma.getNumOccurrences())
useFma = clUseFma;
/// If not specified, select instructions to target for compute.
if (!useMmaSync && !useWmma && !useFma) {
/// First, try to use tensor core.
if (getLhsElementalType() == getRhsElementalType()) {
/// Currently all supported targets at least have WMMA.
/// TODO: Handle targets without tensor core.
if (gpuModel.hasMmaSync)
useMmaSync = true;
else
useWmma = true;
} else {
/// Mixed precision only supported by fma.
useFma = true;
}
}
if (clReductionTileSize.getNumOccurrences()) {
reductionTileSize = clReductionTileSize;
} else {
Expand Down Expand Up @@ -175,7 +180,7 @@ AbstractGemmLikeStrategy::getZeroPadAttrFromElementalTypes(OpBuilder &b) const {

LogicalResult
AbstractGemmLikeStrategy::validate(const GPUModel &gpuModel) const {
if (totalNumThreads() != totalNumWarps() * gpuModel.subgroupSize) {
if (totalNumThreads() != totalNumWarps() * getSubgroupSize()) {
llvm::errs() << "Number of threads specified by warps must match total "
"number of threads\n";
return failure();
Expand Down
Loading
Loading