From 3eb555e69701b2afe4b853f82f7bc8ccf81ce522 Mon Sep 17 00:00:00 2001 From: jinchen62 <49575973+jinchen62@users.noreply.github.com> Date: Tue, 18 Jul 2023 12:42:18 -0700 Subject: [PATCH] Support brevitas custom op (#2320) --- .../TorchConversion/Transforms/Passes.h | 3 + .../TorchConversion/Transforms/Passes.td | 5 + lib/Conversion/TorchToLinalg/Linear.cpp | 162 ++++++++++++++++++ lib/Dialect/Torch/IR/TorchTypes.cpp | 4 +- .../TorchConversion/Transforms/CMakeLists.txt | 1 + .../Transforms/UnpackTensor.cpp | 141 +++++++++++++++ 6 files changed, 314 insertions(+), 2 deletions(-) create mode 100644 lib/Dialect/TorchConversion/Transforms/UnpackTensor.cpp diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index e6493a154edd..c90a55ed537e 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -57,6 +57,9 @@ std::unique_ptr> createFuncBackendTypeConversionPass(); std::unique_ptr> createFinalizingBackendTypeConversionPass(); +std::unique_ptr> +createUnpackTorchTensorPass(); + std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index cb58dbbd998b..47047d669021 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -32,6 +32,11 @@ def FinalizingBackendTypeConversion }]; } +def UnpackTorchTensor : Pass<"torch-unpack-torch-tensor", "func::FuncOp"> { + let summary = "Unpack Int4 Torch Tensor"; + let constructor = "mlir::torch::TorchConversion::createUnpackTorchTensorPass()"; +} + def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { let summary = "Verifies conformity to the linalg-on-tensors backend contract"; let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 65f08a4d71ca..ff69dda21cb2 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -427,6 +427,166 @@ class ConvertAtenMatmulOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getName().str() != "brevitas.matmul_rhs_group_quant") { + return failure(); + } + Location loc = op->getLoc(); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { + return failure(); + } + + // get inputs: lhs, q_rhs, scales, zps + Value lhs = adaptor.getOperands()[0]; + auto lhsType = lhs.getType().cast(); + if (!lhsType) { + return failure(); + } + auto lhsShape = lhsType.getShape(); + int lhs_reduct_dim_size = lhsShape.back(); + + Value q_rhs = adaptor.getOperands()[1]; + auto rhsType = q_rhs.getType().cast(); + if (!rhsType) { + return failure(); + } + auto rhsShape = rhsType.getShape(); + int rhs_reduct_dim_size = rhsShape.back(); + Type rhs_elementType = rhsType.getElementType(); + + Value scales = adaptor.getOperands()[2]; + Value zps = adaptor.getOperands()[3]; + Value unpacked_type_width = adaptor.getOperands()[4]; + Value group_size = adaptor.getOperands()[5]; + + auto getConstantIntegerFromDefiningOp = [](Value operand, + int &extractedInt) { + auto castOp = dyn_cast(operand.getDefiningOp()); + if (!castOp) { + return failure(); + } + auto constOp = + dyn_cast(castOp.getOperand(0).getDefiningOp()); + if (!constOp) { + return failure(); + } + extractedInt = constOp.getValue(); + return success(); + }; + + int gs; + if (failed(getConstantIntegerFromDefiningOp(group_size, gs))) { + return failure(); + } + int unpackedBitWidth; + if (failed(getConstantIntegerFromDefiningOp(unpacked_type_width, unpackedBitWidth))) { + return failure(); + } + if (unpackedBitWidth != rhs_elementType.getIntOrFloatBitWidth()) { + return failure(); + } + + // get outputs + Type newResultType = getTypeConverter()->convertType(op.getType(0)); + auto resultType = newResultType.cast(); + if (!resultType) { + return failure(); + } + auto resultShape = resultType.getShape(); + Type elementType = resultType.getElementType(); + + // expand lhs + std::vector lhs_expandedShape = {lhsShape[0], lhsShape[1], + lhs_reduct_dim_size / gs, gs}; + RankedTensorType lhs_expandedType = RankedTensorType::get(lhs_expandedShape, elementType); + SmallVector lhs_reassociation = {{0}, {1}, {2, 3}}; + Value expanded_lhs = rewriter.create( + loc, lhs_expandedType, lhs, lhs_reassociation); + + // expand rhs + std::vector expandedShape = {rhsShape[0], rhs_reduct_dim_size/gs, gs}; + RankedTensorType expandedType = RankedTensorType::get(expandedShape, rhs_elementType); + SmallVector reassociation = {{0}, {1, 2}}; + Value expanded_rhs = rewriter.create( + loc, expandedType, q_rhs, reassociation); + Value cst_0 = rewriter.create( + loc, FloatAttr::get(elementType, 0.0)); + + Value dq_empty = rewriter.create( + loc, expandedShape, elementType); + SmallVector dynDims; + for (int i = 0; i < lhsType.getRank(); i++) { + if (lhsType.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, lhs, i)); + } + } + Value empty = rewriter.create( + loc, resultShape, elementType, dynDims); + Value output = rewriter.create( + loc, cst_0, empty).getResult(0); + + AffineExpr d0, d1, d2, d3, d4; + bindDims(getContext(), d0, d1, d2, d3, d4); + auto c0 = rewriter.getAffineConstantExpr(0); + auto map = AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()); + auto map1 = AffineMap::get(3, 0, {d0, d1, c0}, rewriter.getContext()); + auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext()); + auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext()); + auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext()); + SmallVector dq_indexingMaps = {map, map1, map1, map}; + SmallVector mat_indexingMaps = {map2, map3, map4}; + + SmallVector dq_iteratorTypes(3, utils::IteratorType::parallel); + SmallVector mat_iteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction + }; + + Value dq_rhs = + rewriter + .create( + loc, dq_empty.getType(), + ValueRange{expanded_rhs, scales, zps}, dq_empty, + /*indexingMaps=*/dq_indexingMaps, + /*iteratorTypes=*/dq_iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value w = args[0], scale = args[1], zeroPoint = args[2]; + Value extw = b.create(loc, rewriter.getI32Type(), w); + Value fp_extw = b.create(loc, rewriter.getF32Type(), extw); + Value shifted = b.create(loc, fp_extw, zeroPoint); + Value dqw = b.create(loc, shifted, scale); + b.create(loc, dqw); + }) + .getResult(0); + + Value quantMat = + rewriter + .create( + loc, output.getType(), + ValueRange{expanded_lhs, dq_rhs}, output, + /*indexingMaps=*/mat_indexingMaps, + /*iteratorTypes=*/mat_iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], out = args[2]; + Value pd = b.create(loc, l, r); + Value ac = b.create(loc, pd, out); + b.create(loc, ac); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, quantMat); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenBmmOp : public OpConversionPattern { public: @@ -860,6 +1020,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 8eb844cbd00b..1c8d3c6f722a 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -194,13 +194,13 @@ static bool isValidTorchDtype(Type dtype) { if (type.isSignless() && type.getWidth() == 1) return true; if (type.isSigned()) { - for (unsigned width : {8, 16, 32, 64}) { + for (unsigned width : {4, 8, 16, 32, 64}) { if (type.getWidth() == width) return true; } } if (type.isUnsigned()) { - return type.getWidth() == 8; + return type.getWidth() == 8 || type.getWidth() == 4; } } return false; diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 1f7f4e8f8294..bd685bc038d5 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses BackendTypeConversion.cpp BackendTypeConversionPasses.cpp Passes.cpp + UnpackTensor.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp VerifyStablehloBackendContract.cpp diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackTensor.cpp new file mode 100644 index 000000000000..5c023641c8f6 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/UnpackTensor.cpp @@ -0,0 +1,141 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class UnpackQuantizedMatmulWeights + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ValueTensorLiteralOp constOp, + PatternRewriter &rewriter) const override { + if (!constOp->hasOneUse()) + return failure(); + + OpOperand *use = constOp.getResult().use_begin().getOperand(); + auto op = dyn_cast(use->getOwner()); + if (!op) + return failure(); + + if (use->getOperandNumber() != 1) + return failure(); + + if (op.getName().str() != "brevitas.matmul_rhs_group_quant") { + return failure(); + } + + Value rhs = op.getOperand(1); + Value bitWidth = op.getOperand(4); + + auto getConstantIntegerFromDefiningOp = [](Value operand, + int &extractedInt) { + auto constOp = dyn_cast(operand.getDefiningOp()); + if (!constOp) { + return failure(); + } + extractedInt = constOp.getValue(); + return success(); + }; + int unpackedBitWidth; + if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) + return failure(); + + auto rhsType = rhs.getType().dyn_cast(); + if (!rhsType) + return failure(); + + if (!rhsType.hasDtype()) + return failure(); + + Type dType = rhsType.getDtype(); + int dTypeWidth = dType.getIntOrFloatBitWidth(); + if (dTypeWidth == unpackedBitWidth) + return failure(); + + if (!rhsType.hasSizes()) + return failure(); + + SmallVector tensorShape(rhsType.getSizes()); + if (tensorShape.back() == kUnknownSize) + return failure(); + int packRatio = dTypeWidth / unpackedBitWidth; + + tensorShape[tensorShape.size() - 1] *= packRatio; + Type unpackedElementType; + if (dType.isSignedInteger()) + unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true); + else + unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false); + ValueTensorType newRhsType = ValueTensorType::get( + rewriter.getContext(), tensorShape, unpackedElementType); + + auto elements = constOp.getValueAttr().dyn_cast(); + if (!elements) + return failure(); + + auto attrType = RankedTensorType::get(tensorShape, unpackedElementType); + + // This is terrible but idk what else to do. + auto data = elements.getRawData(); + std::vector newData(data.size() * packRatio, + APInt(unpackedBitWidth, 0)); + for (int i = 0, e = data.size(); i < e; ++i) { + auto el = data[i]; + char mask = (1 << unpackedBitWidth) - 1; + for (int b = 0; b < packRatio; b++) { + newData[i * packRatio + b] = + APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b)); + mask = mask << unpackedBitWidth; + } + } + rewriter.replaceOpWithNewOp( + constOp, newRhsType, + DenseElementsAttr::get(attrType, ArrayRef(newData))); + return success(); + } +}; +} // namespace + +namespace { +class UnpackTorchTensorPass + : public TorchConversion::UnpackTorchTensorBase { + using UnpackTorchTensorBase::UnpackTorchTensorBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createUnpackTorchTensorPass() { + return std::make_unique(); +}