From 01aed3cd940e586419c4966a41b1b9d44d8f73ee Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 13 Sep 2023 13:44:02 -0400 Subject: [PATCH] [Flow] Add pattern to reassociate dequantization + matmul `linalg.generic` ops Dequantization ops that are consumed by matmuls are currently only fused into a dispatch region, but we can do even better by reassociating these fused operations (see https://github.com/openxla/iree/issues/14951). It is important to note that this pattern does affect precision, and is a trade off between precision and performance. It is set to opt-in with `--iree-flow-enable-quantized-matmul-reassociation` This pattern rewrites a sequence of dequantization->matmul `linalg.generic` ops into a new sequence of `linalg.generic` ops. The new sequence of ops is as follows: 1. A sequence of `linalg.generic` ops that dynamically quantize the non-quantized input to the matmul. This is very cheap in skinny matmul cases, where the non-quantized input is small compared to the quantized input. 2. A `linalg.generic` op that performs an integer matmul. This is the key performance optimization here. On CPU, we want to be doing integer matmuls where we can, but the matmul needs to be picked up by a VectorContractCustomKernel for now. Eventually it will be better to rewrite to `linalg.matmul` here to target ukernels. 3. A final `linalg.generic` op that performs the dequantization scale and zero point math, as well as performing the remaining reduction of the matmul. The matmul from 2. only reduces within quantized groups, while this op does the reduction across groups. --- .../Dialect/Flow/Transforms/BUILD.bazel | 1 + .../Dialect/Flow/Transforms/CMakeLists.txt | 1 + .../Transforms/FuseDequantizationMatmul.cpp | 672 +++++++++++++++++- .../compiler/Dialect/Flow/Transforms/Passes.h | 4 +- .../test/fuse_dequantization_matmul.mlir | 2 +- .../compiler/GlobalOptimization/Passes.cpp | 13 +- 6 files changed, 676 insertions(+), 17 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 1fbf965b3c40..41642a62b472 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -83,6 +83,7 @@ iree_compiler_cc_library( ], deps = [ ":PassesIncGen", + "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect", "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 3d931217c0fa..90224e250f66 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -113,6 +113,7 @@ iree_cc_library( MLIRTransformDialectTransforms MLIRTransformUtils MLIRTransforms + iree::compiler::Codegen::Dialect::IREECodegenDialect iree::compiler::Dialect::Flow::Conversion::TensorToFlow iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp index 51494accd4ae..9b8a5ec362db 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp @@ -4,18 +4,24 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-flow-fuse-dequantization-matmul" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir { namespace iree_compiler { namespace IREE { @@ -34,7 +40,7 @@ static LogicalResult fuseDequantAndMatmul(RewriterBase &rewriter, Operation *dequant, Operation *matmul, std::optional fill) { - Flow::DispatchRegionOp regionOp = matmul->getParentOfType(); + DispatchRegionOp regionOp = matmul->getParentOfType(); if (!regionOp) { FailureOr maybeRegionOp = wrapOpInDispatchRegion(rewriter, matmul); @@ -58,6 +64,585 @@ static LogicalResult fuseDequantAndMatmul(RewriterBase &rewriter, return success(); } +static FailureOr +wrapConsecutiveOpsInDispatchRegion(RewriterBase &rewriter, + SmallVector ops) { + FailureOr maybeRegionOp = + wrapOpInDispatchRegion(rewriter, ops.back()); + if (failed(maybeRegionOp)) { + return failure(); + } + DispatchRegionOp regionOp = maybeRegionOp.value(); + + SmallVector precedingOps(ops.begin(), ops.end() - 1); + FailureOr maybeFusedRegionOp = + movePrecedingOpsIntoDispatchRegion(rewriter, precedingOps, regionOp); + if (failed(maybeFusedRegionOp)) { + return failure(); + } + regionOp = maybeFusedRegionOp.value(); + + return regionOp; +} + +static SmallVector +getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) { + SmallVector res(nLoops - nReduction, + utils::IteratorType::parallel); + res.append(nReduction, utils::IteratorType::reduction); + return res; +} + +// We set the tile sizes of the integer matmul result from the reassociated +// op sequence here to target a specific VectorContractCustomKernel and optimize +// performance on x86 with AVX512VNNI +static LogicalResult setTileSizes(linalg::GenericOp intMatmul, + linalg::GenericOp reassociation, + func::FuncOp entryPointFn) { + + SmallVector distTileSizes_mm = {128, 0, 0}; + SmallVector parallelTileSizes_mm = {4, 0, 0}; + SmallVector reductionTileSizes_mm = {4, 1, 16}; + SmallVector lastTileSizes_mm = {0, 0, 0}; + + TileSizesListType tileSizes_mm; + tileSizes_mm.push_back(distTileSizes_mm); + tileSizes_mm.push_back(parallelTileSizes_mm); + tileSizes_mm.push_back(reductionTileSizes_mm); + tileSizes_mm.push_back(lastTileSizes_mm); + + SmallVector distTileSizes_re = {128, 0}; + SmallVector parallelTileSizes_re = {4, 0}; + SmallVector reductionTileSizes_re = {0, 0}; + SmallVector lastTileSizes_re = {0, 0}; + + TileSizesListType tileSizes_re; + tileSizes_re.push_back(distTileSizes_re); + tileSizes_re.push_back(parallelTileSizes_re); + tileSizes_re.push_back(reductionTileSizes_re); + tileSizes_re.push_back(lastTileSizes_re); + + Codegen::DispatchLoweringPassPipeline passPipeline = + Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert; + + MLIRContext *context = entryPointFn.getContext(); + auto config_mm = Codegen::LoweringConfigAttr::get(context, tileSizes_mm); + intMatmul->setAttr("lowering_config", config_mm); + + auto config_re = Codegen::LoweringConfigAttr::get(context, tileSizes_re); + auto translationInfo_re = Codegen::TranslationInfoAttr::get( + entryPointFn.getContext(), passPipeline, 0, 1); + auto compilationInfo_re = + Codegen::CompilationInfoAttr::get(context, config_re, translationInfo_re, + ArrayRef({}), std::nullopt); + + reassociation->setAttr("compilation_info", compilationInfo_re); + + return success(); +} + +// Takes as input the dequantization `linalg.generic` op and the matmul +// `linalg.generic` op, and returns the scales, zero points, quantized +// input matrix, unquantizaed input matrix, and dequantized result +// matrix. +static std::optional> +getDequantMatmulInputs_f32(linalg::GenericOp dequant, + linalg::GenericOp matmul) { + OpOperand *scales, *zps, *quantMat, *unquantMat, *dequantMat; + for (int operandIdx = 0; operandIdx < dequant.getNumDpsInputs(); + operandIdx++) { + OpOperand *operand = dequant.getDpsInputOperand(operandIdx); + Value input = operand->get(); + RankedTensorType inputType = + llvm::dyn_cast(input.getType()); + if (!inputType) { + continue; + } + if (inputType.getElementTypeBitWidth() != 32) { + quantMat = operand; + continue; + } + for (Operation &bodyOp : dequant.getBlock()->getOperations()) { + if (isa(bodyOp)) { + if (bodyOp.getOperand(1) == + dequant.getBlock()->getArgument(operandIdx)) { + scales = operand; + break; + } + } else if (isa(bodyOp)) { + if (bodyOp.getOperand(1) == + dequant.getBlock()->getArgument(operandIdx)) { + zps = operand; + break; + } + } + } + } + Value dequantOut = dequant.getResult(0); + if (matmul.getDpsInputOperand(0)->get() == dequantOut) { + unquantMat = matmul.getDpsInputOperand(1); + dequantMat = matmul.getDpsInputOperand(0); + } else { + unquantMat = matmul.getDpsInputOperand(0); + dequantMat = matmul.getDpsInputOperand(1); + } + if (scales && zps && quantMat && unquantMat) { + return SmallVector( + {quantMat, unquantMat, scales, zps, dequantMat}); + } + return std::nullopt; +} + +// This function does the bulk of the rewrite for the dequantization + matmul. +// +// Starting with 2 `linalg.generic` ops (dequantization->matmul) +// %arg0 = quantized input +// %arg1 = scales +// %arg2 = zero points +// %arg3 = unquantized input +// ```mlir +// %0 = linalg.generic ins(%arg0, %arg1, %arg2 : tensor<8x4x2xi4>, +// tensor<8x4x1xf32>, tensor<8x4x1xf32>) outs(%1 : tensor<8x4x2xf32>) { +// ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32): +// %9 = arith.extui %in : i4 to i32 +// %10 = arith.uitofp %9 : i32 to f32 +// %11 = arith.subf %10, %in_1 : f32 +// %12 = arith.mulf %11, %in_0 : f32 +// linalg.yield %12 : f32 +// } -> tensor<8x4x2xf32> +// %2 = linalg.generic ins(%arg3, %0 : tensor<4x2xf32>, tensor<8x4x2xf32>) +// outs(%3 : tensor<8xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32): +// %9 = arith.mulf %in, %in_0 : f32 +// %10 = arith.addf %9, %out : f32 +// linalg.yield %10 : f32 +// } -> tensor<8xf32> +// ``` +// +// This function rewrites the above ops as the following new sequence of 6 ops +// that does the following: +// +// a) Dynamically quantize the unquantized input: +// 1. Compute the absolute max of the unquantized input (%arg3) within each +// group. +// 2. Compute scales for %arg3 by dividing the absolute max by (1 << +// newBitWidth) - 1), +// where newBitWidth is the bitwidth of the new quantized type, +// currently set to `i16`. +// 3. Compute the sum along groups of the unquantized input. This is not +// necessary for +// the quantization step, but it is needed to efficiently perform a +// reassociated quantized matmul in steps 5-6. +// 4. Quantize the unquantized input (%arg3) by dividing elements in each +// group by +// the corresponding scale. This quantization is symmetric with no zero +// point. +// b) Perform the reassociated quantized matmul, keeping the bulk of the +// computation in +// integer arithmetic: +// 5. The first op performs a matmul-like operation that reduces the +// innermost group +// dimension. Note the indexing maps in the following example: +// ```mlir +// %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> +// (d1, d2)>, +// affine_map<(d0, d1, d2) -> +// (d0, d1, d2)>, affine_map<(d0, +// d1, d2) -> (d0, d1)>], +// iterator_types = ["parallel", "parallel", +// "reduction"]} ins(%17, %0 : tensor<4x2xi16>, +// tensor<8x4x2xi4>) outs(%19 : tensor<8x4xi32>) { +// ^bb0(%in: i16, %in_4: i4, %out: i32): +// %24 = arith.extsi %in : i16 to i32 +// %25 = arith.extui %in_4 : i4 to i32 +// %26 = arith.muli %24, %25 : i32 +// %27 = arith.addi %26, %out : i32 +// linalg.yield %27 : i32 +// } -> tensor<8x4xi32> +// ``` +// This op also extends the inputs to the accumulation type, i32 in this +// case, to target specific x86 instructions. We perform the matrix +// multiplication before the dequantization arithmetic, which has been +// reassociated into op 6. +// 6. The final op performs the remaining reduction across groups and does +// the +// dequantization arithmetic: +// ```mlir +// %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, +// d1)>, +// affine_map<(d0, d1) -> (d1)>, +// affine_map<(d0, d1) -> (d1)>, +// affine_map<(d0, d1) -> (d0, +// d1)>, affine_map<(d0, d1) -> +// (d0, d1)>, affine_map<(d0, d1) +// -> (d0)>], +// iterator_types = ["parallel", "reduction"]} +// ins(tensor<8x4xi32>, tensor<4xf32>, +// tensor<4xf32>, tensor<8x4xf32>, +// tensor<8x4xf32>) outs(tensor<8xf32>) { +// ^bb0(%in: i32, %in_4: f32, %in_5: f32, %in_6: f32, %in_7: f32, +// %out: f32): +// %24 = arith.sitofp %in : i32 to f32 +// %25 = arith.mulf %24, %in_4 : f32 +// %26 = arith.mulf %25, %in_6 : f32 +// %27 = arith.mulf %in_7, %in_6 : f32 +// %28 = arith.mulf %27, %in_5 : f32 +// %29 = arith.subf %26, %28 : f32 +// %30 = arith.addf %29, %out : f32 +// linalg.yield %30 : f32 +// } -> tensor<8xf32> +// ``` +// +// The rewrite also forms a `flow.dispatch.region` op around ops 5 and 6, and +// sets a specific tiling on ops 5 and 6 to target a VectorContractCustomKernel. +// +// ** Note that this rewrite introduces precision loss in the matmul, and is a +// tradeoff between precision and performance. This rewrite should most +// likely be opt-in only. ** +static LogicalResult ReassociateAndFuseDequantMatmul( + RewriterBase &rewriter, linalg::GenericOp dequant, linalg::GenericOp matmul, + func::FuncOp entryPointFn) { + LDBG("Reassociating"); + LDBG("dequant: " << dequant); + LDBG("matmul: " << matmul); + std::optional> maybeInputs = + getDequantMatmulInputs_f32(dequant, matmul); + if (!maybeInputs) { + return failure(); + } + SmallVector ins = maybeInputs.value(); + OpOperand *quantInOperand = ins[0]; + OpOperand *unquantInOperand = ins[1]; + Value quantIn = quantInOperand->get(); + Value unquantIn = unquantInOperand->get(); + Value scales = ins[2]->get(); + Value zps = ins[3]->get(); + OpOperand *matmulDequantOperand = ins[4]; + RankedTensorType unquantInType = + llvm::dyn_cast(unquantIn.getType()); + if (!unquantInType) { + return failure(); + } + RankedTensorType quantInType = + llvm::dyn_cast(quantIn.getType()); + if (!quantInType) { + return failure(); + } + OpOperand *matmulOutputOperand = matmul.getDpsInitOperand(0); + Value matmulOutput = matmulOutputOperand->get(); + RankedTensorType matmulOutputType = + llvm::dyn_cast(matmulOutput.getType()); + SmallVector matmulOutShape(matmulOutputType.getShape()); + SmallVector unquantInShape(unquantInType.getShape()); + SmallVector quantInShape(quantInType.getShape()); + SmallVector dequantIndexingMaps = dequant.getIndexingMapsArray(); + SmallVector matmulIndexingMaps = matmul.getIndexingMapsArray(); + SmallVector dequantIteratorTypes = + dequant.getIteratorTypesArray(); + SmallVector matmulIteratorTypes = + matmul.getIteratorTypesArray(); + FloatType f32Type = rewriter.getF32Type(); + IntegerType i32Type = rewriter.getI32Type(); + // Type for accumulation of integer matmul + IntegerType accType = rewriter.getI32Type(); + // Type for dynamic quantization of unquantized input + IntegerType quantType = rewriter.getI16Type(); + Type srcQuantType = quantInType.getElementType(); + // Type for multiplication in integer matmul (should probably be same as + // `accType`) + IntegerType mulType = rewriter.getI32Type(); + unsigned quantBitRange = + std::min(quantType.getIntOrFloatBitWidth() - 1, + mulType.getIntOrFloatBitWidth() - + srcQuantType.getIntOrFloatBitWidth() - 1); + + // ----- Quantize unquantized input ----- // + Value cst = rewriter.create( + dequant.getLoc(), rewriter.getF32FloatAttr((1 << quantBitRange) - 1)); + LDBG("cst: " << cst); + Value zeroF32cst = rewriter.create( + dequant.getLoc(), rewriter.getF32FloatAttr(0)); + Value zeroI32cst = rewriter.create( + dequant.getLoc(), rewriter.getI32IntegerAttr(0)); + // Generic to find max along groups + SmallVector groupMaxShape; + SmallVector groupMaxIterators; + int64_t numGroups = 0; + + SmallVector exprs; + AffineMap indexingMap = matmul.getMatchingIndexingMap(unquantInOperand); + for (const auto &expr : enumerate(indexingMap.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (matmulIteratorTypes[dimExpr.getPosition()] == + utils::IteratorType::parallel || + dimExpr.getPosition() != indexingMap.getNumDims() - 1) { + groupMaxIterators.push_back(utils::IteratorType::parallel); + groupMaxShape.push_back(unquantInShape[expr.index()]); + exprs.push_back(rewriter.getAffineDimExpr(groupMaxShape.size() - 1)); + if (matmulIteratorTypes[dimExpr.getPosition()] == + utils::IteratorType::reduction) { + numGroups = unquantInShape[expr.index()]; + } + } else { + groupMaxIterators.push_back(utils::IteratorType::reduction); + } + } else { + return failure(); + } + } + if (!numGroups) { + return failure(); + } + RankedTensorType groupMaxType = + RankedTensorType::get(groupMaxShape, unquantInType.getElementType()); + Value groupMaxEmpty = rewriter.create( + dequant.getLoc(), groupMaxType.getShape(), groupMaxType.getElementType()); + Value groupMaxOut = + rewriter + .create(dequant.getLoc(), zeroF32cst, groupMaxEmpty) + .result(); + LDBG("groupMaxOut: " << groupMaxOut); + SmallVector groupMaxMaps; + groupMaxMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + groupMaxMaps.push_back(AffineMap::get(unquantInShape.size(), 0, exprs, + exprs.front().getContext())); + auto groupMaxOp = rewriter.create( + dequant.getLoc(), groupMaxOut.getType(), unquantIn, groupMaxOut, + groupMaxMaps, groupMaxIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value abs = b.create(loc, args[0]); + Value max = b.create(loc, abs, args[1]); + b.create(loc, max); + }); + LDBG("groupMaxOp: " << groupMaxOp); + Value groupMax = groupMaxOp.getResult(0); + + // Generic to find scales + RankedTensorType unquantInScalesType = groupMaxType; + Value unquantInScalesOut = rewriter.create( + dequant.getLoc(), unquantInScalesType.getShape(), + unquantInScalesType.getElementType()); + LDBG("unquantInScalesOut: " << unquantInScalesOut); + SmallVector unquantInScalesMaps; + unquantInScalesMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size() - 1)); + unquantInScalesMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size() - 1)); + + auto unquantInScalesOp = rewriter.create( + dequant.getLoc(), unquantInScalesOut.getType(), groupMax, + unquantInScalesOut, unquantInScalesMaps, + getParallelAndReductionIterators(unquantInScalesType.getRank(), 0), + [&](OpBuilder &b, Location loc, ValueRange args) { + Value scale = b.create(loc, args[0], cst); + b.create(loc, scale); + }); + LDBG("unquantInScalesOp: " << unquantInScalesOp); + Value unquantInScales = unquantInScalesOp.getResult(0); + + // Generic to find scaled sums + RankedTensorType scaledSumsType = groupMaxType; + Value scaledSumsEmpty = rewriter.create( + dequant.getLoc(), scaledSumsType.getShape(), + scaledSumsType.getElementType()); + Value scaledSumsOut = + rewriter + .create(dequant.getLoc(), zeroF32cst, scaledSumsEmpty) + .result(); + LDBG("scaledSumsOut: " << scaledSumsOut); + SmallVector scaledSumsMaps; + scaledSumsMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + scaledSumsMaps.push_back(AffineMap::get(unquantInShape.size(), 0, exprs, + exprs.front().getContext())); + SmallVector scaledSumsIterators = groupMaxIterators; + auto scaledSumsOp = rewriter.create( + dequant.getLoc(), scaledSumsOut.getType(), unquantIn, scaledSumsOut, + scaledSumsMaps, scaledSumsIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value sum = b.create(loc, args[0], args[1]); + b.create(loc, sum); + }); + LDBG("scaledSumsOp: " << scaledSumsOp); + Value scaledSums = scaledSumsOp.getResult(0); + + // Generic to quantized the unquantized input + Value newQuantInOut = rewriter.create( + dequant.getLoc(), unquantInShape, quantType); + LDBG("newQuantInOut: " << newQuantInOut); + SmallVector newQuantInMaps; + newQuantInMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + newQuantInMaps.push_back(AffineMap::get(unquantInShape.size(), 0, exprs, + exprs.front().getContext())); + newQuantInMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + auto newQuantInOp = rewriter.create( + dequant.getLoc(), newQuantInOut.getType(), + ValueRange{unquantIn, unquantInScales}, newQuantInOut, newQuantInMaps, + getParallelAndReductionIterators(unquantInShape.size(), 0), + [&](OpBuilder &b, Location loc, ValueRange args) { + Value scaled = b.create(loc, args[0], args[1]); + Value quant = b.create(loc, quantType, scaled); + b.create(loc, quant); + }); + LDBG("newQuantInOp: " << newQuantInOp); + Value newQuantIn = newQuantInOp.getResult(0); + + // ----- Reassociated dequantization matmul ----- // + + // Generic to perform integer matmul and reduce within groups + SmallVector integerMatmulShape = matmulOutShape; + integerMatmulShape.push_back(numGroups); + Value integerMatmulEmpty = rewriter.create( + dequant.getLoc(), integerMatmulShape, accType); + Value integerMatmulOut = + rewriter + .create(dequant.getLoc(), zeroI32cst, + integerMatmulEmpty) + .result(); + LDBG("integerMatmulOut: " << integerMatmulOut); + SmallVector integerMatmulIterators = + getParallelAndReductionIterators(matmul.getNumLoops(), 1); + SmallVector integerMatmulMaps; + integerMatmulMaps.push_back(matmul.getMatchingIndexingMap(unquantInOperand)); + integerMatmulMaps.push_back( + matmul.getMatchingIndexingMap(matmulDequantOperand)); + SmallVector outputExprs( + matmul.getMatchingIndexingMap(matmulOutputOperand).getResults()); + outputExprs.push_back(rewriter.getAffineDimExpr(matmul.getNumLoops() - 2)); + integerMatmulMaps.push_back(AffineMap::get(integerMatmulIterators.size(), 0, + outputExprs, + outputExprs.front().getContext())); + auto integerMatmulOp = rewriter.create( + dequant.getLoc(), integerMatmulOut.getType(), + ValueRange{newQuantIn, quantIn}, integerMatmulOut, integerMatmulMaps, + integerMatmulIterators, [&](OpBuilder &b, Location loc, ValueRange args) { + Value mul; + if (quantType == mulType) { + Value ext1 = b.create(loc, mulType, args[1]); + mul = b.create(loc, args[0], ext1); + } else { + Value ext0 = b.create(loc, mulType, args[0]); + Value ext1 = b.create(loc, mulType, args[1]); + mul = b.create(loc, ext0, ext1); + } + Value sum; + if (mulType == accType) { + sum = b.create(loc, mul, args[2]); + } else { + Value extMul = b.create(loc, accType, mul); + sum = b.create(loc, extMul, args[2]); + } + b.create(loc, sum); + }); + LDBG("integerMatmulOp: " << integerMatmulOp); + Value integerMatmul = integerMatmulOp.getResult(0); + + // Generic to perform dequantization and finish reduction + SmallVector dequantizedMatmulIterators = + getParallelAndReductionIterators(matmul.getNumLoops() - 1, 1); + SmallVector dequantizedMatmulMaps; + dequantizedMatmulMaps.push_back( + rewriter.getMultiDimIdentityMap(dequantizedMatmulIterators.size())); + AffineMap intMatmulNewQuantMap = + matmul.getMatchingIndexingMap(unquantInOperand); + SmallVector newQuantScalesExprs; + for (const auto &expr : enumerate(intMatmulNewQuantMap.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (dimExpr.getPosition() != intMatmulNewQuantMap.getNumDims() - 1) { + newQuantScalesExprs.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition())); + } + } else { + return failure(); + } + } + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, newQuantScalesExprs, + newQuantScalesExprs.front().getContext())); + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, newQuantScalesExprs, + newQuantScalesExprs.front().getContext())); + AffineMap matmulQuantInMap = + matmul.getMatchingIndexingMap(matmulDequantOperand); + SmallVector quantScalesExprs; + for (const auto &expr : enumerate(matmulQuantInMap.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (dimExpr.getPosition() != matmulQuantInMap.getNumDims() - 1) { + quantScalesExprs.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition())); + } + } else { + return failure(); + } + } + RankedTensorType scalesType = + llvm::dyn_cast(scales.getType()); + if (!scalesType) { + return failure(); + } + if (quantScalesExprs.size() < scalesType.getShape().size()) { + quantScalesExprs.push_back(rewriter.getAffineConstantExpr(0)); + if (quantScalesExprs.size() < scalesType.getShape().size()) { + return failure(); + } + } + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, quantScalesExprs, + quantScalesExprs.front().getContext())); + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, quantScalesExprs, + quantScalesExprs.front().getContext())); + SmallVector finalOutputExprs( + matmul.getMatchingIndexingMap(matmulOutputOperand).getResults()); + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, finalOutputExprs, + finalOutputExprs.front().getContext())); + + auto dequantizedMatmulOp = rewriter.create( + dequant.getLoc(), matmulOutput.getType(), + ValueRange{integerMatmul, unquantInScales, scaledSums, scales, zps}, + matmulOutput, dequantizedMatmulMaps, dequantizedMatmulIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dq; + if (accType == i32Type) { + dq = b.create(loc, f32Type, args[0]); + } else { + Value ext = b.create(loc, i32Type, args[0]); + dq = b.create(loc, f32Type, ext); + } + Value scaledRes0 = b.create(loc, dq, args[1]); + Value scaledRes1 = b.create(loc, scaledRes0, args[3]); + Value scaledZp0 = b.create(loc, args[4], args[3]); + Value scaledZp1 = b.create(loc, scaledZp0, args[2]); + Value groupRes = b.create(loc, scaledRes1, scaledZp1); + Value sum = b.create(loc, groupRes, args[5]); + b.create(loc, sum); + }); + LDBG("dequantizedMatmulOp: " << dequantizedMatmulOp); + Value dequantizedMatmul = dequantizedMatmulOp.getResult(0); + + rewriter.replaceOp(matmul, dequantizedMatmul); + + // Set tile sizes for dequantization + matmul ops + if (failed( + setTileSizes(integerMatmulOp, dequantizedMatmulOp, entryPointFn))) { + return failure(); + } + + // Fuse dequantization + matmul ops into a single dispatch region + SmallVector dequantMatmulOps( + {integerMatmulOp, dequantizedMatmulOp}); + FailureOr maybeDequantMatmulDispatch = + wrapConsecutiveOpsInDispatchRegion(rewriter, dequantMatmulOps); + if (failed(maybeDequantMatmulDispatch)) { + return failure(); + } + + return success(); +} + // Checks if the passed op is a contraction on grouped input // This function checks that the genericOp: // 1. isaContractionOpInterface @@ -172,6 +757,8 @@ static LogicalResult isGroupedDequantizationOp(linalg::GenericOp genericOp) { // Patterns //----------------------------------------------------------------------------// +// This pattern does a basic fusion of dequantization + matmul `linalg.generic` +// ops, moving them into a single `flow.dispatch.region` op. class FuseDequantizationMatmulPattern final : public OpRewritePattern { public: @@ -222,27 +809,84 @@ struct FuseDequantizationMatmulPass : public FuseDequantizationMatmulBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); + } + FuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) { + this->enableQuantizedMatmulReassociation = + enableQuantizedMatmulReassociation; } + FuseDequantizationMatmulPass(const FuseDequantizationMatmulPass &pass) + : FuseDequantizationMatmulPass(pass.enableQuantizedMatmulReassociation) {} - void runOnOperation() override { - MLIRContext *context = &getContext(); - // Main pattern. - { - RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + void runOnOperation() override; +}; + +} // namespace + +void FuseDequantizationMatmulPass::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + // Perform reassociation if enabled + if (this->enableQuantizedMatmulReassociation) { + SmallVector> candidates; + for (auto genericOp : + funcOp.getFunctionBody().getOps()) { + if (failed(isGroupedContractionOp(genericOp))) + continue; + + OpOperand *lhs = genericOp.getDpsInputOperand(0); + OpOperand *rhs = genericOp.getDpsInputOperand(1); + auto lhsOp = lhs->get().getDefiningOp(); + auto rhsOp = rhs->get().getDefiningOp(); + if (!llvm::cast(genericOp.getInputs()[0].getType()) + .hasStaticShape() || + !llvm::cast(genericOp.getInputs()[1].getType()) + .hasStaticShape() || + !llvm::cast(genericOp.getResults()[0].getType()) + .hasStaticShape()) { + // Codegen can't handle the dynamic case yet. + continue; + } + if (lhsOp) { + if (!failed(isGroupedDequantizationOp(lhsOp))) { + candidates.push_back(std::make_pair(lhsOp, genericOp)); + continue; + } + } + if (rhsOp) { + if (!failed(isGroupedDequantizationOp(rhsOp))) { + candidates.push_back(std::make_pair(rhsOp, genericOp)); + } + } + } + IRRewriter rewriter(context); + for (auto candidate : candidates) { + rewriter.setInsertionPointAfter(candidate.second); + if (failed(ReassociateAndFuseDequantMatmul( + rewriter, candidate.first, candidate.second, + llvm::cast(funcOp)))) { return signalPassFailure(); } } } -}; -} // namespace + // Normal fusion pattern. + { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +} -std::unique_ptr createFuseDequantizationMatmulPass() { - return std::make_unique(); +std::unique_ptr> +createFuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) { + return std::make_unique( + enableQuantizedMatmulReassociation); } } // namespace Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h index bbfc494b7452..36525302c08d 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -175,7 +175,9 @@ std::unique_ptr> createCloneProducersIntoDispatchRegionsPass(); // A pass to fuse dequantization and matmul linalg.generic ops -std::unique_ptr createFuseDequantizationMatmulPass(); +std::unique_ptr> +createFuseDequantizationMatmulPass( + bool enableQuantizedMatmulReassociation = false); //===----------------------------------------------------------------------===// // Dispatches (flow.dispatch.workgroups) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir index 2537efd5b1a9..567cf1652b29 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-flow-fuse-dequantization-matmul --canonicalize %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-fuse-dequantization-matmul,canonicalize))" %s | FileCheck %s module { func.func @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> { diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 472f7bd00e11..d32416b0eb20 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -19,6 +19,12 @@ namespace GlobalOptimization { using FunctionLikeNest = MultiOpNest; +static llvm::cl::opt clEnableQuantizedMatmulReassociation( + "iree-flow-enable-quantized-matmul-reassociation", + llvm::cl::desc( + "Enables reassociation of quantized matmul ops (experimental)."), + llvm::cl::init(false)); + void buildGlobalOptimizationPassPipeline( OpPassManager &mainPassManager, const TransformOptions &transformOptions) { // ML frontends have very uneven support for user-controlled types _and_ users @@ -67,7 +73,12 @@ void buildGlobalOptimizationPassPipeline( // this pass both before unit dim folding + consteval, as well as after. .addPass(IREE::Flow::createRaiseSpecialOps) .addPass(IREE::Flow::createFoldUnitExtentDimsPass) - .addPass(IREE::Flow::createFuseDequantizationMatmulPass) + .addPass([&]() { + return IREE::Flow::createFuseDequantizationMatmulPass( + clEnableQuantizedMatmulReassociation); + }) + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) // Enable data tiling after they are in a canonical form. .addPredicatedPass(transformOptions.options.dataTiling, IREE::Flow::createSetEncodingPass)