diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 1b3fa5e176f9..dae112e47b45 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -78,6 +78,7 @@ iree_compiler_cc_library( ], deps = [ ":PassesIncGen", + "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect", "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index e90ef1bee6ae..3aaa01ac6efd 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -109,6 +109,7 @@ iree_cc_library( MLIRTransformDialectTransforms MLIRTransformUtils MLIRTransforms + iree::compiler::Codegen::Dialect::IREECodegenDialect iree::compiler::Dialect::Flow::Conversion::TensorToFlow iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp index 51494accd4ae..9b8a5ec362db 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp @@ -4,18 +4,24 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" #include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-flow-fuse-dequantization-matmul" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir { namespace iree_compiler { namespace IREE { @@ -34,7 +40,7 @@ static LogicalResult fuseDequantAndMatmul(RewriterBase &rewriter, Operation *dequant, Operation *matmul, std::optional fill) { - Flow::DispatchRegionOp regionOp = matmul->getParentOfType(); + DispatchRegionOp regionOp = matmul->getParentOfType(); if (!regionOp) { FailureOr maybeRegionOp = wrapOpInDispatchRegion(rewriter, matmul); @@ -58,6 +64,585 @@ static LogicalResult fuseDequantAndMatmul(RewriterBase &rewriter, return success(); } +static FailureOr +wrapConsecutiveOpsInDispatchRegion(RewriterBase &rewriter, + SmallVector ops) { + FailureOr maybeRegionOp = + wrapOpInDispatchRegion(rewriter, ops.back()); + if (failed(maybeRegionOp)) { + return failure(); + } + DispatchRegionOp regionOp = maybeRegionOp.value(); + + SmallVector precedingOps(ops.begin(), ops.end() - 1); + FailureOr maybeFusedRegionOp = + movePrecedingOpsIntoDispatchRegion(rewriter, precedingOps, regionOp); + if (failed(maybeFusedRegionOp)) { + return failure(); + } + regionOp = maybeFusedRegionOp.value(); + + return regionOp; +} + +static SmallVector +getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) { + SmallVector res(nLoops - nReduction, + utils::IteratorType::parallel); + res.append(nReduction, utils::IteratorType::reduction); + return res; +} + +// We set the tile sizes of the integer matmul result from the reassociated +// op sequence here to target a specific VectorContractCustomKernel and optimize +// performance on x86 with AVX512VNNI +static LogicalResult setTileSizes(linalg::GenericOp intMatmul, + linalg::GenericOp reassociation, + func::FuncOp entryPointFn) { + + SmallVector distTileSizes_mm = {128, 0, 0}; + SmallVector parallelTileSizes_mm = {4, 0, 0}; + SmallVector reductionTileSizes_mm = {4, 1, 16}; + SmallVector lastTileSizes_mm = {0, 0, 0}; + + TileSizesListType tileSizes_mm; + tileSizes_mm.push_back(distTileSizes_mm); + tileSizes_mm.push_back(parallelTileSizes_mm); + tileSizes_mm.push_back(reductionTileSizes_mm); + tileSizes_mm.push_back(lastTileSizes_mm); + + SmallVector distTileSizes_re = {128, 0}; + SmallVector parallelTileSizes_re = {4, 0}; + SmallVector reductionTileSizes_re = {0, 0}; + SmallVector lastTileSizes_re = {0, 0}; + + TileSizesListType tileSizes_re; + tileSizes_re.push_back(distTileSizes_re); + tileSizes_re.push_back(parallelTileSizes_re); + tileSizes_re.push_back(reductionTileSizes_re); + tileSizes_re.push_back(lastTileSizes_re); + + Codegen::DispatchLoweringPassPipeline passPipeline = + Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert; + + MLIRContext *context = entryPointFn.getContext(); + auto config_mm = Codegen::LoweringConfigAttr::get(context, tileSizes_mm); + intMatmul->setAttr("lowering_config", config_mm); + + auto config_re = Codegen::LoweringConfigAttr::get(context, tileSizes_re); + auto translationInfo_re = Codegen::TranslationInfoAttr::get( + entryPointFn.getContext(), passPipeline, 0, 1); + auto compilationInfo_re = + Codegen::CompilationInfoAttr::get(context, config_re, translationInfo_re, + ArrayRef({}), std::nullopt); + + reassociation->setAttr("compilation_info", compilationInfo_re); + + return success(); +} + +// Takes as input the dequantization `linalg.generic` op and the matmul +// `linalg.generic` op, and returns the scales, zero points, quantized +// input matrix, unquantizaed input matrix, and dequantized result +// matrix. +static std::optional> +getDequantMatmulInputs_f32(linalg::GenericOp dequant, + linalg::GenericOp matmul) { + OpOperand *scales, *zps, *quantMat, *unquantMat, *dequantMat; + for (int operandIdx = 0; operandIdx < dequant.getNumDpsInputs(); + operandIdx++) { + OpOperand *operand = dequant.getDpsInputOperand(operandIdx); + Value input = operand->get(); + RankedTensorType inputType = + llvm::dyn_cast(input.getType()); + if (!inputType) { + continue; + } + if (inputType.getElementTypeBitWidth() != 32) { + quantMat = operand; + continue; + } + for (Operation &bodyOp : dequant.getBlock()->getOperations()) { + if (isa(bodyOp)) { + if (bodyOp.getOperand(1) == + dequant.getBlock()->getArgument(operandIdx)) { + scales = operand; + break; + } + } else if (isa(bodyOp)) { + if (bodyOp.getOperand(1) == + dequant.getBlock()->getArgument(operandIdx)) { + zps = operand; + break; + } + } + } + } + Value dequantOut = dequant.getResult(0); + if (matmul.getDpsInputOperand(0)->get() == dequantOut) { + unquantMat = matmul.getDpsInputOperand(1); + dequantMat = matmul.getDpsInputOperand(0); + } else { + unquantMat = matmul.getDpsInputOperand(0); + dequantMat = matmul.getDpsInputOperand(1); + } + if (scales && zps && quantMat && unquantMat) { + return SmallVector( + {quantMat, unquantMat, scales, zps, dequantMat}); + } + return std::nullopt; +} + +// This function does the bulk of the rewrite for the dequantization + matmul. +// +// Starting with 2 `linalg.generic` ops (dequantization->matmul) +// %arg0 = quantized input +// %arg1 = scales +// %arg2 = zero points +// %arg3 = unquantized input +// ```mlir +// %0 = linalg.generic ins(%arg0, %arg1, %arg2 : tensor<8x4x2xi4>, +// tensor<8x4x1xf32>, tensor<8x4x1xf32>) outs(%1 : tensor<8x4x2xf32>) { +// ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32): +// %9 = arith.extui %in : i4 to i32 +// %10 = arith.uitofp %9 : i32 to f32 +// %11 = arith.subf %10, %in_1 : f32 +// %12 = arith.mulf %11, %in_0 : f32 +// linalg.yield %12 : f32 +// } -> tensor<8x4x2xf32> +// %2 = linalg.generic ins(%arg3, %0 : tensor<4x2xf32>, tensor<8x4x2xf32>) +// outs(%3 : tensor<8xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32): +// %9 = arith.mulf %in, %in_0 : f32 +// %10 = arith.addf %9, %out : f32 +// linalg.yield %10 : f32 +// } -> tensor<8xf32> +// ``` +// +// This function rewrites the above ops as the following new sequence of 6 ops +// that does the following: +// +// a) Dynamically quantize the unquantized input: +// 1. Compute the absolute max of the unquantized input (%arg3) within each +// group. +// 2. Compute scales for %arg3 by dividing the absolute max by (1 << +// newBitWidth) - 1), +// where newBitWidth is the bitwidth of the new quantized type, +// currently set to `i16`. +// 3. Compute the sum along groups of the unquantized input. This is not +// necessary for +// the quantization step, but it is needed to efficiently perform a +// reassociated quantized matmul in steps 5-6. +// 4. Quantize the unquantized input (%arg3) by dividing elements in each +// group by +// the corresponding scale. This quantization is symmetric with no zero +// point. +// b) Perform the reassociated quantized matmul, keeping the bulk of the +// computation in +// integer arithmetic: +// 5. The first op performs a matmul-like operation that reduces the +// innermost group +// dimension. Note the indexing maps in the following example: +// ```mlir +// %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> +// (d1, d2)>, +// affine_map<(d0, d1, d2) -> +// (d0, d1, d2)>, affine_map<(d0, +// d1, d2) -> (d0, d1)>], +// iterator_types = ["parallel", "parallel", +// "reduction"]} ins(%17, %0 : tensor<4x2xi16>, +// tensor<8x4x2xi4>) outs(%19 : tensor<8x4xi32>) { +// ^bb0(%in: i16, %in_4: i4, %out: i32): +// %24 = arith.extsi %in : i16 to i32 +// %25 = arith.extui %in_4 : i4 to i32 +// %26 = arith.muli %24, %25 : i32 +// %27 = arith.addi %26, %out : i32 +// linalg.yield %27 : i32 +// } -> tensor<8x4xi32> +// ``` +// This op also extends the inputs to the accumulation type, i32 in this +// case, to target specific x86 instructions. We perform the matrix +// multiplication before the dequantization arithmetic, which has been +// reassociated into op 6. +// 6. The final op performs the remaining reduction across groups and does +// the +// dequantization arithmetic: +// ```mlir +// %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, +// d1)>, +// affine_map<(d0, d1) -> (d1)>, +// affine_map<(d0, d1) -> (d1)>, +// affine_map<(d0, d1) -> (d0, +// d1)>, affine_map<(d0, d1) -> +// (d0, d1)>, affine_map<(d0, d1) +// -> (d0)>], +// iterator_types = ["parallel", "reduction"]} +// ins(tensor<8x4xi32>, tensor<4xf32>, +// tensor<4xf32>, tensor<8x4xf32>, +// tensor<8x4xf32>) outs(tensor<8xf32>) { +// ^bb0(%in: i32, %in_4: f32, %in_5: f32, %in_6: f32, %in_7: f32, +// %out: f32): +// %24 = arith.sitofp %in : i32 to f32 +// %25 = arith.mulf %24, %in_4 : f32 +// %26 = arith.mulf %25, %in_6 : f32 +// %27 = arith.mulf %in_7, %in_6 : f32 +// %28 = arith.mulf %27, %in_5 : f32 +// %29 = arith.subf %26, %28 : f32 +// %30 = arith.addf %29, %out : f32 +// linalg.yield %30 : f32 +// } -> tensor<8xf32> +// ``` +// +// The rewrite also forms a `flow.dispatch.region` op around ops 5 and 6, and +// sets a specific tiling on ops 5 and 6 to target a VectorContractCustomKernel. +// +// ** Note that this rewrite introduces precision loss in the matmul, and is a +// tradeoff between precision and performance. This rewrite should most +// likely be opt-in only. ** +static LogicalResult ReassociateAndFuseDequantMatmul( + RewriterBase &rewriter, linalg::GenericOp dequant, linalg::GenericOp matmul, + func::FuncOp entryPointFn) { + LDBG("Reassociating"); + LDBG("dequant: " << dequant); + LDBG("matmul: " << matmul); + std::optional> maybeInputs = + getDequantMatmulInputs_f32(dequant, matmul); + if (!maybeInputs) { + return failure(); + } + SmallVector ins = maybeInputs.value(); + OpOperand *quantInOperand = ins[0]; + OpOperand *unquantInOperand = ins[1]; + Value quantIn = quantInOperand->get(); + Value unquantIn = unquantInOperand->get(); + Value scales = ins[2]->get(); + Value zps = ins[3]->get(); + OpOperand *matmulDequantOperand = ins[4]; + RankedTensorType unquantInType = + llvm::dyn_cast(unquantIn.getType()); + if (!unquantInType) { + return failure(); + } + RankedTensorType quantInType = + llvm::dyn_cast(quantIn.getType()); + if (!quantInType) { + return failure(); + } + OpOperand *matmulOutputOperand = matmul.getDpsInitOperand(0); + Value matmulOutput = matmulOutputOperand->get(); + RankedTensorType matmulOutputType = + llvm::dyn_cast(matmulOutput.getType()); + SmallVector matmulOutShape(matmulOutputType.getShape()); + SmallVector unquantInShape(unquantInType.getShape()); + SmallVector quantInShape(quantInType.getShape()); + SmallVector dequantIndexingMaps = dequant.getIndexingMapsArray(); + SmallVector matmulIndexingMaps = matmul.getIndexingMapsArray(); + SmallVector dequantIteratorTypes = + dequant.getIteratorTypesArray(); + SmallVector matmulIteratorTypes = + matmul.getIteratorTypesArray(); + FloatType f32Type = rewriter.getF32Type(); + IntegerType i32Type = rewriter.getI32Type(); + // Type for accumulation of integer matmul + IntegerType accType = rewriter.getI32Type(); + // Type for dynamic quantization of unquantized input + IntegerType quantType = rewriter.getI16Type(); + Type srcQuantType = quantInType.getElementType(); + // Type for multiplication in integer matmul (should probably be same as + // `accType`) + IntegerType mulType = rewriter.getI32Type(); + unsigned quantBitRange = + std::min(quantType.getIntOrFloatBitWidth() - 1, + mulType.getIntOrFloatBitWidth() - + srcQuantType.getIntOrFloatBitWidth() - 1); + + // ----- Quantize unquantized input ----- // + Value cst = rewriter.create( + dequant.getLoc(), rewriter.getF32FloatAttr((1 << quantBitRange) - 1)); + LDBG("cst: " << cst); + Value zeroF32cst = rewriter.create( + dequant.getLoc(), rewriter.getF32FloatAttr(0)); + Value zeroI32cst = rewriter.create( + dequant.getLoc(), rewriter.getI32IntegerAttr(0)); + // Generic to find max along groups + SmallVector groupMaxShape; + SmallVector groupMaxIterators; + int64_t numGroups = 0; + + SmallVector exprs; + AffineMap indexingMap = matmul.getMatchingIndexingMap(unquantInOperand); + for (const auto &expr : enumerate(indexingMap.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (matmulIteratorTypes[dimExpr.getPosition()] == + utils::IteratorType::parallel || + dimExpr.getPosition() != indexingMap.getNumDims() - 1) { + groupMaxIterators.push_back(utils::IteratorType::parallel); + groupMaxShape.push_back(unquantInShape[expr.index()]); + exprs.push_back(rewriter.getAffineDimExpr(groupMaxShape.size() - 1)); + if (matmulIteratorTypes[dimExpr.getPosition()] == + utils::IteratorType::reduction) { + numGroups = unquantInShape[expr.index()]; + } + } else { + groupMaxIterators.push_back(utils::IteratorType::reduction); + } + } else { + return failure(); + } + } + if (!numGroups) { + return failure(); + } + RankedTensorType groupMaxType = + RankedTensorType::get(groupMaxShape, unquantInType.getElementType()); + Value groupMaxEmpty = rewriter.create( + dequant.getLoc(), groupMaxType.getShape(), groupMaxType.getElementType()); + Value groupMaxOut = + rewriter + .create(dequant.getLoc(), zeroF32cst, groupMaxEmpty) + .result(); + LDBG("groupMaxOut: " << groupMaxOut); + SmallVector groupMaxMaps; + groupMaxMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + groupMaxMaps.push_back(AffineMap::get(unquantInShape.size(), 0, exprs, + exprs.front().getContext())); + auto groupMaxOp = rewriter.create( + dequant.getLoc(), groupMaxOut.getType(), unquantIn, groupMaxOut, + groupMaxMaps, groupMaxIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value abs = b.create(loc, args[0]); + Value max = b.create(loc, abs, args[1]); + b.create(loc, max); + }); + LDBG("groupMaxOp: " << groupMaxOp); + Value groupMax = groupMaxOp.getResult(0); + + // Generic to find scales + RankedTensorType unquantInScalesType = groupMaxType; + Value unquantInScalesOut = rewriter.create( + dequant.getLoc(), unquantInScalesType.getShape(), + unquantInScalesType.getElementType()); + LDBG("unquantInScalesOut: " << unquantInScalesOut); + SmallVector unquantInScalesMaps; + unquantInScalesMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size() - 1)); + unquantInScalesMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size() - 1)); + + auto unquantInScalesOp = rewriter.create( + dequant.getLoc(), unquantInScalesOut.getType(), groupMax, + unquantInScalesOut, unquantInScalesMaps, + getParallelAndReductionIterators(unquantInScalesType.getRank(), 0), + [&](OpBuilder &b, Location loc, ValueRange args) { + Value scale = b.create(loc, args[0], cst); + b.create(loc, scale); + }); + LDBG("unquantInScalesOp: " << unquantInScalesOp); + Value unquantInScales = unquantInScalesOp.getResult(0); + + // Generic to find scaled sums + RankedTensorType scaledSumsType = groupMaxType; + Value scaledSumsEmpty = rewriter.create( + dequant.getLoc(), scaledSumsType.getShape(), + scaledSumsType.getElementType()); + Value scaledSumsOut = + rewriter + .create(dequant.getLoc(), zeroF32cst, scaledSumsEmpty) + .result(); + LDBG("scaledSumsOut: " << scaledSumsOut); + SmallVector scaledSumsMaps; + scaledSumsMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + scaledSumsMaps.push_back(AffineMap::get(unquantInShape.size(), 0, exprs, + exprs.front().getContext())); + SmallVector scaledSumsIterators = groupMaxIterators; + auto scaledSumsOp = rewriter.create( + dequant.getLoc(), scaledSumsOut.getType(), unquantIn, scaledSumsOut, + scaledSumsMaps, scaledSumsIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value sum = b.create(loc, args[0], args[1]); + b.create(loc, sum); + }); + LDBG("scaledSumsOp: " << scaledSumsOp); + Value scaledSums = scaledSumsOp.getResult(0); + + // Generic to quantized the unquantized input + Value newQuantInOut = rewriter.create( + dequant.getLoc(), unquantInShape, quantType); + LDBG("newQuantInOut: " << newQuantInOut); + SmallVector newQuantInMaps; + newQuantInMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + newQuantInMaps.push_back(AffineMap::get(unquantInShape.size(), 0, exprs, + exprs.front().getContext())); + newQuantInMaps.push_back( + rewriter.getMultiDimIdentityMap(unquantInShape.size())); + auto newQuantInOp = rewriter.create( + dequant.getLoc(), newQuantInOut.getType(), + ValueRange{unquantIn, unquantInScales}, newQuantInOut, newQuantInMaps, + getParallelAndReductionIterators(unquantInShape.size(), 0), + [&](OpBuilder &b, Location loc, ValueRange args) { + Value scaled = b.create(loc, args[0], args[1]); + Value quant = b.create(loc, quantType, scaled); + b.create(loc, quant); + }); + LDBG("newQuantInOp: " << newQuantInOp); + Value newQuantIn = newQuantInOp.getResult(0); + + // ----- Reassociated dequantization matmul ----- // + + // Generic to perform integer matmul and reduce within groups + SmallVector integerMatmulShape = matmulOutShape; + integerMatmulShape.push_back(numGroups); + Value integerMatmulEmpty = rewriter.create( + dequant.getLoc(), integerMatmulShape, accType); + Value integerMatmulOut = + rewriter + .create(dequant.getLoc(), zeroI32cst, + integerMatmulEmpty) + .result(); + LDBG("integerMatmulOut: " << integerMatmulOut); + SmallVector integerMatmulIterators = + getParallelAndReductionIterators(matmul.getNumLoops(), 1); + SmallVector integerMatmulMaps; + integerMatmulMaps.push_back(matmul.getMatchingIndexingMap(unquantInOperand)); + integerMatmulMaps.push_back( + matmul.getMatchingIndexingMap(matmulDequantOperand)); + SmallVector outputExprs( + matmul.getMatchingIndexingMap(matmulOutputOperand).getResults()); + outputExprs.push_back(rewriter.getAffineDimExpr(matmul.getNumLoops() - 2)); + integerMatmulMaps.push_back(AffineMap::get(integerMatmulIterators.size(), 0, + outputExprs, + outputExprs.front().getContext())); + auto integerMatmulOp = rewriter.create( + dequant.getLoc(), integerMatmulOut.getType(), + ValueRange{newQuantIn, quantIn}, integerMatmulOut, integerMatmulMaps, + integerMatmulIterators, [&](OpBuilder &b, Location loc, ValueRange args) { + Value mul; + if (quantType == mulType) { + Value ext1 = b.create(loc, mulType, args[1]); + mul = b.create(loc, args[0], ext1); + } else { + Value ext0 = b.create(loc, mulType, args[0]); + Value ext1 = b.create(loc, mulType, args[1]); + mul = b.create(loc, ext0, ext1); + } + Value sum; + if (mulType == accType) { + sum = b.create(loc, mul, args[2]); + } else { + Value extMul = b.create(loc, accType, mul); + sum = b.create(loc, extMul, args[2]); + } + b.create(loc, sum); + }); + LDBG("integerMatmulOp: " << integerMatmulOp); + Value integerMatmul = integerMatmulOp.getResult(0); + + // Generic to perform dequantization and finish reduction + SmallVector dequantizedMatmulIterators = + getParallelAndReductionIterators(matmul.getNumLoops() - 1, 1); + SmallVector dequantizedMatmulMaps; + dequantizedMatmulMaps.push_back( + rewriter.getMultiDimIdentityMap(dequantizedMatmulIterators.size())); + AffineMap intMatmulNewQuantMap = + matmul.getMatchingIndexingMap(unquantInOperand); + SmallVector newQuantScalesExprs; + for (const auto &expr : enumerate(intMatmulNewQuantMap.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (dimExpr.getPosition() != intMatmulNewQuantMap.getNumDims() - 1) { + newQuantScalesExprs.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition())); + } + } else { + return failure(); + } + } + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, newQuantScalesExprs, + newQuantScalesExprs.front().getContext())); + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, newQuantScalesExprs, + newQuantScalesExprs.front().getContext())); + AffineMap matmulQuantInMap = + matmul.getMatchingIndexingMap(matmulDequantOperand); + SmallVector quantScalesExprs; + for (const auto &expr : enumerate(matmulQuantInMap.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (dimExpr.getPosition() != matmulQuantInMap.getNumDims() - 1) { + quantScalesExprs.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition())); + } + } else { + return failure(); + } + } + RankedTensorType scalesType = + llvm::dyn_cast(scales.getType()); + if (!scalesType) { + return failure(); + } + if (quantScalesExprs.size() < scalesType.getShape().size()) { + quantScalesExprs.push_back(rewriter.getAffineConstantExpr(0)); + if (quantScalesExprs.size() < scalesType.getShape().size()) { + return failure(); + } + } + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, quantScalesExprs, + quantScalesExprs.front().getContext())); + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, quantScalesExprs, + quantScalesExprs.front().getContext())); + SmallVector finalOutputExprs( + matmul.getMatchingIndexingMap(matmulOutputOperand).getResults()); + dequantizedMatmulMaps.push_back( + AffineMap::get(dequantizedMatmulIterators.size(), 0, finalOutputExprs, + finalOutputExprs.front().getContext())); + + auto dequantizedMatmulOp = rewriter.create( + dequant.getLoc(), matmulOutput.getType(), + ValueRange{integerMatmul, unquantInScales, scaledSums, scales, zps}, + matmulOutput, dequantizedMatmulMaps, dequantizedMatmulIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dq; + if (accType == i32Type) { + dq = b.create(loc, f32Type, args[0]); + } else { + Value ext = b.create(loc, i32Type, args[0]); + dq = b.create(loc, f32Type, ext); + } + Value scaledRes0 = b.create(loc, dq, args[1]); + Value scaledRes1 = b.create(loc, scaledRes0, args[3]); + Value scaledZp0 = b.create(loc, args[4], args[3]); + Value scaledZp1 = b.create(loc, scaledZp0, args[2]); + Value groupRes = b.create(loc, scaledRes1, scaledZp1); + Value sum = b.create(loc, groupRes, args[5]); + b.create(loc, sum); + }); + LDBG("dequantizedMatmulOp: " << dequantizedMatmulOp); + Value dequantizedMatmul = dequantizedMatmulOp.getResult(0); + + rewriter.replaceOp(matmul, dequantizedMatmul); + + // Set tile sizes for dequantization + matmul ops + if (failed( + setTileSizes(integerMatmulOp, dequantizedMatmulOp, entryPointFn))) { + return failure(); + } + + // Fuse dequantization + matmul ops into a single dispatch region + SmallVector dequantMatmulOps( + {integerMatmulOp, dequantizedMatmulOp}); + FailureOr maybeDequantMatmulDispatch = + wrapConsecutiveOpsInDispatchRegion(rewriter, dequantMatmulOps); + if (failed(maybeDequantMatmulDispatch)) { + return failure(); + } + + return success(); +} + // Checks if the passed op is a contraction on grouped input // This function checks that the genericOp: // 1. isaContractionOpInterface @@ -172,6 +757,8 @@ static LogicalResult isGroupedDequantizationOp(linalg::GenericOp genericOp) { // Patterns //----------------------------------------------------------------------------// +// This pattern does a basic fusion of dequantization + matmul `linalg.generic` +// ops, moving them into a single `flow.dispatch.region` op. class FuseDequantizationMatmulPattern final : public OpRewritePattern { public: @@ -222,27 +809,84 @@ struct FuseDequantizationMatmulPass : public FuseDequantizationMatmulBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); + } + FuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) { + this->enableQuantizedMatmulReassociation = + enableQuantizedMatmulReassociation; } + FuseDequantizationMatmulPass(const FuseDequantizationMatmulPass &pass) + : FuseDequantizationMatmulPass(pass.enableQuantizedMatmulReassociation) {} - void runOnOperation() override { - MLIRContext *context = &getContext(); - // Main pattern. - { - RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + void runOnOperation() override; +}; + +} // namespace + +void FuseDequantizationMatmulPass::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + // Perform reassociation if enabled + if (this->enableQuantizedMatmulReassociation) { + SmallVector> candidates; + for (auto genericOp : + funcOp.getFunctionBody().getOps()) { + if (failed(isGroupedContractionOp(genericOp))) + continue; + + OpOperand *lhs = genericOp.getDpsInputOperand(0); + OpOperand *rhs = genericOp.getDpsInputOperand(1); + auto lhsOp = lhs->get().getDefiningOp(); + auto rhsOp = rhs->get().getDefiningOp(); + if (!llvm::cast(genericOp.getInputs()[0].getType()) + .hasStaticShape() || + !llvm::cast(genericOp.getInputs()[1].getType()) + .hasStaticShape() || + !llvm::cast(genericOp.getResults()[0].getType()) + .hasStaticShape()) { + // Codegen can't handle the dynamic case yet. + continue; + } + if (lhsOp) { + if (!failed(isGroupedDequantizationOp(lhsOp))) { + candidates.push_back(std::make_pair(lhsOp, genericOp)); + continue; + } + } + if (rhsOp) { + if (!failed(isGroupedDequantizationOp(rhsOp))) { + candidates.push_back(std::make_pair(rhsOp, genericOp)); + } + } + } + IRRewriter rewriter(context); + for (auto candidate : candidates) { + rewriter.setInsertionPointAfter(candidate.second); + if (failed(ReassociateAndFuseDequantMatmul( + rewriter, candidate.first, candidate.second, + llvm::cast(funcOp)))) { return signalPassFailure(); } } } -}; -} // namespace + // Normal fusion pattern. + { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +} -std::unique_ptr createFuseDequantizationMatmulPass() { - return std::make_unique(); +std::unique_ptr> +createFuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) { + return std::make_unique( + enableQuantizedMatmulReassociation); } } // namespace Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h index 78dd11a0dd45..b829519dc2ae 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -156,7 +156,9 @@ std::unique_ptr> createCloneProducersIntoDispatchRegionsPass(); // A pass to fuse dequantization and matmul linalg.generic ops -std::unique_ptr createFuseDequantizationMatmulPass(); +std::unique_ptr> +createFuseDequantizationMatmulPass( + bool enableQuantizedMatmulReassociation = false); //===----------------------------------------------------------------------===// // Dispatches (flow.dispatch.workgroups) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir index 2537efd5b1a9..567cf1652b29 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-flow-fuse-dequantization-matmul --canonicalize %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-fuse-dequantization-matmul,canonicalize))" %s | FileCheck %s module { func.func @grouped_quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> { diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index d2de7959b6cc..47cc09e52710 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -19,6 +19,12 @@ namespace GlobalOptimization { using FunctionLikeNest = MultiOpNest; +static llvm::cl::opt clEnableQuantizedMatmulReassociation( + "iree-flow-enable-quantized-matmul-reassociation", + llvm::cl::desc( + "Enables reassociation of quantized matmul ops (experimental)."), + llvm::cl::init(false)); + void buildGlobalOptimizationPassPipeline( OpPassManager &mainPassManager, const TransformOptions &transformOptions) { // ML frontends have very uneven support for user-controlled types _and_ users @@ -67,7 +73,12 @@ void buildGlobalOptimizationPassPipeline( // this pass both before unit dim folding + consteval, as well as after. .addPass(IREE::Flow::createRaiseSpecialOps) .addPass(IREE::Flow::createFoldUnitExtentDimsPass) - .addPass(IREE::Flow::createFuseDequantizationMatmulPass) + .addPass([&]() { + return IREE::Flow::createFuseDequantizationMatmulPass( + clEnableQuantizedMatmulReassociation); + }) + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) // Enable data tiling after they are in a canonical form. .addPredicatedPass(transformOptions.options.dataTiling, createSetEncodingPass)