diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index 66acf8cb670a..f26f3223f3ce 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -46,6 +46,84 @@ class TransposeUnitDimToShapeCast } }; +// TODO: Move this upstream +// Hoists a vector.bitcast op to the output of the enclosing scf.if +// +// This transforms IR like: +// %0 = scf.if %1 -> (vector<16xi8>) { +// %2 = memref.load %4[%c0] : memref> +// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8> +// scf.yield %3 : vector<16xi8> +// } else { +// scf.yield %cst : vector<16xi8> +// } +// Into: +// %0 = scf.if %1 -> (vector<4xi32>) { +// %2 = memref.load %4[%c0] : memref> +// scf.yield %2 : vector<4xi32> +// } else { +// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32> +// scf.yield %0 : vector<4xi32> +// } +// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8> +struct BubbleUpBitCastOfScfIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + // Bail on more than one result for now. + scf::YieldOp thenYield = ifOp.thenYield(); + if (!thenYield || thenYield.getNumOperands() != 1) + return failure(); + auto bitcastOp = thenYield.getOperand(0).getDefiningOp(); + // Bail out if no bitcast on the if then statement. + if (!bitcastOp) + return failure(); + + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + // Skip 0-D vector. + if (castSrcType.getRank() == 0) + return failure(); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements; + if (castSrcLastDim > castDstLastDim) + return failure(); + + Location loc = ifOp.getLoc(); + + auto bitcastedIfOp = + rewriter.create(loc, castSrcType, ifOp.getCondition()); + bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + + scf::YieldOp newThenYield = bitcastedIfOp.thenYield(); + auto newBitcastOp = + newThenYield.getOperand(0).getDefiningOp(); + + newThenYield.setOperand(0, newBitcastOp.getSource()); + + auto newBitcast = rewriter.create( + loc, castDstType, bitcastedIfOp.getResult(0)); + + scf::YieldOp elseYield = bitcastedIfOp.elseYield(); + if (elseYield) { + OpBuilder::InsertionGuard elseGuard(rewriter); + rewriter.setInsertionPoint(elseYield); + + Value yieldSrc = elseYield.getOperand(0); + auto elseBitcast = + rewriter.create(loc, castSrcType, yieldSrc); + elseYield.setOperand(0, elseBitcast); + } + rewriter.replaceOp(ifOp, newBitcast); + return success(); + } +}; + static void loopInvariantCodeMotion(func::FuncOp funcOp) { // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in @@ -89,6 +167,7 @@ struct OptimizeVectorTransferPass { RewritePatternSet patterns(&getContext()); vector::populateBubbleVectorBitCastOpPatterns(patterns); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 0ffc77e0f336..2dc5eb0b7aa1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -269,6 +269,12 @@ void transform_dialect::ApplyPrepareVectorToMMAPatternsOp::populatePatterns( populatePrepareVectorToMMAPatterns(patterns, getUseNvGpu()); } +void transform_dialect::ApplySwapTensorPadWithExtractSliceOp::populatePatterns( + RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext(), [](tensor::ExtractSliceOp) { return false; }); +} + //===---------------------------------------------------------------------===// // ApplyCommonSubexpressionEliminationOp //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 219e7426a3ed..c61cb9398da1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td @@ -519,4 +519,17 @@ def IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp : }]; } +def ApplySwapTensorPadWithExtractSliceOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Populate patterns to swap tensor pad with consumer tensor.extract_slice + operations. + }]; + + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; + let assemblyFormat = "attr-dict"; +} + #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 17189e927bfb..bd192b3cbb5c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -692,6 +692,30 @@ static LogicalResult setTransformDialectConfig(func::FuncOp entryPoint, gpuModel.hasTF32TensorCore = targetInfo.hasTF32TensorCore; gpuModel.hasMmaSync = targetInfo.hasMmaSync; + // Populates a subset of the fragment combinations supported in MLIR lowerings + // to NVVM (which is itself a subset of what LLVM supports) based on what the + // pipeline currently supports. + // TODO: avoid hard coding this and populate based on hardware capabilities. + // TODO: add missing supported configs once the pipeline supports it. + MLIRContext *context = entryPoint.getContext(); + Type f32Type = Float32Type::get(context); + Type f16Type = Float16Type::get(context); + + iree_compiler::gpu::MMAConfig f16f32AccConfig = { + /*m=*/16, /*n=*/16, /*k=*/16, + /*aType=*/f16Type, /*bType=*/f16Type, /*cType=*/f32Type}; + iree_compiler::gpu::MMAConfig f16f16AccConfig = { + /*m=*/16, /*n=*/16, /*k=*/16, + /*aType=*/f16Type, /*bType=*/f16Type, /*cType=*/f16Type}; + gpuModel.supportedWMMAConfigs = {f16f32AccConfig, f16f16AccConfig}; + + if (targetInfo.hasTF32TensorCore) { + iree_compiler::gpu::MMAConfig tf32WmmaConfig = { + /*m=*/16, /*n=*/16, /*k=*/8, + /*aType=*/f32Type, /*bType=*/f32Type, /*cType=*/f32Type}; + gpuModel.supportedWMMAConfigs.push_back(tf32WmmaConfig); + } + if (failed(iree_compiler::gpu::matchAndSetTransformStrategy(entryPoint, op, gpuModel))) return failure(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index f101d90d2993..5c6b2f747682 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -126,9 +126,10 @@ void transform_dialect::MapNestedForallToGpuThreadsOp::build( void transform_dialect::MapNestedForallToGpuThreadsOp::build( OpBuilder &builder, OperationState &state, Value target, ArrayRef workgroupDims, ArrayRef warpDims, - int64_t subgroupSize) { + std::optional subgroupSize) { build(builder, state, {}, target, workgroupDims, warpDims, - builder.getI64IntegerAttr(subgroupSize)); + subgroupSize ? builder.getI64IntegerAttr(*subgroupSize) + : IntegerAttr()); } void transform_dialect::MapNestedForallToGpuThreadsOp::getEffects( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 8956a6ba77ab..ce43cd560068 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -110,7 +110,7 @@ def MapNestedForallToGpuThreadsOp : OpBuilder<(ins "Value":$target, "ArrayRef":$workgroup_dims, "ArrayRef":$warp_dims, - "int64_t":$subgroupSize)> + "std::optional":$subgroupSize)> ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index fd97fe5c4f9a..b81d4c9a584a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Common/UserConfig.h" #include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" #include "iree/compiler/Codegen/SPIRV/Utils.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/ADT/ArrayRef.h" @@ -51,6 +52,11 @@ llvm::cl::opt clSPIRVTransformDialectFileName( "MLIR file containing a transform dialect specification to apply"), llvm::cl::init("")); +llvm::cl::opt clSPIRVEnableTransformDialectJit( + "iree-spirv-enable-transform-dialect-jit", + llvm::cl::desc("enable the usage of the transform dialect JIT"), + llvm::cl::init(false)); + using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline; //===----------------------------------------------------------------------===// @@ -1515,6 +1521,68 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits, workgroupSize); } +//===----------------------------------------------------------------------===// +// Transform Dialect Specialized Configurations +//===----------------------------------------------------------------------===// + +static LogicalResult +setTransformDialectConfig(func::FuncOp entryPoint, Operation *op, + const spirv::TargetEnv &targetEnv) { + if (!clSPIRVEnableTransformDialectJit && + clSPIRVTransformDialectFileName.empty()) { + return failure(); + } + + MLIRContext *context = entryPoint.getContext(); + + // Prefer a transform script file if provided. + if (!clSPIRVTransformDialectFileName.empty()) { + auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( + context, CodeGenPipeline::TransformDialectCodegen); + LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n"); + return setTranslationInfo(entryPoint, translationInfo); + } + + auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( + entryPoint.getContext(), + IREE::Codegen::DispatchLoweringPassPipeline::TransformDialectCodegen); + if (!clSPIRVTransformDialectFileName.empty()) { + return setTranslationInfo(entryPoint, translationInfo); + } + + spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits(); + + // TODO: unify the target informations into one structure. + iree_compiler::gpu::GPUModel gpuModel; + gpuModel.hasWarpShuffle = + targetEnv.allows(spirv::Capability::GroupNonUniformShuffle); + gpuModel.hasTF32TensorCore = false; + gpuModel.hasMmaSync = false; + gpuModel.minSubgroupSize = limits.getMinSubgroupSize(); + gpuModel.maxSubgroupSize = limits.getMaxSubgroupSize(); + gpuModel.maxWorkGroupInvocations = limits.getMaxComputeWorkgroupInvocations(); + + // Populates the supported WMMA fragment combinations from the target + // environment. Infer tf32 support from the list of supported fragment types. + Type f32Type = Float32Type::get(context); + auto properties = limits.getCooperativeMatrixPropertiesNv() + .getAsRange(); + for (auto property : properties) { + if (property.getScope().getValue() != spirv::Scope::Subgroup) + continue; + gpuModel.supportedWMMAConfigs.push_back(iree_compiler::gpu::MMAConfig{ + property.getMSize(), property.getNSize(), property.getKSize(), + property.getAType(), property.getBType(), property.getCType()}); + if (property.getAType() == f32Type && property.getBType() == f32Type) + gpuModel.hasTF32TensorCore = true; + } + + if (failed(iree_compiler::gpu::matchAndSetTransformStrategy(entryPoint, op, + gpuModel))) + return failure(); + return setTranslationInfo(entryPoint, translationInfo); +} + //===----------------------------------------------------------------------===// // Configuration Dispatcher //===----------------------------------------------------------------------===// @@ -1531,13 +1599,9 @@ static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv, return setUserConfig(entryPointFn, rootOp, compilationInfo); } - if (!clSPIRVTransformDialectFileName.empty()) { - MLIRContext *context = entryPointFn.getContext(); - auto translationInfo = IREE::Codegen::TranslationInfoAttr::get( - context, CodeGenPipeline::TransformDialectCodegen); - LLVM_DEBUG(llvm::dbgs() << "using user specified transform dialect...\n"); - - return setTranslationInfo(entryPointFn, translationInfo); + // First try to see if there is a matching transform dialect configuration. + if (succeeded(setTransformDialectConfig(entryPointFn, rootOp, targetEnv))) { + return success(); } // First try to find a proper CodeGen configuration to tile and vectorize for diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp index ac55b3d7d843..aea31cfede85 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp @@ -109,6 +109,7 @@ void mlir::iree_compiler::createTransformRegion( (void)sequence; LDBG("transformation script:\n"); LDBG("verification: " << sequence.verify().succeeded() << "\n"); + LLVM_DEBUG(sequence.dump()); } //===----------------------------------------------------------------------===// @@ -290,7 +291,13 @@ Value mlir::iree_compiler::buildPad( Value mlir::iree_compiler::buildVectorize(ImplicitLocOpBuilder &b, Value funcH, bool applyCleanups, bool vectorizePadding, - bool vectorizeNdExtract) { + bool vectorizeNdExtract, + bool useIreePadHandling) { + if (useIreePadHandling) { + funcH = b.create( + funcH.getType(), funcH, + b.getStringAttr("iree-codegen-vectorize-tensor-pad")); + } funcH = b.create(funcH, vectorizePadding, vectorizeNdExtract); if (applyCleanups) { iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h index e0ff79c2017e..5f91b6010cdc 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h @@ -188,7 +188,8 @@ Value buildPad(ImplicitLocOpBuilder &b, Value opH, /// If `applyCleanups` is true, also apply cleanup patterns. Value buildVectorize(ImplicitLocOpBuilder &b, Value funcH, bool applyCleanups = false, bool vectorizePadding = false, - bool vectorizeNdExtract = false); + bool vectorizeNdExtract = false, + bool useIreePadHandling = false); /// Build transform IR that applies lowering of masked vector transfer /// operations and subsequent cleanup patterns (fold-memref-aliases). diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp index 8787c6359ddc..bc992ddeb712 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp @@ -91,6 +91,26 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) { cliOptionsSpecified = true; } + /// If not specified, select instructions to target for compute. + if (!useMmaSync && !useWmma && !useFma) { + /// First, try to use tensor core. + if (getLhsElementalType() == getRhsElementalType()) { + /// Currently all supported targets at least have WMMA. + /// TODO: Handle targets without tensor core. + if (gpuModel.hasMmaSync) + useMmaSync = true; + else + useWmma = true; + } else { + /// Mixed precision only supported by fma. + useFma = true; + } + } + + /// Prefer smaller subgroup sizes for tensor core strategies. + if (!useFma) + targetSubgroupSize = gpuModel.minSubgroupSize; + /// Default configuration based on hardware properties and problem bit widths. if (clBlockTileSizes.getNumOccurrences()) { blockTileSizes = @@ -105,7 +125,7 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) { // Infer from warp counts if present. if (clNumWarps.getNumOccurrences()) { numThreads = SmallVector(clNumWarps.begin(), clNumWarps.end()); - numThreads[0] *= gpuModel.subgroupSize; + numThreads[0] *= getSubgroupSize(); } else { numThreads = SmallVector{64, 2, 1}; } @@ -114,7 +134,7 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) { numWarps = SmallVector(clNumWarps.begin(), clNumWarps.end()); } else { numWarps = numThreads; - numWarps[0] = mlir::ceilDiv(numWarps[0], gpuModel.subgroupSize); + numWarps[0] = mlir::ceilDiv(numWarps[0], getSubgroupSize()); } if (clUseAsyncCopies.getNumOccurrences()) useAsyncCopies = clUseAsyncCopies; @@ -126,21 +146,6 @@ void AbstractGemmLikeStrategy::initDefaultValues(const GPUModel &gpuModel) { useWmma = clUseWmma; if (clUseFma.getNumOccurrences()) useFma = clUseFma; - /// If not specified, select instructions to target for compute. - if (!useMmaSync && !useWmma && !useFma) { - /// First, try to use tensor core. - if (getLhsElementalType() == getRhsElementalType()) { - /// Currently all supported targets at least have WMMA. - /// TODO: Handle targets without tensor core. - if (gpuModel.hasMmaSync) - useMmaSync = true; - else - useWmma = true; - } else { - /// Mixed precision only supported by fma. - useFma = true; - } - } if (clReductionTileSize.getNumOccurrences()) { reductionTileSize = clReductionTileSize; } else { @@ -175,7 +180,7 @@ AbstractGemmLikeStrategy::getZeroPadAttrFromElementalTypes(OpBuilder &b) const { LogicalResult AbstractGemmLikeStrategy::validate(const GPUModel &gpuModel) const { - if (totalNumThreads() != totalNumWarps() * gpuModel.subgroupSize) { + if (totalNumThreads() != totalNumWarps() * getSubgroupSize()) { llvm::errs() << "Number of threads specified by warps must match total " "number of threads\n"; return failure(); diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.h index 8a3e5e40b917..5f459648834b 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.h @@ -38,6 +38,14 @@ struct AbstractGemmLikeStrategy : GPUStrategy { /// override the user's choices. bool cliOptionsSpecified = false; + /// Non-default subgroup size to use configured based on hardware supported + /// values. + std::optional targetSubgroupSize = std::nullopt; + + int64_t getSubgroupSize() const { + return targetSubgroupSize ? *targetSubgroupSize : subgroupSize; + } + //===--------------------------------------------------------------------===// // Parameters that control the tiling and mapping. //===--------------------------------------------------------------------===// @@ -94,16 +102,20 @@ struct AbstractGemmLikeStrategy : GPUStrategy { return getResElementalType().getIntOrFloatBitWidth(); } - bool alignedLhs() const { + virtual bool alignedLhs() const { return m() % blockTileM() == 0 && k() % reductionTileSize == 0; } - bool alignedRhs() const { + virtual bool alignedRhs() const { return n() % blockTileN() == 0 && k() % reductionTileSize == 0; } - bool alignedRes() const { + virtual bool alignedRes() const { return m() % blockTileM() == 0 && n() % blockTileN() == 0; } + virtual bool hasLhsCopy() const { return true; } + virtual bool hasRhsCopy() const { return true; } + virtual bool hasResCopy() const { return true; } + virtual MappingInfo lhsCopyMapping() const = 0; virtual LogicalResult validateLhsCopyMapping() const = 0; virtual MappingInfo rhsCopyMapping() const = 0; diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/BUILD.bazel index 1ed72bde857e..857a14a148a5 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/BUILD.bazel @@ -18,7 +18,10 @@ iree_compiler_cc_library( "AbstractGemmLikeStrategy.cpp", "Common.cpp", "ConvolutionImplicitGemmStrategy.cpp", + "ConvolutionStrategy.cpp", + "ConvolutionTensorCoreStrategy.cpp", "CopyMapping.cpp", + "DataTiledMatmulStrategy.cpp", "MappingInfo.cpp", "MatmulTensorCoreStrategy.cpp", "PadStrategy.cpp", @@ -30,7 +33,10 @@ iree_compiler_cc_library( "AbstractGemmLikeStrategy.h", "Common.h", "ConvolutionImplicitGemmStrategy.h", + "ConvolutionStrategy.h", + "ConvolutionTensorCoreStrategy.h", "CopyMapping.h", + "DataTiledMatmulStrategy.h", "MappingInfo.h", "MatmulTensorCoreStrategy.h", "PadStrategy.h", diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CMakeLists.txt index 6eef68d20ad2..e33719f75763 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CMakeLists.txt @@ -17,7 +17,10 @@ iree_cc_library( "AbstractGemmLikeStrategy.h" "Common.h" "ConvolutionImplicitGemmStrategy.h" + "ConvolutionStrategy.h" + "ConvolutionTensorCoreStrategy.h" "CopyMapping.h" + "DataTiledMatmulStrategy.h" "MappingInfo.h" "MatmulTensorCoreStrategy.h" "PadStrategy.h" @@ -28,7 +31,10 @@ iree_cc_library( "AbstractGemmLikeStrategy.cpp" "Common.cpp" "ConvolutionImplicitGemmStrategy.cpp" + "ConvolutionStrategy.cpp" + "ConvolutionTensorCoreStrategy.cpp" "CopyMapping.cpp" + "DataTiledMatmulStrategy.cpp" "MappingInfo.cpp" "MatmulTensorCoreStrategy.cpp" "PadStrategy.cpp" diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp index 3f006c68556c..6021a830e040 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp @@ -142,9 +142,12 @@ static std::pair computeSplitPoint(int64_t upperBound, /// func.func. Value mlir::iree_compiler::gpu::buildMapToBlockAndThreads( ImplicitLocOpBuilder &b, Value funcH, ArrayRef blockSize, - ArrayRef warpDims) { + ArrayRef warpDims, std::optional subgroupSize) { b.create(funcH); - b.create(funcH, blockSize, warpDims); + auto mapToThreadsOp = + b.create(funcH, blockSize, warpDims); + if (subgroupSize) + mapToThreadsOp.setSubgroupSize(*subgroupSize); return funcH; } @@ -427,35 +430,48 @@ mlir::iree_compiler::gpu::buildDistributeMatmulCopies( variantH, tensor::ParallelInsertSliceOp::getOperationName()); copyBackOpH = b.create( insertSliceH.getType(), insertSliceH); - } else { + } else if (strategy.hasResCopy()) { Value resH = b.create( paddedMatmulOpH.getType(), paddedMatmulOpH, b.getI64IntegerAttr(2)); copyBackOpH = b.create(resH.getType(), resH); } - Value lhsH = b.create( - paddedMatmulOpH.getType(), paddedMatmulOpH, b.getI64IntegerAttr(0)); - Value rhsH = b.create( - paddedMatmulOpH.getType(), paddedMatmulOpH, b.getI64IntegerAttr(1)); + Value lhsH, rhsH; + if (strategy.hasLhsCopy()) { + lhsH = b.create( + paddedMatmulOpH.getType(), paddedMatmulOpH, b.getI64IntegerAttr(0)); + } + if (strategy.hasRhsCopy()) { + rhsH = b.create( + paddedMatmulOpH.getType(), paddedMatmulOpH, b.getI64IntegerAttr(1)); + } // Rewrite aligned pads as destination passing (linalg.copy) - if (strategy.alignedLhs() && strategy.packingDimensions[0]) + if (strategy.alignedLhs() && strategy.packingDimensions[0] && + strategy.hasLhsCopy()) lhsH = b.create(lhsH.getType(), lhsH); - if (strategy.alignedRhs() && strategy.packingDimensions[1]) + if (strategy.alignedRhs() && strategy.packingDimensions[1] && + strategy.hasRhsCopy()) rhsH = b.create(rhsH.getType(), rhsH); - MappingInfo lhsCopyMapping = strategy.lhsCopyMapping(); - Value lhsCopyOpH = buildDistributeOnePadOrCopyWithNumThreads( - b, variantH, lhsH, /*numThreads=*/lhsCopyMapping.numThreads, - /*threadDimMapping=*/lhsCopyMapping.threadMapping, - /*foldIfBranch=*/!strategy.alignedLhs()); + Value lhsCopyOpH = lhsH; + if (strategy.hasLhsCopy()) { + MappingInfo lhsCopyMapping = strategy.lhsCopyMapping(); + lhsCopyOpH = buildDistributeOnePadOrCopyWithNumThreads( + b, variantH, lhsH, /*numThreads=*/lhsCopyMapping.numThreads, + /*threadDimMapping=*/lhsCopyMapping.threadMapping, + /*foldIfBranch=*/!strategy.alignedLhs()); + } - MappingInfo rhsCopyMapping = strategy.rhsCopyMapping(); - Value rhsCopyOpH = buildDistributeOnePadOrCopyWithNumThreads( - b, variantH, rhsH, /*numThreads=*/rhsCopyMapping.numThreads, - /*threadDimMapping=*/rhsCopyMapping.threadMapping, - /*foldIfBranch=*/!strategy.alignedRhs()); + Value rhsCopyOpH = rhsH; + if (strategy.hasRhsCopy()) { + MappingInfo rhsCopyMapping = strategy.rhsCopyMapping(); + rhsCopyOpH = buildDistributeOnePadOrCopyWithNumThreads( + b, variantH, rhsH, /*numThreads=*/rhsCopyMapping.numThreads, + /*threadDimMapping=*/rhsCopyMapping.threadMapping, + /*foldIfBranch=*/!strategy.alignedRhs()); + } if (!strategy.alignedRes()) { MappingInfo resCopyMapping = strategy.resCopyMapping(); @@ -589,7 +605,7 @@ Value mlir::iree_compiler::gpu::buildConvertToTensorCoreOp( } /* else nothing to do for fma here */ // Post-hoc elimiation of barriers. - funcH = b.create(funcH); + // funcH = b.create(funcH); return funcH; } diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h index 1a6746aab0ba..cdf43771d3f5 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h @@ -73,9 +73,10 @@ int64_t adjustNumberOfWarpsForBlockShuffle(int64_t numWarpsToUse, /// Takes an optional `warpDims` argument to specify the number of warp /// dimensions to consider along various dimensions and avoid second-guessing /// how the mapping to warps should occur. -Value buildMapToBlockAndThreads(ImplicitLocOpBuilder &b, Value funcH, - ArrayRef blockSize, - ArrayRef warpDims = {}); +Value buildMapToBlockAndThreads( + ImplicitLocOpBuilder &b, Value funcH, ArrayRef blockSize, + ArrayRef warpDims = {}, + std::optional subgroupSize = std::nullopt); /// Post-bufferization vector distribution with rank-reduction. /// Takes a handle to a func.func and returns an updated handle to a diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.cpp new file mode 100644 index 000000000000..300110239a2b --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.cpp @@ -0,0 +1,326 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.h" + +#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" +#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h" +#include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Support/MathExtras.h" + +using namespace mlir; + +#define DEBUG_TYPE "iree-transform-builder" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +// TODO: significantly better namespacing. +using iree_compiler::buildPad; +using iree_compiler::buildSelectFirstNonEmpty; +using iree_compiler::buildTileFuseDistToForallWithNumThreads; +using iree_compiler::buildTileFuseDistToForallWithTileSizes; +using iree_compiler::TileToForallAndFuseAndDistributeResult; +using iree_compiler::gpu::buildBufferize; +using iree_compiler::gpu::buildConvertToAsyncCopies; +using iree_compiler::gpu::buildConvertToTensorCoreOp; +using iree_compiler::gpu::buildDistributeMatmulCopies; +using iree_compiler::gpu::ConvolutionStrategy; +using iree_compiler::gpu::MappingInfo; +using iree_compiler::IREE::transform_dialect::ApplyBufferOptimizationsOp; +using iree_compiler::IREE::transform_dialect:: + ApplyFoldReshapeIntoTensorHalInterfacePatternsOp; +using iree_compiler::IREE::transform_dialect::EliminateGpuBarriersOp; +using iree_compiler::IREE::transform_dialect:: + IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp; +using transform::FuseIntoContainingOp; +using transform::MatchOp; +using transform_ext::RegisterMatchCallbacksOp; + +static llvm::cl::list clBlockTileSizes( + "td-convolution-strategy-blk-sizes", + llvm::cl::desc("block tile size for dims (x,y,z) for the transform " + "dialect convolution strategy"), + llvm::cl::CommaSeparated); +static llvm::cl::list clNumThreads( + "td-convolution-strategy-num-threads", + llvm::cl::desc("number of threads for dims (x,y,z) for the transform " + "dialect convolution strategy"), + llvm::cl::CommaSeparated); +static llvm::cl::list clNumWarps( + "td-convolution-strategy-num-warps", + llvm::cl::desc("number of warps for dims (x,y,z) for the transform " + "dialect convolution strategy"), + llvm::cl::CommaSeparated); + +void ConvolutionStrategy::initDefaultValues(const GPUModel &gpuModel) { + blockTileSizes = + SmallVector{clBlockTileSizes.begin(), clBlockTileSizes.end()}; + numThreads = SmallVector{clNumThreads.begin(), clNumThreads.end()}; + numWarps = SmallVector{clNumWarps.begin(), clNumWarps.end()}; + + /// Default configuration based on hardware properties and problem bit widths. + if (clBlockTileSizes.getNumOccurrences()) { + blockTileSizes = + SmallVector(clBlockTileSizes.begin(), clBlockTileSizes.end()); + } else { + blockTileSizes = SmallVector{4, 16, 1}; + while ( + captures + .convolutionOpSizes[captures.convolutionDims.outputImage.front()] % + blockTileSizes[0]) + blockTileSizes[0] /= 2; + } + + if (clNumThreads.getNumOccurrences()) { + numThreads = SmallVector(clNumThreads.begin(), clNumThreads.end()); + } else { + // Infer from warp counts if present. + if (clNumWarps.getNumOccurrences()) { + numThreads = SmallVector(clNumWarps.begin(), clNumWarps.end()); + numThreads[0] *= subgroupSize; + } else { + numThreads = SmallVector{64, 1, 1}; + } + } + if (clNumWarps.getNumOccurrences()) { + numWarps = SmallVector(clNumWarps.begin(), clNumWarps.end()); + } else { + numWarps = numThreads; + numWarps[0] = mlir::ceilDiv(numWarps[0], subgroupSize); + } +} + +LLVM_DUMP_METHOD void ConvolutionStrategy::dump() const { print(llvm::errs()); } + +void ConvolutionStrategy::print(llvm::raw_ostream &os) const { + os << "\n--- Convolution strategy ---\n"; + os << "- block tile sizes: {"; + bool isFirst = true; + for (int64_t blockTileSize : blockTileSizes) { + if (!isFirst) + os << ", "; + os << blockTileSize; + isFirst = false; + } + os << "}\n"; + os << "- number of threads: {"; + isFirst = true; + for (int64_t numThreadsForDim : numThreads) { + if (!isFirst) + os << ", "; + os << numThreadsForDim; + isFirst = false; + } + os << "}\n"; + + os << "- number of warps: {"; + isFirst = true; + for (int64_t numWarpsForDim : numWarps) { + if (!isFirst) + os << ", "; + os << numWarpsForDim; + isFirst = false; + } + os << "\n-- Derived quantities --\n"; + os << "- block mapping:\n"; + getBlockMapping().print(os << " -> "); + os << "- compute mapping:\n"; + computeMapping().print(os << " -> "); +} + +// TODO: implement validator. +LogicalResult ConvolutionStrategy::validate(const GPUModel &gpuModel) const { + return success(); +} + +static std::tuple +buildConvolutionStrategyBlockDistribution(ImplicitLocOpBuilder &b, + Value variantH, + const ConvolutionStrategy &strategy) { + // Step 1. Call the matcher. Note that this is the same matcher as used to + // trigger this compilation path, so it must always apply. + b.create(); + auto [padH, fillH, convH, maybeTrailingH] = unpackRegisteredMatchCallback<4>( + b, "convolution", transform::FailurePropagationMode::Propagate, variantH); + + // Step 2. Create the block/mapping tiling level and fusee. + auto [fusionTargetH, fusionGroupH] = + buildSelectFirstNonEmpty(b, maybeTrailingH, convH); + MappingInfo blockMapping = strategy.getBlockMapping(); + TileToForallAndFuseAndDistributeResult tileResult = + buildTileFuseDistToForallWithTileSizes( + /*builder=*/b, + /*variantH=*/variantH, + /*rootH=*/fusionTargetH, + /*opsToFuseH=*/fusionGroupH, + /*tileSizes=*/ + getAsOpFoldResult(b.getI64ArrayAttr(blockMapping.tileSizes)), + /*threadDimMapping=*/ + b.getArrayAttr(blockMapping.threadMapping)); + + auto [blockConvH, maybeBlockTrailingH] = buildSelectFirstNonEmpty( + b, tileResult.resultingFusedOpsHandles.front(), tileResult.tiledOpH); + + Value fusedPadH = + b.create(padH, tileResult.forallH).getFusedOp(); + Value fusedFillH = + b.create(fillH, tileResult.forallH).getFusedOp(); + + // Handle the workgroup count region. + b.create( + tileResult.forallH); + + return std::make_tuple(fusedPadH, fusedFillH, blockConvH, maybeBlockTrailingH, + tileResult.forallH); +} + +/// Builds the common part of the schedule for matmuls and batched matmuls. +static void buildCommonConvolutionLikeThreadSchedule( + ImplicitLocOpBuilder &b, Value variantH, Value padH, Value fillH, + Value convH, Value trailingH, const ConvolutionStrategy &strategy) { + using mlir::iree_compiler::buildLowerVectorMasksAndCleanup; + using mlir::iree_compiler::buildTileFuseToScfFor; + using namespace mlir::iree_compiler::gpu; + + // Tile the outer input channel dimension. + if (strategy.captures.convolutionDims.inputChannel.size() > 1) { + SmallVector tileSizes( + strategy.captures.convolutionDims.outputChannel.size(), 0); + tileSizes.append(strategy.captures.convolutionDims.outputImage.size(), 0); + // tileSizes.append(strategy.captures.convolutionDims.filterLoop.size(), 0); + tileSizes.push_back(1); + + // Avoid canonicalizing before the pad to avoid folding away the + // extract_slice on the output needed to hoist the output pad. + auto tileReductionResult = buildTileFuseToScfFor( + b, variantH, convH, {}, getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)), + /*canonicalize=*/false); + convH = tileReductionResult.tiledOpH; + } + + Value funcH = + b.create(variantH, func::FuncOp::getOperationName()); + iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); + + // Step 5. Tile the filter loop dimensions. + SmallVector tileSizes( + strategy.captures.convolutionDims.outputChannel.size(), 0); + tileSizes.append(strategy.captures.convolutionDims.outputImage.size(), 0); + tileSizes.append(strategy.captures.convolutionDims.filterLoop.size(), 1); + + auto tileReductionResult = buildTileFuseToScfFor( + b, variantH, convH, {}, getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)), + /*canonicalize=*/true); + Value filterTiledConvH = tileReductionResult.tiledOpH; + + // Step 6. Distribute to threads: SIMT programming model. + MappingInfo computeMapping = strategy.computeMapping(); + buildTileFuseDistToForallWithNumThreads( + b, variantH, filterTiledConvH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(computeMapping.numThreads)), + b.getArrayAttr(computeMapping.threadMapping)); + buildTileFuseDistToForallWithNumThreads( + b, variantH, fillH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(computeMapping.numThreads)), + b.getArrayAttr(computeMapping.threadMapping)); + buildTileFuseDistToForallWithNumThreads( + b, variantH, trailingH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(computeMapping.numThreads)), + b.getArrayAttr(computeMapping.threadMapping)); + + // Step 7. Apply vectorization + cleanups to what remains. + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + b.create(loc); + b.create(loc); + }); + funcH = iree_compiler::buildVectorize(b, funcH, + /*vectorizeNdExtract=*/false, + /*vectorizePadding=*/false, + /*useIreePadHandling=*/true, + /*applyCleanups=*/true); + + // Step 8. Bufferize and drop HAL descriptor from memref ops. + variantH = buildBufferize(b, variantH); + + // Step 9. Post-bufferization mapping to blocks and threads. + // Need to match again since bufferize invalidated all handles. + // TODO: assumes a single func::FuncOp to transform, needs hardening. + funcH = b.create(variantH, func::FuncOp::getOperationName()); + funcH = buildMapToBlockAndThreads(b, funcH, + /*blockSize=*/strategy.numThreads, + /*warpDims=*/strategy.numWarps, + /*subgroupSize=*/strategy.subgroupSize); + // This currently spins forever. + // funcH = b.create(funcH); + + // Step 10. Cleanup. + iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); + b.create(funcH); + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); + + // Value forH = b.create( + // transform::OperationType::get(b.getContext(), "scf.for"), funcH, + // b.getStrArrayAttr({scf::ForOp::getOperationName()}), + // /*matchInterfaceEnum=*/transform::MatchInterfaceEnumAttr(), + // /*opAttrs=*/DictionaryAttr(), + // /*filterResultType=*/TypeAttr()); + // // TODO: At this time, this synchronization is needed for applying the + // // HoistRedundantVectorTransfersOp transform correctly. This is because the + // // transform does not take parallelism into accound. + // // In the future, HoistRedundantVectorTransfersOp + SynchronizeLoopOp need + // to + // // be replaced by a single transform. + // b.create(forH); + + // TODO: not a functional style transform and avoid returning funcH. + // funcH = b.create( + // transform::AnyOpType::get(b.getContext()), funcH); + iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); + b.create(funcH); + + // // Post-hoc elimiation of barriers. + // funcH = b.create(funcH); + + // Step 11. Late lowerings and cleanups. + buildLowerVectorMasksAndCleanup(b, funcH); +} + +void iree_compiler::gpu::buildConvolutionStrategy( + ImplicitLocOpBuilder &b, Value variantH, + const ConvolutionStrategy &strategy) { + LLVM_DEBUG(strategy.print(DBGS())); + + // Step 1. Apply block-level part of the strategy, keeps everything fused. + auto [padH, fillH, convH, trailingH, forall] = + buildConvolutionStrategyBlockDistribution(b, variantH, strategy); + buildCommonConvolutionLikeThreadSchedule(b, variantH, padH, fillH, convH, + trailingH, strategy); +} diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.h new file mode 100644 index 000000000000..08b3c9999b0e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.h @@ -0,0 +1,142 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_CONVOLUTION_STRATEGY_H_ +#define IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_CONVOLUTION_STRATEGY_H_ + +#include "iree-dialects/Transforms/TransformMatchers.h" +#include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Support/LogicalResult.h" + +namespace llvm { +class raw_ostream; +} + +namespace mlir { +namespace iree_compiler { +namespace gpu { + +struct GPUModel; + +class ConvolutionStrategy : public GPUStrategy { +public: + ConvolutionStrategy(MLIRContext *context, + const transform_ext::MatchedConvolutionCaptures &captures, + const GPUModel &gpuModel) + : GPUStrategy(gpuModel), ctx(context), captures(captures) { + initDefaultValues(gpuModel); + } + + ConvolutionStrategy(const ConvolutionStrategy &) = default; + ConvolutionStrategy &operator=(const ConvolutionStrategy &) = default; + + /// Constructor quantities. + MLIRContext *ctx; + transform_ext::MatchedConvolutionCaptures captures; + + /// Initialize values from the CLI. + void initDefaultValues(const GPUModel &gpuModel); + + LogicalResult validate(const GPUModel &gpuModel) const; + + //===--------------------------------------------------------------------===// + // Parameters that control the tiling and mapping. + //===--------------------------------------------------------------------===// + + /// Tile sizes for the workgroup / determines grid size for all known + /// reduction strategies. The initial values are set by initDefaultValues(); + SmallVector blockTileSizes; + int64_t reductionTileSize; + SmallVector numThreads; + SmallVector numWarps; + + /// Common values based on derived quantities. + int64_t totalNumThreads() const { + int64_t res = 1; + for (auto v : numThreads) + res *= v; + return res; + } + + int64_t totalNumWarps() const { + int64_t res = 1; + for (auto v : numWarps) + res *= v; + return res; + } + + int64_t blockTileH() const { + assert(blockTileSizes.size() >= 2 && "need at least 2 tile sizes"); + return blockTileSizes[0]; + } + int64_t blockTileW() const { + assert(blockTileSizes.size() >= 2 && "need at least 2 tile sizes"); + return blockTileSizes[1]; + } + + // int64_t numWarpsX() const { + // assert(numWarps.size() >= 2 && "need at least 2 warp sizes"); + // return numWarps[0]; + // } + // int64_t numWarpsY() const { + // assert(numWarps.size() >= 2 && "need at least 2 warp sizes"); + // return numWarps[1]; + // } + + MappingInfo getBlockMapping() const { + SmallVector tileSizes; + SmallVector threadMapping = {blockY(ctx), blockX(ctx)}; + // Outer output channel. + if (captures.convolutionDims.outputChannel.size() == 2) { + tileSizes.push_back(1); + threadMapping = {blockZ(ctx), blockY(ctx), blockX(ctx)}; + } + // Image height. + tileSizes.push_back(blockTileH()); + // Image width. + tileSizes.push_back(blockTileW()); + return MappingInfo{/*numThreads=*/{}, + /*tileSizes=*/tileSizes, + /*threadMapping=*/threadMapping, + /*vectorSize=*/std::nullopt}; + } + + MappingInfo computeMapping() const { + int64_t innerOcTileSize = + captures + .convolutionOpSizes[captures.convolutionDims.outputChannel.back()]; + MappingInfo mapping = CopyMapping::getMappingInfo( + ctx, totalNumThreads(), + /*alignment=*/innerOcTileSize, + {blockTileH(), blockTileW(), innerOcTileSize}); + if (captures.convolutionDims.outputChannel.size() == 2) { + mapping.tileSizes.insert(mapping.tileSizes.begin(), 1); + mapping.numThreads.insert(mapping.numThreads.begin(), 0); + } + return mapping; + // return MappingInfo{ + // /*numThreads=*/captures.convolutionDims.outputChannel.size() == 2 + // ? SmallVector{0, 0, numWarpsY(), numWarpsX()} + // : SmallVector{0, numWarpsY(), numWarpsX()}, + // /*tileSizes=*/{}, + // /*threadMapping=*/{warpY(ctx), warpX(ctx)}, + // /*vectorSize=*/std::nullopt}; + } + + void print(llvm::raw_ostream &os) const; + LLVM_DUMP_METHOD void dump() const; +}; + +} // namespace gpu +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_CONVOLUTION_STRATEGY_H_ diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.cpp new file mode 100644 index 000000000000..c08d703a01ee --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.cpp @@ -0,0 +1,289 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.h" + +#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" +#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h" +#include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" + +using namespace mlir; + +#define DEBUG_TYPE "iree-transform-builder" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +// TODO: significantly better namespacing. +using iree_compiler::buildPad; +using iree_compiler::buildSelectFirstNonEmpty; +using iree_compiler::buildTileFuseDistToForallWithNumThreads; +using iree_compiler::buildTileFuseDistToForallWithTileSizes; +using iree_compiler::TileToForallAndFuseAndDistributeResult; +using iree_compiler::gpu::buildBufferize; +using iree_compiler::gpu::buildConvertToAsyncCopies; +using iree_compiler::gpu::buildConvertToTensorCoreOp; +using iree_compiler::gpu::buildDistributeMatmulCopies; +using iree_compiler::gpu::buildHoistOutputPaddingOp; +using iree_compiler::gpu::DataTiledConvolutionStrategy; +using iree_compiler::gpu::MappingInfo; +using iree_compiler::gpu::scaleUpByBitWidth; +using iree_compiler::IREE::transform_dialect:: + ApplyCommonSubexpressionEliminationOp; +using iree_compiler::IREE::transform_dialect:: + ApplyFoldReshapeIntoTensorHalInterfacePatternsOp; +using iree_compiler::IREE::transform_dialect:: + ApplySwapTensorPadWithExtractSliceOp; +using iree_compiler::IREE::transform_dialect::EliminateGpuBarriersOp; +using iree_compiler::IREE::transform_dialect:: + IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp; +using transform::FuseIntoContainingOp; +using transform::MatchOp; +using transform_ext::RegisterMatchCallbacksOp; + +void DataTiledConvolutionStrategy::initDefaultValues(const GPUModel &gpuModel) { + // Set the configuration for padding the matmul. + paddingValueTypes = {captures.inputElementType, captures.filterElementType, + captures.outputElementType}; + paddingDimensions = {0, 1, 2}; + packingDimensions = {1, 0, 1}; + + // Pull in tile configs from flags. + AbstractGemmLikeStrategy::initDefaultValues(gpuModel); + if (!cliOptionsSpecified) { + numThreads = SmallVector{32, 1, 1}; + numWarps = SmallVector{1, 1, 1}; + blockTileSizes[0] = 64; + blockTileSizes[1] = 1; + while ( + captures + .convolutionOpSizes[captures.convolutionDims.outputImage.back()] % + blockTileSizes[0]) { + blockTileSizes[0] /= 2; + } + useWmma = true; + } +} + +LLVM_DUMP_METHOD void DataTiledConvolutionStrategy::dump() const { + print(llvm::errs()); +} + +void DataTiledConvolutionStrategy::print(llvm::raw_ostream &os) const { + os << "\n--- Data Tiled Convolution strategy ---\n"; + AbstractGemmLikeStrategy::print(os); +} + +// TODO: implement validator. +LogicalResult +DataTiledConvolutionStrategy::validate(const GPUModel &gpuModel) const { + return success(); +} + +static std::tuple +buildDataTiledConvolutionStrategyBlockDistribution( + ImplicitLocOpBuilder &b, Value variantH, + const DataTiledConvolutionStrategy &strategy) { + // Step 1. Call the matcher. Note that this is the same matcher as used to + // trigger this compilation path, so it must always apply. + b.create(); + auto [padH, fillH, convH, maybeTrailingH] = unpackRegisteredMatchCallback<4>( + b, "convolution", transform::FailurePropagationMode::Propagate, variantH); + + // Step 2. Create the block/mapping tiling level and fusee. + auto [fusionTargetH, fusionGroupH] = + buildSelectFirstNonEmpty(b, maybeTrailingH, convH); + MappingInfo blockMapping = strategy.getBlockMapping(); + TileToForallAndFuseAndDistributeResult tileResult = + buildTileFuseDistToForallWithTileSizes( + /*builder=*/b, + /*variantH=*/variantH, + /*rootH=*/fusionTargetH, + /*opsToFuseH=*/fusionGroupH, + /*tileSizes=*/ + getAsOpFoldResult(b.getI64ArrayAttr(blockMapping.tileSizes)), + /*threadDimMapping=*/ + b.getArrayAttr(blockMapping.threadMapping)); + + auto [blockConvH, maybeBlockTrailingH] = buildSelectFirstNonEmpty( + b, tileResult.resultingFusedOpsHandles.front(), tileResult.tiledOpH); + + Value fusedPadH = + b.create(padH, tileResult.forallH).getFusedOp(); + Value fusedFillH = + b.create(fillH, tileResult.forallH).getFusedOp(); + + // Handle the workgroup count region. + b.create( + tileResult.forallH); + + return std::make_tuple(fusedPadH, fusedFillH, blockConvH, maybeBlockTrailingH, + tileResult.forallH); +} + +/// Builds the common part of the schedule for matmuls and batched matmuls. +static void buildCommonConvolutionLikeThreadSchedule( + ImplicitLocOpBuilder &b, Value variantH, Value padH, Value fillH, + Value convH, Value trailingH, + const DataTiledConvolutionStrategy &strategy) { + using mlir::iree_compiler::buildLowerVectorMasksAndCleanup; + using mlir::iree_compiler::buildTileFuseToScfFor; + using namespace mlir::iree_compiler::gpu; + + // Tile the outer input channel dimension. + if (strategy.captures.convolutionDims.inputChannel.size() > 1) { + SmallVector tileSizes( + strategy.captures.convolutionDims.outputChannel.size(), 0); + tileSizes.append(strategy.captures.convolutionDims.outputImage.size(), 0); + // tileSizes.append(strategy.captures.convolutionDims.filterLoop.size(), 0); + tileSizes.push_back(1); + + // Avoid canonicalizing before the pad to avoid folding away the + // extract_slice on the output needed to hoist the output pad. + auto tileReductionResult = buildTileFuseToScfFor( + b, variantH, convH, {}, getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)), + /*canonicalize=*/false); + convH = tileReductionResult.tiledOpH; + } + + // Step 2. Pad the (batch) matmul op. + auto paddedConvOpH = buildPad( + b, convH, strategy.getZeroPadAttrFromElementalTypes(b).getValue(), + strategy.paddingDimensions, strategy.packingDimensions); + + // Step 3. Hoist the padding of the output operand above the reduction loop. + // The resulting fillOp will be mapped with the contraction using an SIMD + // programming model. + Value fillOpH = fillH; + if (!strategy.alignedRes()) { + fillOpH = buildHoistOutputPaddingOp(b, variantH, paddedConvOpH); + } + + // Running canonicalization is required here to enable aligned pads to become + // linalg.copy ops when rewriting in DPS. + Value funcH = + b.create(variantH, func::FuncOp::getOperationName()); + iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); + + // Step 4. Distribute pad and copies: SIMT programming model. + // auto [lhsCopyOpH, rhsCopyOpH, copyBackOpH] = + buildDistributeMatmulCopies(b, variantH, paddedConvOpH, strategy); + + // Step 5. Tile the filter loop dimensions. + SmallVector tileSizes( + strategy.captures.convolutionDims.outputChannel.size(), 0); + tileSizes.append(strategy.captures.convolutionDims.inputChannel.size() - 1, + 0); + tileSizes.append(strategy.captures.convolutionDims.outputImage.size(), 0); + tileSizes.append(strategy.captures.convolutionDims.filterLoop.size(), 1); + + auto tileReductionResult = + buildTileFuseToScfFor(b, variantH, paddedConvOpH, {}, + getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)), + /*canonicalize=*/true); + Value filterTiledConvH = tileReductionResult.tiledOpH; + + // Step 6. Distribute to warps: SIMD programming model. + // TODO: get the number of warps from strategy. + MappingInfo computeMapping = strategy.computeMapping(); + buildTileFuseDistToForallWithNumThreads( + b, variantH, filterTiledConvH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(computeMapping.numThreads)), + b.getArrayAttr(computeMapping.threadMapping)); + + // Step 6.5 Distribute to threads: SIMT programming model. + MappingInfo resCopyMapping = strategy.resCopyMapping(); + fillOpH = b.create(variantH, linalg::FillOp::getOperationName()); + buildTileFuseDistToForallWithNumThreads( + b, variantH, fillOpH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(resCopyMapping.numThreads)), + b.getArrayAttr(resCopyMapping.threadMapping)); + buildTileFuseDistToForallWithNumThreads( + b, variantH, trailingH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(resCopyMapping.numThreads)), + b.getArrayAttr(resCopyMapping.threadMapping)); + + // Step 7. Apply vectorization + cleanups to what remains. + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + funcH = b.create( + funcH.getType(), funcH, + b.getStringAttr("iree-codegen-concretize-pad-result-shape")); + b.create(funcH); + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + funcH = b.create( + funcH.getType(), funcH, + b.getStringAttr("iree-codegen-concretize-pad-result-shape")); + b.create(funcH); + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + }); + funcH = b.create( + funcH.getType(), funcH, + b.getStringAttr("iree-codegen-concretize-pad-result-shape")); + b.create(funcH); + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + b.create(loc); + b.create(loc); + }); + funcH = iree_compiler::buildVectorize(b, funcH, + /*vectorizeNdExtract=*/false, + /*vectorizePadding=*/false, + /*useIreePadHandling=*/true, + /*applyCleanups=*/true); + + // Step 8. Bufferize and drop HAL descriptor from memref ops. + variantH = buildBufferize(b, variantH); + + // Step 9. Post-bufferization mapping to blocks and threads. + // Need to match again since bufferize invalidated all handles. + // TODO: assumes a single func::FuncOp to transform, needs hardening. + funcH = b.create(variantH, func::FuncOp::getOperationName()); + funcH = + buildMapToBlockAndThreads(b, funcH, + /*blockSize=*/strategy.numThreads, + /*warpDims=*/strategy.numWarps, + /*subgroupSize=*/strategy.targetSubgroupSize); + // funcH = b.create(funcH); + + // Step 10. Convert to tensor core ops. + // TODO: avoid consuming handles and returning here. + funcH = buildConvertToTensorCoreOp(b, funcH, strategy); + + // Step 11. Late lowerings and cleanups. + buildLowerVectorMasksAndCleanup(b, funcH); +} + +void iree_compiler::gpu::buildConvolutionTensorCoreStrategy( + ImplicitLocOpBuilder &b, Value variantH, + const DataTiledConvolutionStrategy &strategy) { + LLVM_DEBUG(strategy.print(DBGS())); + + // Step 1. Apply block-level part of the strategy, keeps everything fused. + auto [padH, fillH, convH, trailingH, forall] = + buildDataTiledConvolutionStrategyBlockDistribution(b, variantH, strategy); + buildCommonConvolutionLikeThreadSchedule(b, variantH, padH, fillH, convH, + trailingH, strategy); +} diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.h new file mode 100644 index 000000000000..6f9b7af73a4d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.h @@ -0,0 +1,214 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_TENSOR_CORE_CONVOLUTION_STRATEGY_H_ +#define IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_TENSOR_CORE_CONVOLUTION_STRATEGY_H_ + +#include "iree-dialects/Transforms/TransformMatchers.h" +#include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/MathExtras.h" + +namespace llvm { +class raw_ostream; +} + +namespace mlir { +namespace iree_compiler { +namespace gpu { + +struct GPUModel; + +class DataTiledConvolutionStrategy : public AbstractGemmLikeStrategy { +public: + DataTiledConvolutionStrategy( + MLIRContext *context, + const transform_ext::MatchedConvolutionCaptures &captures, + const GPUModel &gpuModel) + : AbstractGemmLikeStrategy(gpuModel), ctx(context), captures(captures) { + initDefaultValues(gpuModel); + } + + DataTiledConvolutionStrategy(const DataTiledConvolutionStrategy &) = default; + DataTiledConvolutionStrategy & + operator=(const DataTiledConvolutionStrategy &) = default; + + /// Constructor quantities. + MLIRContext *ctx; + transform_ext::MatchedConvolutionCaptures captures; + + /// Initialize values from the CLI. Set cliOptionsSpecified to true if the + /// default CLI values have been overriden. + void initDefaultValues(const GPUModel &gpuModel) override; + + LogicalResult validate(const GPUModel &gpuModel) const override; + + int64_t m() const override { + int64_t imgElements = 1; + for (auto i : captures.convolutionDims.outputImage) { + imgElements *= captures.convolutionOpSizes[i]; + } + return imgElements; + } + int64_t n() const override { + int64_t ocElements = 1; + for (auto i : captures.convolutionDims.outputChannel) { + ocElements *= captures.convolutionOpSizes[i]; + } + return ocElements; + } + int64_t k() const override { + int64_t icElements = 1; + for (auto i : captures.convolutionDims.outputChannel) { + icElements *= captures.convolutionOpSizes[i]; + } + return icElements; + } + + int64_t blockTileM() const override { + assert(blockTileSizes.size() >= 2 && "need at least 2 tile sizes"); + return blockTileSizes[0]; + } + int64_t blockTileN() const override { + assert(blockTileSizes.size() >= 2 && "need at least 2 tile sizes"); + return blockTileSizes[1]; + } + + int64_t numWarpsX() const override { + assert(numWarps.size() >= 2 && "need at least 2 warp sizes"); + return numWarps[0]; + } + int64_t numWarpsY() const override { + assert(numWarps.size() >= 2 && "need at least 2 warp sizes"); + return numWarps[1]; + } + + Type getLhsElementalType() const override { + return captures.inputElementType; + } + Type getRhsElementalType() const override { + return captures.filterElementType; + } + Type getResElementalType() const override { + return captures.outputElementType; + } + + virtual bool alignedLhs() const { return true; } + virtual bool alignedRhs() const { return true; } + virtual bool alignedRes() const { return true; } + + bool hasLhsCopy() const override { return true; } + // Filter is not copied. + bool hasRhsCopy() const override { return false; } + bool hasResCopy() const override { + return captures.convolutionDims.inputChannel.size() == 2; + } + + MappingInfo getBlockMapping() const override { + SmallVector tileSizes; + SmallVector threadMapping = {blockY(ctx), blockX(ctx)}; + // Outer output channel. + if (captures.convolutionDims.outputChannel.size() == 2) { + tileSizes.push_back(blockTileN()); + threadMapping = {blockZ(ctx), blockY(ctx), blockX(ctx)}; + } + // Image height. + tileSizes.push_back(1); + // Image width. + tileSizes.push_back(blockTileM()); + return MappingInfo{/*numThreads=*/{}, + /*tileSizes=*/tileSizes, + /*threadMapping=*/threadMapping, + /*vectorSize=*/std::nullopt}; + } + + MappingInfo lhsCopyMapping() const override { + int64_t inputTileH = + captures.convolutionOpSizes[captures.convolutionDims.filterLoop[0]]; + int64_t inputTileW = + captures.convolutionOpSizes[captures.convolutionDims.filterLoop[1]] + + blockTileM() - 1; + int64_t icInnerTileSize = + captures + .convolutionOpSizes[captures.convolutionDims.inputChannel.back()]; + MappingInfo mapping = CopyMapping::getMappingInfo( + ctx, totalNumThreads(), + /*alignment=*/k(), + /*copySizes=*/ + ArrayRef{inputTileH, inputTileW, icInnerTileSize}, + /*favorPredication=*/true, + /*elementalBitWidth=*/lhsElementalBitWidth(), + /*favorLazyOuterDistributing=*/true); + if (captures.convolutionDims.inputChannel.size() == 2) { + mapping.tileSizes.insert(mapping.tileSizes.begin(), 1); + mapping.numThreads.insert(mapping.numThreads.begin(), 0); + } + return mapping; + } + // TODO: Write a validator. + LogicalResult validateLhsCopyMapping() const override { return success(); } + + // Filter is not copied. + MappingInfo rhsCopyMapping() const override { return MappingInfo(); } + LogicalResult validateRhsCopyMapping() const override { return success(); } + + MappingInfo resCopyMapping() const override { + int64_t outputTileH = 1; + int64_t outputTileW = blockTileM(); + int64_t ocInnerTileSize = + captures + .convolutionOpSizes[captures.convolutionDims.outputChannel.back()]; + MappingInfo mapping = CopyMapping::getMappingInfo( + ctx, totalNumThreads(), + /*alignment=*/n(), + /*copySizes=*/ + ArrayRef{outputTileH, outputTileW, ocInnerTileSize}, + /*favorPredication=*/false, + /*elementalBitWidth=*/resElementalBitWidth(), + /*favorLazyOuterDistributing=*/false); + if (captures.convolutionDims.outputChannel.size() == 2) { + mapping.tileSizes.insert(mapping.tileSizes.begin(), 1); + mapping.numThreads.insert(mapping.numThreads.begin(), 0); + } + return mapping; + } + // TODO: Write a validator. + LogicalResult validateResCopyMapping() const override { return success(); } + + // COMPUTE is of size mxn. + MappingInfo computeMapping() const override { + // FMA disabled. + // if (useFma) { + // // When using FMA we don't need to map to warps, instead just match + // what + // // the copy does. + // return CopyMapping::getMappingInfo(ctx, totalNumThreads(), + // /*alignment=*/n(), + // {blockTileM(), blockTileN()}); + // } + return MappingInfo{ + /*numThreads=*/captures.convolutionDims.outputChannel.size() == 2 + ? SmallVector{0, 0, numWarpsY(), numWarpsX()} + : SmallVector{0, numWarpsY(), numWarpsX()}, + /*tileSizes=*/{}, + /*threadMapping=*/{warpY(ctx), warpX(ctx)}, + /*vectorSize=*/std::nullopt}; + } + + void print(llvm::raw_ostream &os) const override; + LLVM_DUMP_METHOD void dump() const override; +}; + +} // namespace gpu +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_TENSOR_CORE_CONVOLUTION_STRATEGY_H_ diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp index c3a957be348a..f71e2e50cc23 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp @@ -29,11 +29,10 @@ int64_t iree_compiler::gpu::CopyMapping::maxContiguousElementsToTransfer( } FailureOr -iree_compiler::gpu::CopyMapping::numThreadsForCopy(int totalNumThreads, - int64_t alignment, - ArrayRef sizes, - bool favorPredication, - int64_t elementalBitWidth) { +iree_compiler::gpu::CopyMapping::numThreadsForCopy( + int totalNumThreads, int64_t alignment, ArrayRef sizes, + bool favorPredication, int64_t elementalBitWidth, + bool favorLazyOuterDistributing) { LDBG("\nSTART numThreadsForCopy, favorPredication: " << favorPredication); LLVM_DEBUG(llvm::interleaveComma(sizes, DBGS() << "--sizes: "); llvm::dbgs() << "\n";); @@ -81,19 +80,35 @@ iree_compiler::gpu::CopyMapping::numThreadsForCopy(int totalNumThreads, SmallVector scaledSizes{sizes.begin(), sizes.end()}; scaledSizes.back() /= actualVectorSize; - int64_t numThreadsRemaining = totalNumThreads; - LDBG("--numThreadsRemaining: " << numThreadsRemaining); SmallVector factors; - for (auto s : llvm::reverse(scaledSizes)) { - int64_t gcd = std::gcd(numThreadsRemaining, s); - factors.push_back(gcd); - numThreadsRemaining /= gcd; - LDBG("--new factors: " << gcd); + if (favorLazyOuterDistributing) { + int64_t numThreadsUsed = 1; + for (auto s : scaledSizes) { + int newThreads = 1; + for (auto maybeFactor : llvm::seq(1l, s + 1)) { + if (maybeFactor * numThreadsUsed > totalNumThreads) + break; + if (s % maybeFactor == 0) + newThreads = maybeFactor; + } + factors.push_back(newThreads); + numThreadsUsed *= newThreads; + LDBG("--new factors: " << newThreads); + LDBG("--numThreadsUsed: " << numThreadsUsed); + } + } else { + int64_t numThreadsRemaining = totalNumThreads; LDBG("--numThreadsRemaining: " << numThreadsRemaining); + for (auto s : llvm::reverse(scaledSizes)) { + int64_t gcd = std::gcd(numThreadsRemaining, s); + factors.push_back(gcd); + numThreadsRemaining /= gcd; + LDBG("--new factors: " << gcd); + LDBG("--numThreadsRemaining: " << numThreadsRemaining); + } + std::reverse(factors.begin(), factors.end()); } - std::reverse(factors.begin(), factors.end()); - LLVM_DEBUG(llvm::interleaveComma(factors, DBGS() << "numThreads: "); llvm::dbgs() << "\n"; LDBG("actualVectorSize: " << actualVectorSize);); @@ -104,12 +119,12 @@ iree_compiler::gpu::CopyMapping::numThreadsForCopy(int totalNumThreads, iree_compiler::gpu::MappingInfo iree_compiler::gpu::CopyMapping::getMappingInfo( MLIRContext *ctx, int totalNumThreads, int64_t alignment, ArrayRef copySizes, bool favorPredication, - int64_t elementalBitWidth) { + int64_t elementalBitWidth, bool favorLazyOuterDistributing) { assert(!copySizes.empty() && copySizes.size() <= 3 && "only 1,2,3-D copies are supported for now"); - FailureOr maybeCopyMapping = - CopyMapping::numThreadsForCopy(totalNumThreads, alignment, copySizes, - favorPredication, elementalBitWidth); + FailureOr maybeCopyMapping = CopyMapping::numThreadsForCopy( + totalNumThreads, alignment, copySizes, favorPredication, + elementalBitWidth, favorLazyOuterDistributing); // If failed, try again with predication; this must succeed. if (failed(maybeCopyMapping)) { assert(!favorPredication && diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h index 3918becc4d96..06e2f244a4b0 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h @@ -57,7 +57,8 @@ struct CopyMapping { static FailureOr numThreadsForCopy(int totalNumThreads, int64_t alignment, ArrayRef sizes, bool favorPredication, - int64_t elementalBitWidth = 32); + int64_t elementalBitWidth = 32, + bool favorLazyOuterDistributing = false); /// Greedily compute the MappingInfo to use to perform a copy of `sizes` /// elements of bitwidth `elementalBitWidth`. @@ -75,7 +76,8 @@ struct CopyMapping { static MappingInfo getMappingInfo(MLIRContext *ctx, int totalNumThreads, int64_t alignment, ArrayRef sizes, bool favorPredication = false, - int64_t elementalBitWidth = 32); + int64_t elementalBitWidth = 32, + bool favorLazyOuterDistributing = false); }; } // namespace gpu diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.cpp new file mode 100644 index 000000000000..8dbbc55f6cb9 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.cpp @@ -0,0 +1,284 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.h" + +#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" +#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h" +#include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" + +using namespace mlir; + +#define DEBUG_TYPE "iree-transform-builder" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +// TODO: significantly better namespacing. +using iree_compiler::buildPad; +using iree_compiler::buildSelectFirstNonEmpty; +using iree_compiler::buildTileFuseDistToForallWithNumThreads; +using iree_compiler::buildTileFuseDistToForallWithTileSizes; +using iree_compiler::TileToForallAndFuseAndDistributeResult; +using iree_compiler::gpu::BatchMatmulStrategy; +using iree_compiler::gpu::buildBufferize; +using iree_compiler::gpu::buildConvertToAsyncCopies; +using iree_compiler::gpu::buildConvertToTensorCoreOp; +using iree_compiler::gpu::buildDistributeMatmulCopies; +using iree_compiler::gpu::buildHoistOutputPaddingOp; +using iree_compiler::gpu::buildMatmulVectorization; +using iree_compiler::gpu::buildMultiBuffering; +using iree_compiler::gpu::buildPipelineSharedMemoryCopies; +using iree_compiler::gpu::DataTiledMatmulStrategy; +using iree_compiler::gpu::MappingInfo; +using iree_compiler::gpu::scaleUpByBitWidth; +using iree_compiler::IREE::transform_dialect:: + ApplyFoldReshapeIntoTensorHalInterfacePatternsOp; +using iree_compiler::IREE::transform_dialect::EliminateGpuBarriersOp; +using iree_compiler::IREE::transform_dialect:: + IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp; +using transform::FuseIntoContainingOp; +using transform::MatchOp; +using transform_ext::RegisterMatchCallbacksOp; + +void DataTiledMatmulStrategy::initDefaultValues(const GPUModel &gpuModel) { + // Set the configuration for padding the matmul. + paddingValueTypes = {captures.lhsElementType, captures.rhsElementType, + captures.outputElementType}; + paddingDimensions = {0, 1, 2}; + packingDimensions = {1, 1, 1}; + + // Pull in tile configs from flags. + AbstractGemmLikeStrategy::initDefaultValues(gpuModel); + + // Data tiled strategies have specific requirements so adjust here. + + // Consolidate the warps/threads along X. + numWarps[0] *= numWarps[1]; + numWarps[1] = 1; + numThreads[0] *= numThreads[1]; + numThreads[1] *= 1; + // BlockTileN is effectively the inner tile. + blockTileSizes[1] = captures.matmulOpSizes[captures.contractionDims.n.back()]; + // Adjust downwards to force alignment along M. + while (m() % blockTileSizes[0]) { + blockTileSizes[0] /= 2; + if (numWarps[0] > 1) { + numWarps[0] /= 2; + numThreads[0] /= 2; + } + } + // Reduction tile size is unused. + reductionTileSize = 1; + // Force wmma. + useWmma = true; + useMmaSync = false; + useFma = false; + // Disable pipelining. + useAsyncCopies = false; + pipelineDepth = 0; + if (gpuModel.minSubgroupSize) + targetSubgroupSize = *gpuModel.minSubgroupSize; +} + +LLVM_DUMP_METHOD void DataTiledMatmulStrategy::dump() const { + print(llvm::errs()); +} + +void DataTiledMatmulStrategy::print(llvm::raw_ostream &os) const { + os << "\n--- Data Tiled Matmul strategy ---\n"; + AbstractGemmLikeStrategy::print(os); +} + +// TODO: Implement a validator. +LogicalResult +DataTiledMatmulStrategy::validate(const GPUModel &gpuModel) const { + return success(); +} + +static std::tuple +buildDataTiledMatmulStrategyBlockDistribution( + ImplicitLocOpBuilder &b, Value variantH, + const DataTiledMatmulStrategy &strategy) { + // Step 1. Call the matcher. Note that this is the same matcher as used to + // trigger this compilation path, so it must always apply. + b.create(); + auto [fillH, matmulH, maybeTrailingH] = unpackRegisteredMatchCallback<3>( + b, "contraction", transform::FailurePropagationMode::Propagate, variantH); + + // Step 2. Create the block/mapping tiling level and fusee. + auto [fusionTargetH, fusionGroupH] = + buildSelectFirstNonEmpty(b, maybeTrailingH, matmulH); + MappingInfo blockMapping = strategy.getBlockMapping(); + TileToForallAndFuseAndDistributeResult tileResult = + buildTileFuseDistToForallWithTileSizes( + /*builder=*/b, + /*variantH=*/variantH, + /*rootH=*/fusionTargetH, + /*opsToFuseH=*/fusionGroupH, + /*tileSizes=*/ + getAsOpFoldResult(b.getI64ArrayAttr(blockMapping.tileSizes)), + /*threadDimMapping=*/ + b.getArrayAttr(blockMapping.threadMapping)); + + auto [blockMatmulH, maybeBlockTrailingH] = buildSelectFirstNonEmpty( + b, tileResult.resultingFusedOpsHandles.front(), tileResult.tiledOpH); + + Value fusedFillH = + b.create(fillH, tileResult.forallH).getFusedOp(); + + // Handle the workgroup count region. + b.create( + tileResult.forallH); + + // TODO: handle trailing op. + return std::make_tuple(fusedFillH, blockMatmulH, maybeBlockTrailingH, + tileResult.forallH); +} + +/// Builds the common part of the schedule for matmuls and batched matmuls. +static void +buildCommonMatmulLikeThreadSchedule(ImplicitLocOpBuilder &b, Value variantH, + Value fillH, Value matmulH, Value trailingH, + const DataTiledMatmulStrategy &strategy) { + using mlir::iree_compiler::buildLowerVectorMasksAndCleanup; + using mlir::iree_compiler::buildTileFuseToScfFor; + using namespace mlir::iree_compiler::gpu; + + // Tile the reduction loop (last in the list). + SmallVector tileSizes(strategy.captures.matmulOpSizes.size() - + strategy.captures.contractionDims.k.size(), + 0); + if (strategy.captures.contractionDims.k.size() == 2) { + tileSizes.push_back(1); + } else { + tileSizes.push_back( + strategy.captures + .matmulOpSizes[strategy.captures.contractionDims.k.back()]); + } + + // Avoid canonicalizing before the pad to avoid folding away the extract_slice + // on the output needed to hoist the output pad. + auto tileReductionResult = buildTileFuseToScfFor( + b, variantH, matmulH, {}, getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)), + /*canonicalize=*/false); + + // Step 2. Pad the (batch) matmul op. + auto paddedMatmulOpH = + buildPad(b, tileReductionResult.tiledOpH, + strategy.getZeroPadAttrFromElementalTypes(b).getValue(), + strategy.paddingDimensions, strategy.packingDimensions); + + // Step 3. Hoist the padding of the output operand above the reduction loop. + // The resulting fillOp will be mapped with the contraction using an SIMD + // programming model. + Value fillOpH = fillH; + if (!strategy.alignedRes()) { + fillOpH = buildHoistOutputPaddingOp(b, variantH, paddedMatmulOpH); + } + + // Running canonicalization is required here to enable aligned pads to become + // linalg.copy ops when rewriting in DPS. + Value funcH = + b.create(variantH, func::FuncOp::getOperationName()); + iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH); + + // Step 4. Distribute pad and copies: SIMT programming model. + auto [lhsCopyOpH, rhsCopyOpH, copyBackOpH] = + buildDistributeMatmulCopies(b, variantH, paddedMatmulOpH, strategy); + + // Step 5. Distribute to warps: SIMD programming model. + // TODO: get the number of warps from strategy. + MappingInfo computeMapping = strategy.computeMapping(); + buildTileFuseDistToForallWithNumThreads( + b, variantH, paddedMatmulOpH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(computeMapping.numThreads)), + b.getArrayAttr(computeMapping.threadMapping)); + buildTileFuseDistToForallWithNumThreads( + b, variantH, fillOpH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(computeMapping.numThreads)), + b.getArrayAttr(computeMapping.threadMapping)); + + // Step 5.5. Distribute to threads: SIMT programming model. + MappingInfo resCopyMapping = strategy.resCopyMapping(); + buildTileFuseDistToForallWithNumThreads( + b, variantH, trailingH, ValueRange(), + getAsOpFoldResult(b.getI64ArrayAttr(resCopyMapping.numThreads)), + b.getArrayAttr(resCopyMapping.threadMapping)); + + // Step 6. Rank-reduce and vectorize. + b.create(funcH, [](OpBuilder &b, Location loc) { + b.create(loc); + b.create(loc); + b.create(loc); + }); + buildMatmulVectorization(b, variantH, lhsCopyOpH, rhsCopyOpH, copyBackOpH, + strategy); + + // Step 7. Bufferize and drop HAL descriptor from memref ops. + variantH = buildBufferize(b, variantH); + + // Step 8. Post-bufferization mapping to blocks and threads. + // Need to match again since bufferize invalidated all handles. + // TODO: assumes a single func::FuncOp to transform, needs hardening. + funcH = b.create(variantH, func::FuncOp::getOperationName()); + funcH = + buildMapToBlockAndThreads(b, funcH, + /*blockSize=*/strategy.numThreads, + /*warpDims=*/strategy.numWarps, + /*subgroupSize=*/strategy.targetSubgroupSize); + funcH = b.create(funcH); + + // Step 9. Convert to tensor core ops. + // TODO: avoid consuming handles and returning here. + funcH = buildConvertToTensorCoreOp(b, funcH, strategy); + + // TODO: Support pipelining strategies without async copy (e.g. store to + // shared memory in stage 0). + if (strategy.useAsyncCopies) { + // Step 10. Multi-buffering. + if (strategy.pipelineDepth > 1) + buildMultiBuffering(b, funcH, strategy); + + // Step 11. Convert to async copies. + // TODO: avoid consuming handles and returning here. + funcH = buildConvertToAsyncCopies(b, funcH, strategy); + + // Step 12. Pipeline shared memory copies. + if (strategy.pipelineDepth > 1) + buildPipelineSharedMemoryCopies(b, funcH, strategy); + } + + // Step 13. Late lowerings and cleanups. + buildLowerVectorMasksAndCleanup(b, funcH); +} + +void iree_compiler::gpu::buildDataTiledMatmulStrategy( + ImplicitLocOpBuilder &b, Value variantH, + const DataTiledMatmulStrategy &strategy) { + LLVM_DEBUG(strategy.print(DBGS())); + + // Step 1. Apply block-level part of the strategy, keeps everything fused. + auto [fillH, matmulH, maybeTiledTrailingHBlock, forall] = + buildDataTiledMatmulStrategyBlockDistribution(b, variantH, strategy); + buildCommonMatmulLikeThreadSchedule(b, variantH, fillH, matmulH, + maybeTiledTrailingHBlock, strategy); +} diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.h new file mode 100644 index 000000000000..c5c1507bb6ed --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.h @@ -0,0 +1,183 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_DATA_TILED_MATMUL_STRATEGY_H_ +#define IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_DATA_TILED_MATMUL_STRATEGY_H_ + +#include "iree-dialects/Transforms/TransformMatchers.h" +#include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/MathExtras.h" + +namespace llvm { +class raw_ostream; +} + +namespace mlir { +namespace iree_compiler { +namespace gpu { + +struct GPUModel; + +class DataTiledMatmulStrategy : public AbstractGemmLikeStrategy { +public: + DataTiledMatmulStrategy(MLIRContext *context, + const transform_ext::MatchedMatmulCaptures &captures, + const GPUModel &gpuModel) + : AbstractGemmLikeStrategy(gpuModel), ctx(context), captures(captures) { + initDefaultValues(gpuModel); + } + + DataTiledMatmulStrategy(const DataTiledMatmulStrategy &) = default; + DataTiledMatmulStrategy &operator=(const DataTiledMatmulStrategy &) = default; + + /// Constructor quantities. + MLIRContext *ctx; + transform_ext::MatchedMatmulCaptures captures; + + /// Initialize values from the CLI. Set cliOptionsSpecified to true if the + /// default CLI values have been overriden. + void initDefaultValues(const GPUModel &gpuModel) override; + + LogicalResult validate(const GPUModel &gpuModel) const override; + + int64_t m() const override { + int64_t mElements = 1; + for (auto i : captures.contractionDims.m) { + mElements *= captures.matmulOpSizes[i]; + } + return mElements; + } + int64_t n() const override { + int64_t nElements = 1; + for (auto i : captures.contractionDims.n) { + nElements *= captures.matmulOpSizes[i]; + } + return nElements; + } + int64_t k() const override { + int64_t kElements = 1; + for (auto i : captures.contractionDims.k) { + kElements *= captures.matmulOpSizes[i]; + } + return kElements; + } + + int64_t blockTileM() const override { return blockTileSizes[0]; } + int64_t blockTileN() const override { + return captures.matmulOpSizes[captures.contractionDims.n.back()]; + } + + int64_t numWarpsX() const override { return numWarps[0]; } + int64_t numWarpsY() const override { return 1; } + + Type getLhsElementalType() const override { return captures.lhsElementType; } + Type getRhsElementalType() const override { return captures.rhsElementType; } + Type getResElementalType() const override { + return captures.outputElementType; + } + + MappingInfo getBlockMapping() const override { + SmallVector tileSizes; + SmallVector threadMapping = {blockX(ctx)}; + // Outer output channel. + if (captures.contractionDims.n.size() == 2) { + tileSizes.push_back(1); + threadMapping = {blockY(ctx), blockX(ctx)}; + } + tileSizes.push_back(blockTileM()); + return MappingInfo{/*numThreads=*/{}, + /*tileSizes=*/tileSizes, + /*threadMapping=*/threadMapping, + /*vectorSize=*/std::nullopt}; + } + + // LHS copy is of size mxk. + MappingInfo lhsCopyMapping() const override { + int64_t kInnerTileSize = + captures.matmulOpSizes[captures.contractionDims.k.back()]; + return CopyMapping::getMappingInfo( + ctx, totalNumThreads(), + /*alignment=*/k(), + /*copySizes=*/captures.contractionDims.k.size() == 2 + ? ArrayRef{1, blockTileM(), kInnerTileSize} + : ArrayRef{blockTileM(), kInnerTileSize}, + /*favorPredication=*/false, + /*elementalBitWidth=*/lhsElementalBitWidth()); + } + // TODO: Implement validator. + LogicalResult validateLhsCopyMapping() const override { return success(); } + + // RHS copy is of size kxn. + MappingInfo rhsCopyMapping() const override { + int64_t kInnerTileSize = + captures.matmulOpSizes[captures.contractionDims.k.back()]; + int64_t nInnerTileSize = + captures.matmulOpSizes[captures.contractionDims.n.back()]; + MappingInfo mapping = CopyMapping::getMappingInfo( + ctx, totalNumThreads(), + /*alignment=*/k(), + /*copySizes=*/ArrayRef{nInnerTileSize, kInnerTileSize}, + /*favorPredication=*/false, + /*elementalBitWidth=*/rhsElementalBitWidth()); + if (captures.contractionDims.n.size() == 2) { + mapping.tileSizes.insert(mapping.tileSizes.begin(), 1); + mapping.numThreads.insert(mapping.numThreads.begin(), 0); + } + if (captures.contractionDims.k.size() == 2) { + mapping.tileSizes.insert(mapping.tileSizes.begin(), 1); + mapping.numThreads.insert(mapping.numThreads.begin(), 0); + } + return mapping; + } + // TODO: Implement validator. + LogicalResult validateRhsCopyMapping() const override { return success(); } + + // RES copy is of size mxn. + MappingInfo resCopyMapping() const override { + int64_t nInnerTileSize = + captures.matmulOpSizes[captures.contractionDims.n.back()]; + return CopyMapping::getMappingInfo( + ctx, totalNumThreads(), + /*alignment=*/n(), + /*copySizes=*/captures.contractionDims.n.size() == 2 + ? ArrayRef{1, blockTileM(), nInnerTileSize} + : ArrayRef{blockTileM(), nInnerTileSize}, + /*favorPredication=*/false, + /*elementalBitWidth=*/resElementalBitWidth()); + } + // TODO: Implement validator. + LogicalResult validateResCopyMapping() const override { return success(); } + + // COMPUTE is of size mxn. + MappingInfo computeMapping() const override { + if (useFma) { + // When using FMA we don't need to map to warps, instead just match what + // the copy does. + return resCopyMapping(); + } + return MappingInfo{/*numThreads=*/captures.contractionDims.n.size() == 2 + ? SmallVector{0, numWarpsX()} + : SmallVector{numWarpsX()}, + /*tileSizes=*/{}, + /*threadMapping=*/{warpX(ctx)}, + /*vectorSize=*/std::nullopt}; + } + + void print(llvm::raw_ostream &os) const override; + LLVM_DUMP_METHOD void dump() const override; +}; + +} // namespace gpu +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_CODEGEN_TRANSFORM_DIALECT_STRATEGIES_GPU_DATA_TILED_MATMUL_STRATEGY_H_ diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp index f92f8f21631d..f8b303dde33e 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp @@ -96,6 +96,29 @@ LogicalResult MatmulStrategy::validate(const GPUModel &gpuModel) const { return failure(); } + if (useMmaSync) { + if (!gpuModel.hasMmaSync) { + LDBG("--Matmul strategy target does not support MMA.SYNC operations\n"); + return failure(); + } + } else { + // Verify WMMA. + // Hard coded to reflect current WMMA unrolling support. + int reqM = 16; + int reqN = 16; + int reqK = lhsElementType.isF32() ? 8 : 16; + if (llvm::all_of(gpuModel.supportedWMMAConfigs, + [&](iree_compiler::gpu::MMAConfig config) { + return config.m != reqM || config.n != reqN || + config.k != reqK || + config.aType != lhsElementType || + config.bType != rhsElementType || + config.cType != resElementType; + })) { + LDBG("--Matmul strategy failed wmma type check\n"); + return failure(); + } + } return success(); } @@ -238,8 +261,11 @@ buildCommonMatmulLikeThreadSchedule(ImplicitLocOpBuilder &b, Value variantH, // Need to match again since bufferize invalidated all handles. // TODO: assumes a single func::FuncOp to transform, needs hardening. funcH = b.create(variantH, func::FuncOp::getOperationName()); - funcH = buildMapToBlockAndThreads(b, funcH, strategy.numThreads, - strategy.numWarps); + funcH = + buildMapToBlockAndThreads(b, funcH, + /*blockSize=*/strategy.numThreads, + /*warpDims=*/strategy.numWarps, + /*subgroupSize=*/strategy.targetSubgroupSize); funcH = b.create(funcH); // Step 9. Convert to tensor core ops. diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.h index 8c3e0dab5e1e..cba7da3dde43 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.h @@ -20,7 +20,7 @@ namespace gpu { struct PadConfig {}; /// Simple padding strategy. -class PadStrategy : GPUStrategy { +class PadStrategy : public GPUStrategy { public: PadStrategy(MLIRContext *context, const transform_ext::MatchedPadCaptures &captures, diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp index 32b9490fafb0..c346de3619c0 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp @@ -15,6 +15,9 @@ #include "iree/compiler/Codegen/TransformStrategies/Common/Common.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionImplicitGemmStrategy.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionStrategy.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/ConvolutionTensorCoreStrategy.h" +#include "iree/compiler/Codegen/TransformStrategies/GPU/DataTiledMatmulStrategy.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.h" @@ -55,6 +58,14 @@ llvm::cl::opt clGPUEnableTransformDialectImplicitGemmStrategy( "iree-codegen-llvmgpu-enable-transform-dialect-implicit-gemm-strategy", llvm::cl::desc("activate the convolution implicit gemm strategy"), llvm::cl::init(false)); +llvm::cl::opt clGPUEnableTransformDialectConvolutionTensorCoreStrategy( + "iree-codegen-llvmgpu-enable-transform-dialect-convolution-tensorcore-" + "strategy", + llvm::cl::desc("activate the convolution tensorcore strategy"), + llvm::cl::init(true)); +llvm::cl::opt clGPUEnableTransformDialectConvolutionStrategy( + "iree-codegen-llvmgpu-enable-transform-dialect-convolution-strategy", + llvm::cl::desc("activate the convolution strategy"), llvm::cl::init(true)); llvm::cl::opt clGPUEnableTransformDialectAlignedMatmul( "iree-codegen-llvmgpu-enable-transform-dialect-aligned-matmul", llvm::cl::desc( @@ -73,10 +84,18 @@ llvm::cl::opt clGPUEnableTransformDialectBatchMatmulStrategy( llvm::cl::desc("activate the batch matmul strategy, additional " "configuration flags are shared with matmul"), llvm::cl::init(false)); +llvm::cl::opt clGPUEnableTransformDialectDataTiledMatmulStrategy( + "iree-codegen-llvmgpu-enable-transform-dialect-data-tiled-matmul-strategy", + llvm::cl::desc("activate the data tiled matmul strategy, additional " + "configuration flags are shared with matmul"), + llvm::cl::init(true)); // TODO: significantly better namespacing. using iree_compiler::gpu::AbstractGemmLikeStrategy; using iree_compiler::gpu::BatchMatmulStrategy; +using iree_compiler::gpu::ConvolutionStrategy; +using iree_compiler::gpu::DataTiledConvolutionStrategy; +using iree_compiler::gpu::DataTiledMatmulStrategy; using iree_compiler::gpu::GPUModel; using iree_compiler::gpu::ImplicitGemmStrategy; using iree_compiler::gpu::kCudaMaxVectorLoadBitWidth; @@ -539,6 +558,58 @@ static LogicalResult matchAndSetMatmulStrategy(func::FuncOp entryPoint, return success(); } +static DataTiledMatmulStrategy +getDataTiledMatmulConfig(MLIRContext *context, MatchedMatmulCaptures &captures, + const GPUModel &gpuModel) { + return DataTiledMatmulStrategy(context, captures, gpuModel); +} + +/// Match the supported batch matmuls and set the transform dialect strategy for +/// them. +static LogicalResult +matchAndSetDataTiledMatmulStrategy(func::FuncOp entryPoint, linalg::LinalgOp op, + const GPUModel &gpuModel) { + if (!clGPUEnableTransformDialectDataTiledMatmulStrategy) { + LDBG("--Data tiled matmul strategy flag turned off\n"); + return failure(); + } + + StructuredOpMatcher *fill; + StructuredOpMatcher *dtm; + StructuredOpMatcher *trailing; + transform_ext::MatchedMatmulCaptures captures; + transform_ext::MatcherContext matcherContext; + transform_ext::makeAnyContractionMatcher(matcherContext, dtm, fill, trailing, + captures, + /*mustMatchEntireFunc=*/true); + if (!matchPattern(op, *dtm)) { + LDBG("--Data tiled matmul strategy failed to match\n"); + return failure(); + } + + if (captures.contractionDims.batch.size() != 0 || + captures.contractionDims.m.size() != 1 || + (captures.contractionDims.n.size() != 2 && + captures.contractionDims.k.size() != 2)) { + LDBG("--Data tiled matmul failed problem type check\n"); + return failure(); + } + + DataTiledMatmulStrategy strategy = + getDataTiledMatmulConfig(entryPoint->getContext(), captures, gpuModel); + if (failed(strategy.validate(gpuModel))) { + LDBG("--Data tiled matmul strategy failed to validate\n"); + return failure(); + } + + iree_compiler::createTransformRegion( + entryPoint, [&](ImplicitLocOpBuilder &b, Value variantH) { + return iree_compiler::gpu::buildDataTiledMatmulStrategy(b, variantH, + strategy); + }); + return success(); +} + //===--------------------------------------------------------------------===// // Convolution strategies. //===--------------------------------------------------------------------===// @@ -546,7 +617,7 @@ static LogicalResult matchAndSetMatmulStrategy(func::FuncOp entryPoint, /// precedence over other heuristics. In the future, this could be lifted to /// e.g. `gpuModel` or higher up in some transform dialect database summary of /// "known good things". -static FailureOr applyKnownGoodConvolutionConfigurations( +static FailureOr applyKnownGoodImplicitGemmConfigurations( const transform_ext::MatchedConvolutionCaptures &captures, const GPUModel &gpuModel) { return failure(); @@ -585,15 +656,15 @@ static void failSafeOverrides(ImplicitGemmStrategy &strategy, /// The configurations below have been determined empirically. // TODO: Significantly improve these heuristics. static ImplicitGemmStrategy -getConvolutionConfig(MLIRContext *context, - const transform_ext::MatchedConvolutionCaptures &captures, - const GPUModel &gpuModel) { +getImplicitGemmConfig(MLIRContext *context, + const transform_ext::MatchedConvolutionCaptures &captures, + const GPUModel &gpuModel) { ImplicitGemmStrategy strategy(context, captures, gpuModel); if (strategy.cliOptionsSpecified) return strategy; auto maybeHardcodedConfiguration = - applyKnownGoodConvolutionConfigurations(captures, gpuModel); + applyKnownGoodImplicitGemmConfigurations(captures, gpuModel); if (succeeded(maybeHardcodedConfiguration)) return *maybeHardcodedConfiguration; @@ -605,21 +676,22 @@ getConvolutionConfig(MLIRContext *context, return strategy; } -static LogicalResult matchAndSetConvolutionStrategy(func::FuncOp entryPoint, - linalg::LinalgOp op, - const GPUModel &gpuModel) { +static LogicalResult matchAndSetConvolutionImplicitGemmStrategy( + func::FuncOp entryPoint, linalg::LinalgOp op, const GPUModel &gpuModel) { if (!clGPUEnableTransformDialectImplicitGemmStrategy) { LDBG("--Implicit gemm strategy flag turned off\n"); return failure(); } // 1. Match a reduction and surrounding ops. + CapturingOpMatcher *pad; StructuredOpMatcher *fill; StructuredOpMatcher *convolution; StructuredOpMatcher *trailing; transform_ext::MatchedConvolutionCaptures captures; transform_ext::MatcherContext matcherContext; - makeConvolutionMatcher(matcherContext, convolution, fill, trailing, captures, + makeConvolutionMatcher(matcherContext, convolution, pad, fill, trailing, + captures, /*mustMatchEntireFunc=*/true); if (!matchPattern(op, *convolution)) { LDBG("--Implicit gemm strategy fail to match\n"); @@ -631,7 +703,7 @@ static LogicalResult matchAndSetConvolutionStrategy(func::FuncOp entryPoint, // - Mandatory fill op. // - Require minimum tile alignment due to img2col. // - Otherwise, we take it. - if (!fill->getCaptured() || trailing->getCaptured()) { + if (!fill->getCaptured() || trailing->getCaptured() || pad->getCaptured()) { LDBG("--Implicit gemm strategy fill / trailing preconditions failed\n"); return failure(); } @@ -673,7 +745,7 @@ static LogicalResult matchAndSetConvolutionStrategy(func::FuncOp entryPoint, } iree_compiler::gpu::ImplicitGemmStrategy strategy = - getConvolutionConfig(op->getContext(), captures, gpuModel); + getImplicitGemmConfig(op->getContext(), captures, gpuModel); // Validate the strategy configuration against the compilation target. if (failed(strategy.validate(gpuModel))) { @@ -693,6 +765,177 @@ static LogicalResult matchAndSetConvolutionStrategy(func::FuncOp entryPoint, return success(); } +static FailureOr +applyKnownGoodConvolutionConfigurations( + const transform_ext::MatchedConvolutionCaptures &captures, + const GPUModel &gpuModel) { + return failure(); +} + +static void failSafeOverrides(DataTiledConvolutionStrategy &strategy, + const GPUModel &gpuModel) {} + +/// The configurations below have been determined empirically. +// TODO: Significantly improve these heuristics. +static DataTiledConvolutionStrategy +getDataTiledConvolutionConfig(MLIRContext *context, + const transform_ext::MatchedConvolutionCaptures &captures, + const GPUModel &gpuModel) { + DataTiledConvolutionStrategy strategy(context, captures, gpuModel); + if (strategy.cliOptionsSpecified) + return strategy; + + auto maybeHardcodedConfiguration = + applyKnownGoodConvolutionConfigurations(captures, gpuModel); + if (succeeded(maybeHardcodedConfiguration)) + return *maybeHardcodedConfiguration; + + // TODO: encode a decision tree of reasonnable heuristics here. + + // Apply failsafe overrides to avoid identified bad corner cases. + failSafeOverrides(strategy, gpuModel); + + return strategy; +} + +static LogicalResult matchAndSetDataTiledConvolutionStrategy( + func::FuncOp entryPoint, linalg::LinalgOp op, const GPUModel &gpuModel) { + if (!clGPUEnableTransformDialectConvolutionTensorCoreStrategy) { + LDBG("--Convolution strategy flag turned off\n"); + return failure(); + } + + // 1. Match a reduction and surrounding ops. + CapturingOpMatcher *pad; + StructuredOpMatcher *fill; + StructuredOpMatcher *convolution; + StructuredOpMatcher *trailing; + transform_ext::MatchedConvolutionCaptures captures; + transform_ext::MatcherContext matcherContext; + makeConvolutionMatcher(matcherContext, convolution, pad, fill, trailing, + captures, + /*mustMatchEntireFunc=*/true); + if (!matchPattern(op, *convolution)) { + LDBG("--Convolution strategy fail to match\n"); + return failure(); + } + + if (!fill->getCaptured()) { + LDBG("--Convolution strategy capture precondition failed\n"); + return failure(); + } + + if (captures.convolutionDims.outputImage.size() != 2) { + return failure(); + } + if (captures.convolutionDims.filterLoop.size() != 2) { + return failure(); + } + if (captures.convolutionDims.batch.size() != 0) { + return failure(); + } + + // int64_t channelSize = 1; + // for (auto dim : captures.convolutionDims.outputChannel) + // channelSize *= captures.convolutionOpSizes[dim]; + // int64_t imageSize = 1; + // for (auto dim : captures.convolutionDims.outputImage) + // imageSize *= captures.convolutionOpSizes[dim]; + + // int64_t derivedK = 1; + // for (auto dim : captures.convolutionDims.filterLoop) + // derivedK *= captures.convolutionOpSizes[dim]; + // for (auto dim : captures.convolutionDims.inputChannel) + // derivedK *= captures.convolutionOpSizes[dim]; + + iree_compiler::gpu::DataTiledConvolutionStrategy strategy = + getDataTiledConvolutionConfig(op->getContext(), captures, gpuModel); + + // Validate the strategy configuration against the compilation target. + if (failed(strategy.validate(gpuModel))) { + LDBG("--Convolution strategy failed to validate\n"); + return failure(); + } + + // 2. Construct the configuration and the strategy builder. + // TODO: Generalize along the HW axis. + auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) { + return buildConvolutionTensorCoreStrategy(b, variant, strategy); + }; + + // 3. Build strategy embedded into the IR. + mlir::iree_compiler::createTransformRegion(entryPoint, strategyBuilder); + + return success(); +} + +/// The configurations below have been determined empirically. +// TODO: Significantly improve these heuristics. +static ConvolutionStrategy +getDirectConvolutionConfig(MLIRContext *context, + const transform_ext::MatchedConvolutionCaptures &captures, + const GPUModel &gpuModel) { + return ConvolutionStrategy(context, captures, gpuModel); +} + +static LogicalResult matchAndSetDirectConvolutionStrategy( + func::FuncOp entryPoint, linalg::LinalgOp op, const GPUModel &gpuModel) { + if (!clGPUEnableTransformDialectConvolutionStrategy) { + LDBG("--Convolution strategy flag turned off\n"); + return failure(); + } + + // 1. Match a reduction and surrounding ops. + CapturingOpMatcher *pad; + StructuredOpMatcher *fill; + StructuredOpMatcher *convolution; + StructuredOpMatcher *trailing; + transform_ext::MatchedConvolutionCaptures captures; + transform_ext::MatcherContext matcherContext; + makeConvolutionMatcher(matcherContext, convolution, pad, fill, trailing, + captures, + /*mustMatchEntireFunc=*/true); + if (!matchPattern(op, *convolution)) { + LDBG("--Convolution strategy fail to match\n"); + return failure(); + } + + if (!fill->getCaptured() || pad->getCaptured()) { + LDBG("--Convolution strategy capture preconditions failed\n"); + return failure(); + } + + if (captures.convolutionDims.outputImage.size() != 2) { + return failure(); + } + if (captures.convolutionDims.filterLoop.size() != 2) { + return failure(); + } + if (captures.convolutionDims.batch.size() != 0) { + return failure(); + } + + iree_compiler::gpu::ConvolutionStrategy strategy = + getDirectConvolutionConfig(op->getContext(), captures, gpuModel); + + // Validate the strategy configuration against the compilation target. + if (failed(strategy.validate(gpuModel))) { + LDBG("--Convolution strategy failed to validate\n"); + return failure(); + } + + // 2. Construct the configuration and the strategy builder. + // TODO: Generalize along the HW axis. + auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) { + return buildConvolutionStrategy(b, variant, strategy); + }; + + // 3. Build strategy embedded into the IR. + mlir::iree_compiler::createTransformRegion(entryPoint, strategyBuilder); + + return success(); +} + //===--------------------------------------------------------------------===// // Pad strategies. //===--------------------------------------------------------------------===// @@ -808,7 +1051,22 @@ LogicalResult mlir::iree_compiler::gpu::matchAndSetTransformStrategy( return success(); } if (succeeded( - matchAndSetConvolutionStrategy(entryPoint, linalgOp, gpuModel))) { + matchAndSetDataTiledMatmulStrategy(entryPoint, linalgOp, gpuModel))) { + LDBG("Activate data tiled matmul\n"); + return success(); + } + if (succeeded(matchAndSetDataTiledConvolutionStrategy(entryPoint, linalgOp, + gpuModel))) { + LDBG("Activate convolution\n"); + return success(); + } + if (succeeded(matchAndSetConvolutionImplicitGemmStrategy(entryPoint, linalgOp, + gpuModel))) { + LDBG("Activate convolution\n"); + return success(); + } + if (succeeded( + matchAndSetDirectConvolutionStrategy(entryPoint, linalgOp, gpuModel))) { LDBG("Activate convolution\n"); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h index 852d5f16dc7a..f606b2336cd0 100644 --- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h +++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h @@ -17,8 +17,11 @@ namespace iree_compiler { namespace gpu { /// Forward declarations of all supported strategies. -struct BatchMatmulStrategy; -struct MatmulStrategy; +class BatchMatmulStrategy; +class MatmulStrategy; +class ConvolutionStrategy; +class DataTiledConvolutionStrategy; +class DataTiledMatmulStrategy; class PadStrategy; class SmallReductionStrategy; class StagedReductionStrategy; @@ -26,6 +29,18 @@ class StagedReductionStrategy; static constexpr int64_t kCudaWarpSize = 32; static constexpr int64_t kCudaMaxNumThreads = 1024; +/// Placeholder for representing supported WMMA/Cooperative Matrix +/// configurations. This is a reflection of +/// SPIRV_CooperativeMatrixPropertiesNVArrayAttr. +struct MMAConfig { + int64_t m; + int64_t n; + int64_t k; + Type aType; + Type bType; + Type cType; +}; + /// Placeholder for some hardware model proxy that contains relevant information /// to configure the strategies. In the future, this will need to be /// driven by some contract with the runtime. @@ -34,11 +49,14 @@ struct GPUModel { llvm::StringRef model = kDefaultGPU; /// TODO: Support a range of subgroup sizes. int64_t subgroupSize = kCudaWarpSize; + std::optional minSubgroupSize = std::nullopt; + std::optional maxSubgroupSize = std::nullopt; int64_t maxWorkGroupInvocations = kCudaMaxNumThreads; int64_t maxWorkGroupSize[3] = {1024, 1024, 64}; bool hasWarpShuffle = false; bool hasTF32TensorCore = false; bool hasMmaSync = false; + SmallVector supportedWMMAConfigs = {}; }; //===--------------------------------------------------------------------===// @@ -73,6 +91,23 @@ void buildMatmulTensorCoreStrategy(ImplicitLocOpBuilder &b, Value variantH, void buildBatchMatmulStrategy(ImplicitLocOpBuilder &b, Value variantH, const BatchMatmulStrategy &strategy); +//===--------------------------------------------------------------------===// +// Data tiled matmul strategies. +//===--------------------------------------------------------------------===// +/// Entry point to build the transform IR corresponding to an FMA-based strategy +/// for linalg.fill + linalg.batch_matmul. +void buildDataTiledMatmulStrategy(ImplicitLocOpBuilder &b, Value variantH, + const DataTiledMatmulStrategy &strategy); + +//===--------------------------------------------------------------------===// +// Convolution strategies. +//===--------------------------------------------------------------------===// +void buildConvolutionTensorCoreStrategy( + ImplicitLocOpBuilder &b, Value variantH, + const DataTiledConvolutionStrategy &strategy); +void buildConvolutionStrategy(ImplicitLocOpBuilder &b, Value variantH, + const ConvolutionStrategy &strategy); + //===--------------------------------------------------------------------===// // Pad strategies. //===--------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index eca1e5ad4043..133f9a24b79c 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -578,7 +578,7 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { int64_t m = 16; int64_t n = 16; if (auto contract = dyn_cast(op)) { - int64_t k = contract.getLhsType().getElementType().isF16() ? 16 : 8; + int64_t k = contract.getLhsType().getElementType().isF32() ? 8 : 16; SmallVector nativeSize(contract.getIteratorTypes().size() - 3, 1); nativeSize.append({m, n, k}); return nativeSize; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp index 9e228d0f4ccc..f3b05e96c5a6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp @@ -39,7 +39,10 @@ struct DetachElementwisePattern LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { if (!linalg::isaContractionOpInterface(linalgOp) && - !isa(*linalgOp)) { + !isa(*linalgOp) && + !linalg::detail::getMatchConvolutionMessage( + mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp)) + .empty()) { return failure(); } if (!linalgOp.hasTensorSemantics()) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index b050654a8d8b..7ea303de8cc2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -629,8 +629,11 @@ isFusableWithProducer(OpOperand &operand, return false; } + mlir::linalg::detail::ConvolutionDimensions ignore; if (options.fusePadWithConsumers && isa(producer) && - isa(consumer)) { + linalg::detail::getMatchConvolutionMessage( + linalg::detail::isConvolutionInterfaceImpl(consumer, &ignore)) + .empty()) { return true; } diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h index 93249b28a322..1e2c537eea4b 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h @@ -1118,6 +1118,13 @@ void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext, transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc); +void makeAnyContractionMatcher(MatcherContext &matcherContext, + StructuredOpMatcher *&dtmCapture, + StructuredOpMatcher *&fillCapture, + StructuredOpMatcher *&trailingCapture, + MatchedMatmulCaptures &captures, + bool mustMatchEntireFunc); + /// Create a group of matchers for a different code sequence of operations /// matching exactly a softmax operation. /// @@ -1133,6 +1140,7 @@ void makeSoftmaxMatcher(MatcherContext &context, struct MatchedConvolutionCaptures { Type inputElementType, filterElementType, outputElementType; mlir::linalg::detail::ConvolutionDimensions convolutionDims = {}; + SmallVector padOpSizes = {}; SmallVector convolutionOpSizes = {}; SmallVector trailingOpSizes = {}; int64_t maybeTrailingOutputElementalTypeBitWidth = 0; @@ -1149,6 +1157,7 @@ struct MatchedConvolutionCaptures { /// tileable operations in the functions are captured. void makeConvolutionMatcher(MatcherContext &context, StructuredOpMatcher *&convolutionCapture, + CapturingOpMatcher *&padCapture, StructuredOpMatcher *&fillCapture, StructuredOpMatcher *&trailingCapture, MatchedConvolutionCaptures &captures, diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp index ed3dfd68c4ec..4386a9ea86fe 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp @@ -696,10 +696,11 @@ convolutionCallback(transform_ext::MatchCallbackResult &res, Location loc, << "expected one handle to one operation"; } + transform_ext::CapturingOpMatcher *pad; transform_ext::StructuredOpMatcher *pattern, *fill, *trailing; transform_ext::MatchedConvolutionCaptures ignore; transform_ext::MatcherContext matcherContext; - makeConvolutionMatcher(matcherContext, pattern, fill, trailing, ignore, + makeConvolutionMatcher(matcherContext, pattern, pad, fill, trailing, ignore, /*mustMatchEntireFunc=*/true); // TODO: need a mechanism for this to go around the entire IR, @@ -713,6 +714,9 @@ convolutionCallback(transform_ext::MatchCallbackResult &res, Location loc, // TODO: notify properly. LLVM_DEBUG({ + DBGS() << "pad:\n"; + if (pad->getCaptured()) + DBGS() << pad->getCaptured() << "\n"; DBGS() << "fill:\n"; if (fill->getCaptured()) DBGS() << fill->getCaptured() << "\n"; @@ -722,6 +726,7 @@ convolutionCallback(transform_ext::MatchCallbackResult &res, Location loc, DBGS() << trailing->getCaptured() << "\n"; }); + res.addPotentiallyEmptyPayloadGroup(pad->getCaptured()); res.addPotentiallyEmptyPayloadGroup(fill->getCaptured()); res.addPayloadGroup({pattern->getCaptured()}); res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); @@ -908,6 +913,62 @@ batchMatmulCallback(transform_ext::MatchCallbackResult &res, Location loc, return emitSilenceableFailure(loc) << "failed to match batch matmul"; } +/// Match callback for linalg.batch_matmul and its linalg.generic equivalent fed +/// by a linalg.fill. +/// +/// Input handles: +/// +/// - the container op, must be associated with one operation. +/// +/// Output handles: +/// +/// - the fill op initializing the output; +/// - the main compute op. +static DiagnosedSilenceableFailure +anyContractionCallback(transform_ext::MatchCallbackResult &res, Location loc, + const mlir::transform::TransformState &state, + ValueRange handles) { + if (handles.size() != 1 || + !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { + return emitSilenceableFailure(loc) + << "expected one handle to one operation"; + } + + transform_ext::StructuredOpMatcher *pattern, *fill, *trailing; + transform_ext::MatchedMatmulCaptures ignore; + transform_ext::MatcherContext matcherContext; + transform_ext::makeAnyContractionMatcher(matcherContext, pattern, fill, + trailing, ignore, + /*mustMatchEntireFunc*/ true); + + // TODO: need a mechanism for this to go around the entire IR, + // potentially with list matches for each group. + Operation *root = *state.getPayloadOps(handles[0]).begin(); + + WalkResult walkResult = root->walk([&](Operation *op) { + pattern->resetCapture(); + if (!matchPattern(op, *pattern)) + return WalkResult::advance(); + + // TODO: notify properly + LLVM_DEBUG({ + DBGS() << "fill:" << fill->getCaptured() << "\n"; + DBGS() << "pattern: " << pattern->getCaptured() << "\n"; + if (trailing->getCaptured()) + DBGS() << "trailing:" << trailing->getCaptured() << "\n"; + }); + + res.addPayloadGroup({fill->getCaptured()}); + res.addPayloadGroup({pattern->getCaptured()}); + res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); + return WalkResult::interrupt(); + }); + + if (walkResult.wasInterrupted()) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableFailure(loc) << "failed to match batch matmul"; +} + /// Match callback for a tensor.pad. Matches *the first* occurrence of such pad /// within an op associated with the given handle. /// @@ -975,6 +1036,7 @@ DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply( registry.registerCallback("convolution", convolutionCallback); registry.registerCallback("matmul", matmulCallback); registry.registerCallback("batch_matmul", batchMatmulCallback); + registry.registerCallback("contraction", anyContractionCallback); registry.registerCallback("pad", wrapAsEntireFuncMatch(padCallback)); registry.registerCallback("reduction", wrapAsEntireFuncMatch(reductionCallback)); diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp index 6ed12a83110e..d7c0bc709aef 100644 --- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp +++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp @@ -1526,6 +1526,33 @@ void transform_ext::makeBatchMatmulMatcher( bmm = bmm.allTilableOpsCaptured(); } +void transform_ext::makeAnyContractionMatcher( + transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&dtmCapture, + transform_ext::StructuredOpMatcher *&fillCapture, + transform_ext::StructuredOpMatcher *&trailingCapture, + transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { + auto &dtm = + transform_ext::m_StructuredOp(matcherContext) + .contractionDims(CaptureContractionDims(captures.contractionDims)) + .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) + .input(NumEqualsTo(2)) + .input(0, CaptureElementType(captures.lhsElementType)) + .input(1, CaptureElementType(captures.rhsElementType)) + .output(0, CaptureElementType(captures.outputElementType)); + dtmCapture = &dtm; + + auto &fill = transform_ext::m_StructuredOp(matcherContext); + dtm = dtm.output(0, fill); + fillCapture = &fill; + + auto &trailing = m_StructuredOp(matcherContext); + dtm = dtm.result(0, HasAnyUse(), trailing, OptionalMatch()); + if (mustMatchEntireFunc) + dtm = dtm.allTilableOpsCaptured(); + trailingCapture = &trailing; +} + /// Match sum(%src, broadcast(%reduction)) static void matchSubBroadcast(transform_ext::MatcherContext &matcherContext, @@ -1704,13 +1731,14 @@ void transform_ext::makeSoftmaxMatcher( void transform_ext::makeConvolutionMatcher( transform_ext::MatcherContext &matcherContext, transform_ext::StructuredOpMatcher *&convolutionCapture, + transform_ext::CapturingOpMatcher *&padCapture, transform_ext::StructuredOpMatcher *&fillCapture, transform_ext::StructuredOpMatcher *&trailingCapture, MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { // The core part of the matcher is anchored on a particular convolution op. auto &convolution = - m_StructuredOp( - matcherContext) + m_StructuredOp(matcherContext) // Capture convolution dim classifications. .convolutionDims(CaptureConvDims(captures.convolutionDims)) // Capture op sizes. @@ -1721,6 +1749,15 @@ void transform_ext::makeConvolutionMatcher( .output(0, CaptureElementType(captures.outputElementType)); convolutionCapture = &convolution; + auto &value = transform_ext::m_ShapedValue(matcherContext); + value.dim(transform_ext::AllDims(), + transform_ext::CaptureDims(captures.padOpSizes)); + auto &pad = transform_ext::m_tensorPad(matcherContext) + .result(0, value) + .yieldsExternalValue(); + convolution = convolution.input(0, pad, OptionalMatch()); + padCapture = &pad; + // Optional FillOp to create the unique output of the convolution. auto &fill = m_StructuredOp(matcherContext) .output(0, CaptureElementTypeBitWidth( @@ -1757,10 +1794,11 @@ void transform_ext::makeConvolutionMatcher( transform_ext::MatcherContext &context, StructuredOpMatcher *&convolutionCapture, MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { + CapturingOpMatcher *pad; StructuredOpMatcher *fill; StructuredOpMatcher *trailing; - makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, - mustMatchEntireFunc); + makeConvolutionMatcher(context, convolutionCapture, pad, fill, trailing, + captures, mustMatchEntireFunc); } void transform_ext::makePadMatcher(MatcherContext &context,