diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index b83ac2881433..7f062548522f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -1683,6 +1683,173 @@ static LogicalResult setElementwiseGenericOpRootConfig( tileSizes, passPipeline); } +// Checks if the passed op is a dequantization on grouped input +// This function checks that the genericOp: +// 1. Has a body like: +// arith.extui +// arith.uitofp +// arith.subf +// arith.mulf +// arith.mulf +// arith.addf +// 2. Increases the bit width of the input +// 3. Has 3 parallel dims +// 4. Has 4 (rhs, weights, scales, zero points) +// inputs and 1 output +static bool isGroupedDequantizationMatvecOp(linalg::GenericOp genericOp) { + // Check for 1 result, and 2 (input, scales) or 3 (input, scales, zero points) + // inputs + if (genericOp.getNumDpsInits() != 1) { + LLVM_DEBUG(KD_DBGS() << "Wrong number of outputs: " + << genericOp.getNumDpsInits() << "\n"); + return false; + } + if (genericOp.getNumDpsInputs() != 4) { + LLVM_DEBUG(KD_DBGS() << "Wrong number of inputs: " + << genericOp.getNumDpsInputs() << "\n"); + return false; + } + + // Check that the rank is at least 3 and all loops are parallel + unsigned numLoops = genericOp.getNumLoops(); + unsigned numReductionLoops = genericOp.getNumReductionLoops(); + if (numLoops != 4) { + LLVM_DEBUG(KD_DBGS() << "Wrong number of loops: " << numLoops << "\n"); + return false; + } + if (numReductionLoops != 2) { + LLVM_DEBUG(KD_DBGS() << "Wrong number of reduction loops: " + << numReductionLoops << "\n"); + return false; + } + // Work back from linalg.yield and check body of genericOp. + // The genericOp should yield the result of an arith.mulf, + // preceded by an arith.subf, arith.uitofp, and arith.extui + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + Value producerOutput; + Operation *producer; + + // Producer of linalg.yield op is arith.addf + { + producerOutput = yieldOp->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.addf op is arith.mulf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.mulf op is arith.mulf + { + producerOutput = producer->getOperand(1); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.mulf op is arith.subf + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.subf op is arith.uitofp + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer || producer->getNumOperands() == 0) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Producer of arith.uitofp op is arith.extui + { + producerOutput = producer->getOperand(0); + producer = producerOutput.getDefiningOp(); + if (!producer) + return false; + if (!matchPattern(producer, m_Op())) + return false; + } + + // Ensure that the dequantization increases the + // bitwidth from the input to the output + auto elementTypeOut = + llvm::cast(genericOp.getOutputs()[0].getType()) + .getElementType(); + if (!elementTypeOut.isIntOrFloat()) + return false; + unsigned bitWidthOut = elementTypeOut.getIntOrFloatBitWidth(); + auto elementTypeIn = + llvm::cast(genericOp.getInputs()[1].getType()) + .getElementType(); + if (!elementTypeIn.isIntOrFloat()) + return false; + unsigned bitWidthIn = elementTypeIn.getIntOrFloatBitWidth(); + if (bitWidthIn >= bitWidthOut) + return false; + + return true; +} + +/// Sets linalg.generic ops that represent rematerialized dequantized matvec +/// ContractionOpInterface RootConfig +static LogicalResult setDequantizationMatvecOpRootConfig( + func::FuncOp entryPointFn, linalg::GenericOp genericOp, + const LinalgOpInfo &linalgOpInfo, + const TargetMLTransformInfo &targetMLTransInfo) { + assert(!getLoweringConfig(genericOp) && + "expected lowering_config is not set"); + unsigned numLoops = genericOp.getNumLoops(); + if (!isGroupedDequantizationMatvecOp(genericOp)) { + LLVM_DEBUG(KD_DBGS() << "Failed matching for dequantized matvec\n"); + return failure(); + } + + SmallVector distTileSizes = {32, 32, 0, 0}; + SmallVector parallelTileSizes = {1, 1, 0, 0}; + SmallVector reductionTileSizes = {0, 0, 1, 64}; + + SmallVector reductionDims; + genericOp.getReductionDims(reductionDims); + SmallVector bounds = genericOp.getStaticLoopRanges(); + + TileSizesListType tileSizes; + tileSizes.push_back(distTileSizes); + tileSizes.push_back(parallelTileSizes); + tileSizes.push_back(reductionTileSizes); + tileSizes.emplace_back(numLoops, 0); + + LLVM_DEBUG(KD_DBGS() << "Setting dequantized matmul config\n"); + LLVM_DEBUG(KD_DBGS() << "Distribution tile sizes: " << distTileSizes << "\n"); + LLVM_DEBUG(KD_DBGS() << "Parallel tile sizes: " << parallelTileSizes << "\n"); + LLVM_DEBUG(KD_DBGS() << "Reduction tile size: " << reductionTileSizes + << "\n"); + + DispatchLoweringPassPipeline passPipeline = + DispatchLoweringPassPipeline::CPUDoubleTilingExpert; + + return setOpConfigAndEntryPointFnTranslation(entryPointFn, genericOp, + tileSizes, passPipeline); +} + /// Sets the lowering configuration for a generic op to use /// CPUDoubleTilingExpert pipeline. static LogicalResult @@ -1705,6 +1872,10 @@ setRootConfig(func::FuncOp entryPointFn, linalg::GenericOp genericOp, entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) { return success(); } + if (succeeded(setDequantizationMatvecOpRootConfig( + entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) { + return success(); + } if (succeeded(setDefaultGenericOpRootConfig( entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) { return success();