diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index f757ae8a0f..18478d34b4 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -48,4 +48,11 @@ llvm::cl::opt nnpaEnableZHighToOnnx("enable-zhigh-to-onnx", "level. Default is true."), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt nnpaEnableZHighCostModel("enable-zhigh-cost-model", + llvm::cl::desc( + "Enabling a performance cost model to estimate the benefit of " + "migrating an eligible onnx operation to a ZHigh operation. Default is " + "false."), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); + } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index ce09a6e0b7..0373f0991e 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -46,6 +46,7 @@ extern llvm::cl::opt nnpaEmissionTarget; extern llvm::cl::list execNodesOnCpu; extern llvm::cl::opt nnpaClipToDLFloatRange; extern llvm::cl::opt nnpaEnableZHighToOnnx; +extern llvm::cl::opt nnpaEnableZHighCostModel; extern llvm::cl::opt profileZHighIR; } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 010da8b3be..631b119351 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -45,12 +45,13 @@ using namespace onnx_mlir; namespace onnx_mlir { -void addONNXToZHighPasses( - mlir::PassManager &pm, ArrayRef execNodesOnCpu) { +void addONNXToZHighPasses(mlir::PassManager &pm, + ArrayRef execNodesOnCpu, bool useCostModel) { for (unsigned i = 0; i < 3; i++) { // Repeat this process so that shape-related ops such as Shape, Expand, // Gather generated during RewriteONNXForZHigh will become constants. - pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu)); + pm.addPass(onnx_mlir::createRewriteONNXForZHighPass( + execNodesOnCpu, false /*useCostModel*/)); // Simplify shape-related ops, including ShapeOp-to-DimOp replacement, // constant propagation, shape inference and canonicalize. pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass()); @@ -75,7 +76,7 @@ void addONNXToZHighPasses( pm.addNestedPass( onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); - pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu)); + pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu, useCostModel)); pm.addNestedPass(onnx_mlir::createShapeInferencePass()); // There are more opportunities for const propagation once all zhigh ops were // generated. @@ -155,7 +156,7 @@ void addPassesNNPA(mlir::OwningOpRef &module, if (emissionTarget >= EmitMLIR) { // Lower zAIU-compatible ONNX ops to ZHigh dialect where possible. - addONNXToZHighPasses(pm, execNodesOnCpu); + addONNXToZHighPasses(pm, execNodesOnCpu, nnpaEnableZHighCostModel); if (nnpaEmissionTarget >= EmitZHighIR) emissionTarget = EmitMLIR; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt index 1fc6c4aac8..372b9465d2 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt @@ -5,6 +5,7 @@ add_onnx_mlir_library(OMONNXToZHigh ONNXLegalityCheck.cpp ONNXToZHigh.cpp ONNXToZHighCommon.cpp + ZHighPerfModel.cpp DEPENDS OMONNXONNXToZHighIncGen @@ -25,6 +26,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh ONNXLegalityCheck.cpp RewriteONNXForZHigh.cpp ONNXToZHighCommon.cpp + ZHighPerfModel.cpp DEPENDS OMONNXRewriteONNXForZHighIncGen diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index 69ad591ad2..652e54c5b2 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -15,6 +15,7 @@ #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h" +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.hpp" #include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" @@ -45,9 +46,10 @@ bool isValidElementTypeAndRank(Value val) { } /// Common legality check for pooling ops. -template -bool checkLegalityPoolOpsCommon(POOLOP op, Value Y) { - POOLOPShapeHelper shapeHelper(op.getOperation(), {}); +template +bool checkLegalityPoolOpsCommon(POOL_OP op, Value Y) { + POOL_OP_ShapeHelper shapeHelper(op.getOperation(), {}); shapeHelper.computeShapeAndAssertOnFailure(); Value X = op.getX(); int64_t ceilMode = op.getCeilMode(); @@ -69,7 +71,7 @@ bool checkLegalityPoolOpsCommon(POOLOP op, Value Y) { // When input has unknown dimension and auto_pad is `NOTSET`, paddingType is // empty. StringRef paddingType = - getStrPaddingType(op); + getStrPaddingType(op); if (paddingType.empty()) return false; @@ -241,59 +243,119 @@ bool meetPoolParamRestrictions(int64_t inputShape, int64_t kernelShape, /// Default legality check. template -bool isSuitableForZDNN(OP_TYPE op, const DimAnalysis *dimAnalysis) { +bool isSuitableForZDNN( + OP_TYPE op, bool useCostModel, const DimAnalysis *dimAnalysis) { return false; } +/// Default model, assume beneficial until proven otherwise. +template +bool isFasterOnNNPA(OP_TYPE op, const DimAnalysis *dimAnalysis) { + return true; +} + /// Check legality for ONNXAdd. // zDNN Add, Sub, Mul, Div do not support broadcasting. + +template <> +bool isFasterOnNNPA(ONNXAddOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getA(), op.getB(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXAddOp op, const DimAnalysis *dimAnalysis) { + ONNXAddOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getA())) return false; if (!isValidElementTypeAndRank(op.getB())) return false; - return dimAnalysis->sameShape(op.getA(), op.getB()); + if (!dimAnalysis->sameShape(op.getA(), op.getB())) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXSub. + +template <> +bool isFasterOnNNPA(ONNXSubOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getA(), op.getB(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXSubOp op, const DimAnalysis *dimAnalysis) { + ONNXSubOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getA())) return false; if (!isValidElementTypeAndRank(op.getB())) return false; - return dimAnalysis->sameShape(op.getA(), op.getB()); + if (!dimAnalysis->sameShape(op.getA(), op.getB())) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXMul. + +template <> +bool isFasterOnNNPA(ONNXMulOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getA(), op.getB(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXMulOp op, const DimAnalysis *dimAnalysis) { + ONNXMulOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getA())) return false; if (!isValidElementTypeAndRank(op.getB())) return false; - return dimAnalysis->sameShape(op.getA(), op.getB()); + if (!dimAnalysis->sameShape(op.getA(), op.getB())) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXDiv. + +template <> +bool isFasterOnNNPA(ONNXDivOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getA(), op.getB(), dimAnalysis, 8.0); +} + template <> bool isSuitableForZDNN( - ONNXDivOp op, const DimAnalysis *dimAnalysis) { + ONNXDivOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getA())) return false; if (!isValidElementTypeAndRank(op.getB())) return false; - return dimAnalysis->sameShape(op.getA(), op.getB()); + if (!dimAnalysis->sameShape(op.getA(), op.getB())) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXSum. + +template <> +bool isFasterOnNNPA(ONNXSumOp op, const DimAnalysis *dimAnalysis) { + // Since all summed elements must have the same shape, sufficient to determine + // if faster by looking at the first 2. + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getData_0()[0], op.getData_0()[1], dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXSumOp op, const DimAnalysis *dimAnalysis) { + ONNXSumOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { // Do not support a single input. if (op.getData_0().size() < 2) return false; @@ -308,113 +370,213 @@ bool isSuitableForZDNN( if (!dimAnalysis->sameShape(op.getData_0()[0], op.getData_0()[i])) return false; } + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; return true; } /// Check legality for ONNXMin. /// zDNN Min/Max do not support broadcasting, and getNumOperands != 2. + +template <> +bool isFasterOnNNPA(ONNXMinOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getOperand(0), op.getOperand(1), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXMinOp op, const DimAnalysis *dimAnalysis) { - int64_t opnum = op.getNumOperands(); - if (opnum != 2) { + ONNXMinOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { + int64_t opNum = op.getNumOperands(); + if (opNum != 2) { return false; } if (!isValidElementTypeAndRank(op.getOperand(0))) return false; if (!isValidElementTypeAndRank(op.getOperand(1))) return false; - return dimAnalysis->sameShape(op.getOperand(0), op.getOperand(1)); + if (!dimAnalysis->sameShape(op.getOperand(0), op.getOperand(1))) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXMax. -/// zDNN Min/Max do not support boradcasting, and getNumOperands != 2. +/// zDNN Min/Max do not support broadcasting, and getNumOperands != 2. + +template <> +bool isFasterOnNNPA(ONNXMaxOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getOperand(0), op.getOperand(1), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXMaxOp op, const DimAnalysis *dimAnalysis) { - int64_t opnum = op.getNumOperands(); - if (opnum != 2) { + ONNXMaxOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { + int64_t opNum = op.getNumOperands(); + if (opNum != 2) { return false; } if (!isValidElementTypeAndRank(op.getOperand(0))) return false; if (!isValidElementTypeAndRank(op.getOperand(1))) return false; - return dimAnalysis->sameShape(op.getOperand(0), op.getOperand(1)); + if (!dimAnalysis->sameShape(op.getOperand(0), op.getOperand(1))) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXSoftmax. /// zDNN softmax only supports axis = rank-1 (or -1) when rank = 2 or 3). If /// axis is not rank-1 (or -1) when rank = 2/3), keep ONNXSoftmax unchanged. + template <> -bool isSuitableForZDNN( +bool isFasterOnNNPA( ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) { + // Has no data on softmax, approx for the moment by using elementwise. + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getInput(), dimAnalysis); +} + +template <> +bool isSuitableForZDNN( + ONNXSoftmaxOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getInput())) return false; ShapedType inputType = op.getType().cast(); if (!inputType.hasRank()) return false; int64_t rank = inputType.getRank(); - return (((rank == 2) || (rank == 3)) && - ((op.getAxis() == rank - 1) || (op.getAxis() == -1))); + if (!(((rank == 2) || (rank == 3)) && + ((op.getAxis() == rank - 1) || (op.getAxis() == -1)))) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXRelu. + +template <> +bool isFasterOnNNPA(ONNXReluOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA(op.getOperation(), op.getX(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXReluOp op, const DimAnalysis *dimAnalysis) { + ONNXReluOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getX())) return false; ShapedType xType = op.getX().getType().cast(); - return xType.hasRank() && (xType.getRank() <= 4); + if (!xType.hasRank() || xType.getRank() > 4) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXTanh. + +template <> +bool isFasterOnNNPA(ONNXTanhOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getInput(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXTanhOp op, const DimAnalysis *dimAnalysis) { + ONNXTanhOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getInput())) return false; ShapedType inputType = op.getType().cast(); - return inputType.hasRank() && (inputType.getRank() <= 4); + if (!inputType.hasRank() || inputType.getRank() > 4) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXSigmoid. + template <> -bool isSuitableForZDNN( +bool isFasterOnNNPA( ONNXSigmoidOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA(op.getOperation(), op.getX(), dimAnalysis); +} + +template <> +bool isSuitableForZDNN( + ONNXSigmoidOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getX())) return false; ShapedType xType = op.getX().getType().cast(); - return xType.hasRank() && (xType.getRank() <= 4); + if (!xType.hasRank() || xType.getRank() > 4) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXLog. + +template <> +bool isFasterOnNNPA(ONNXLogOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getInput(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXLogOp op, const DimAnalysis *dimAnalysis) { + ONNXLogOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getInput())) return false; ShapedType inputType = op.getInput().getType().cast(); - return inputType.hasRank() && (inputType.getRank() <= 4); + if (!inputType.hasRank() || inputType.getRank() > 4) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXExp. + +template <> +bool isFasterOnNNPA(ONNXExpOp op, const DimAnalysis *dimAnalysis) { + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getInput(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXExpOp op, const DimAnalysis *dimAnalysis) { + ONNXExpOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { if (!isValidElementTypeAndRank(op.getInput())) return false; ShapedType inputType = op.getInput().getType().cast(); - return inputType.hasRank() && (inputType.getRank() <= 4); + if (!inputType.hasRank() || inputType.getRank() > 4) + return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXMatMul. + template <> -bool isSuitableForZDNN( +bool isFasterOnNNPA( ONNXMatMulOp op, const DimAnalysis *dimAnalysis) { - int64_t opnum = op.getNumOperands(); - if (opnum != 2) { + return isMatMulFasterOnNNPA(op.getOperation(), op.getOperand(0), + op.getOperand(1), /*aTransposed*/ false, /*bTransposed*/ false, + dimAnalysis); +} + +template <> +bool isSuitableForZDNN( + ONNXMatMulOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { + int64_t opNum = op.getNumOperands(); + if (opNum != 2) { return false; } if (!isValidElementTypeAndRank(op.getOperand(0))) @@ -440,29 +602,37 @@ bool isSuitableForZDNN( if ((shapeA.size() == 2) && (shapeB.size() == 2)) { // unstacked case if (aType.hasStaticShape() && bType.hasStaticShape()) - return (shapeA[1] == shapeB[0]); - else - return true; + if (shapeA[1] != shapeB[0]) + return false; } else if ((shapeA.size() == 3) && (shapeB.size() == 3)) { // stacked w/o bcast case if (aType.hasStaticShape() && bType.hasStaticShape()) - return ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1])); - else - return true; + if ((shapeA[0] != shapeB[0]) || (shapeA[2] != shapeB[1])) + return false; } else if ((shapeA.size() == 3) && (shapeB.size() == 2)) { // stacked w/ bcast if (aType.hasStaticShape() && bType.hasStaticShape()) - return (shapeA[2] == shapeB[0]); - else - return true; + if (shapeA[2] != shapeB[0]) + return false; + } else { + return false; // Unsupported case. } - return false; // unsupported case + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; + return true; } /// Check legality for ONNXGemm. +template <> +bool isFasterOnNNPA(ONNXGemmOp op, const DimAnalysis *dimAnalysis) { + // For GEMM, consider only the multiplication for the cost model. + return isMatMulFasterOnNNPA(op.getOperation(), op.getA(), op.getB(), + op.getTransA(), op.getTransB(), dimAnalysis); +} + template <> bool isSuitableForZDNN( - ONNXGemmOp op, const DimAnalysis *dimAnalysis) { + ONNXGemmOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { Value A = op.getA(); Value B = op.getB(); Value C = op.getC(); @@ -512,13 +682,23 @@ bool isSuitableForZDNN( if (cShape[0] != bShape1) return false; } + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; return true; } /// Check legality for ONNXReduceMeanV13. template <> -bool isSuitableForZDNN( +bool isFasterOnNNPA( ONNXReduceMeanV13Op op, const DimAnalysis *dimAnalysis) { + // Has no data on reduce min, approx for the moment by using elementwise. + return isElementwiseFasterOnNNPA( + op.getOperation(), op.getData(), dimAnalysis); +} + +template <> +bool isSuitableForZDNN( + ONNXReduceMeanV13Op op, bool useCostModel, const DimAnalysis *dimAnalysis) { // Check data type. if (!isValidElementTypeAndRank(op.getData())) return false; @@ -547,6 +727,8 @@ bool isSuitableForZDNN( (shapeData[3] > 1024)) return false; + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; return true; } @@ -554,7 +736,7 @@ bool isSuitableForZDNN( /// TODO: current ONNX-to-zhigh conversion does not support bi-direction template <> bool isSuitableForZDNN( - ONNXLSTMOp op, const DimAnalysis *dimAnalysis) { + ONNXLSTMOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { StringRef direction = op.getDirection(); Value W = op.getW(); Value R = op.getR(); @@ -629,7 +811,7 @@ bool isSuitableForZDNN( /// TODO: current ONNX-to-zhigh conversion does not support bi-direction template <> bool isSuitableForZDNN( - ONNXGRUOp op, const DimAnalysis *dimAnalysis) { + ONNXGRUOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { StringRef direction = op.getDirection(); Value W = op.getW(); Value R = op.getR(); @@ -699,8 +881,8 @@ bool isSuitableForZDNN( /// Check legality for ONNXMaxPool. template <> -bool isSuitableForZDNN( - ONNXMaxPoolSingleOutOp op, const DimAnalysis *dimAnalysis) { +bool isSuitableForZDNN(ONNXMaxPoolSingleOutOp op, + bool useCostModel, const DimAnalysis *dimAnalysis) { // Check data type. if (!isValidElementTypeAndRank(op.getX())) return false; @@ -723,7 +905,7 @@ bool isSuitableForZDNN( /// Check legality for ONNXAveragePool. template <> bool isSuitableForZDNN( - ONNXAveragePoolOp op, const DimAnalysis *dimAnalysis) { + ONNXAveragePoolOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { // Check data type. if (!isValidElementTypeAndRank(op.getX())) return false; @@ -780,7 +962,7 @@ static bool checkConv2DParamRestrictions(int64_t inputDim, int64_t kernelDim, /// Check legality for ONNXConvOp. template <> bool isSuitableForZDNN( - ONNXConvOp op, const DimAnalysis *dimAnalysis) { + ONNXConvOp op, bool useCostModel, const DimAnalysis *dimAnalysis) { // Check data type. if (!isValidElementTypeAndRank(op.getX())) return false; @@ -855,9 +1037,18 @@ bool isSuitableForZDNN( } /// Check legality for ONNXBatchNormOp. + template <> -bool isSuitableForZDNN( +bool isFasterOnNNPA( ONNXBatchNormalizationInferenceModeOp op, const DimAnalysis *dimAnalysis) { + // Cedric's spreadsheet shows similar results to elementwise. + return isElementwiseFasterOnNNPA(op.getOperation(), op.getX(), dimAnalysis); +} + +template <> +bool isSuitableForZDNN( + ONNXBatchNormalizationInferenceModeOp op, bool useCostModel, + const DimAnalysis *dimAnalysis) { ShapedType inputType = op.getX().getType().cast(); ShapedType outputType = op.getO_Y().getType().cast(); ArrayRef shapeInput = inputType.getShape(); @@ -866,6 +1057,7 @@ bool isSuitableForZDNN( // 4D tensors(N x C x H x W) are supported as input and output. if (shapeInput.size() != 4 || shapeOutput.size() != 4) return false; - + if (useCostModel && !isFasterOnNNPA(op, dimAnalysis)) + return false; return true; } diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp index c6225bee24..fbe8c709d7 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp @@ -22,9 +22,10 @@ /// A function to check whether an ONNX op is suitable for being lowered to zDNN /// or not. +// TODO: revisit if the cost model should not be optional. template -bool isSuitableForZDNN( - OP_TYPE op, const onnx_mlir::DimAnalysis *dimAnalysis = nullptr); +bool isSuitableForZDNN(OP_TYPE op, bool useCostModel = false, + const onnx_mlir::DimAnalysis *dimAnalysis = nullptr); /// Get padding type using shape helper. This returns /// `SAME_PADDING`, `VALID_PADDING`, or empty. diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 1b8f8ed6fd..bb5aa15c25 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -257,8 +257,10 @@ struct ONNXToZHighLoweringPass ONNXToZHighLoweringPass() = default; ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass) : PassWrapper>() {} - ONNXToZHighLoweringPass(mlir::ArrayRef execNodesOnCpu) { + ONNXToZHighLoweringPass( + mlir::ArrayRef execNodesOnCpu, bool useCostModel) { this->execNodesOnCpu = execNodesOnCpu; + this->useCostModel = useCostModel; } void runOnOperation() final; @@ -269,6 +271,11 @@ struct ONNXToZHighLoweringPass "using the zDNN. The node name is an optional attribute " "in onnx graph, which is `onnx_node_name` in ONNX IR"), llvm::cl::ZeroOrMore}; + Option useCostModel{*this, "use-cost-model", + llvm::cl::desc( + "Whether to use performance cost model to estimate if it is " + " beneficial to map an operation to ZHigh. Default: false"), + llvm::cl::init(false)}; }; } // end anonymous namespace. @@ -317,32 +324,48 @@ void ONNXToZHighLoweringPass::runOnOperation() { // ONNX ops to ZHigh dialect under specific conditions. // When adding a new op, need to implement a method, i.e. isSuitableForZDNN, // for the op in ONNXLegalityCheck.cpp. - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); + &target, &dimAnalysis, useCostModel, execNodesOnCpu); addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); + &target, &dimAnalysis, useCostModel, execNodesOnCpu); addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); - addDynamicallyLegalOpFor(&target, &dimAnalysis, execNodesOnCpu); + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); + addDynamicallyLegalOpFor( + &target, &dimAnalysis, useCostModel, execNodesOnCpu); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` @@ -356,8 +379,9 @@ std::unique_ptr createONNXToZHighPass() { } std::unique_ptr createONNXToZHighPass( - mlir::ArrayRef execNodesOnCpu) { - return std::make_unique(execNodesOnCpu); + mlir::ArrayRef execNodesOnCpu, bool useCostModel) { + return std::make_unique( + execNodesOnCpu, useCostModel); } } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td index 5d6d2ab924..c548cf8da8 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td @@ -301,7 +301,7 @@ def replaceONNXSoftmax3DPattern : Pat< //===----------------------------------------------------------------------===// def IsSoftmaxLegalForZDNN: Constraint< CPred<"isSuitableForZDNN(" # - "dyn_cast_or_null($0.getDefiningOp()))">, + "dyn_cast_or_null($0.getDefiningOp()), false)">, "Softmax is legal for zDNN" >; def replaceONNXLogSoftmaxPattern : Pattern< @@ -461,7 +461,7 @@ def replaceONNXMatMulPattern : Pat< def IsMatMulLegalForZDNN: Constraint< CPred<"isSuitableForZDNN(" # - "dyn_cast_or_null($0.getDefiningOp()))">, + "dyn_cast_or_null($0.getDefiningOp()), false)">, "MatMul is legal for zDNN" >; @@ -983,7 +983,7 @@ def replaceONNXConv2DPattern : Pattern< def IsConv2DLegalForZDNN: Constraint< CPred<"isSuitableForZDNN(" # - "dyn_cast_or_null($0.getDefiningOp()))">, + "dyn_cast_or_null($0.getDefiningOp()), false)">, "Conv is legal for zDNN" >; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index b66c0cc927..df694e5d65 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -21,10 +21,10 @@ template void addDynamicallyLegalOpFor(mlir::ConversionTarget *target, - const onnx_mlir::DimAnalysis *dimAnalysis, + const onnx_mlir::DimAnalysis *dimAnalysis, bool useCostModel, mlir::ArrayRef execNodesOnCpu) { - target->addDynamicallyLegalOp([dimAnalysis, execNodesOnCpu]( - OP_TYPE op) { + target->addDynamicallyLegalOp([dimAnalysis, useCostModel, + execNodesOnCpu](OP_TYPE op) { // Check operations to be forced to run on CPU. mlir::Operation *genericOp = op.getOperation(); mlir::StringAttr nodeName = @@ -56,7 +56,7 @@ void addDynamicallyLegalOpFor(mlir::ConversionTarget *target, if (exceedLimit) return true; - return !isSuitableForZDNN(op, dimAnalysis); + return !isSuitableForZDNN(op, useCostModel, dimAnalysis); }); } diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp index 0e6b3fdb07..3e0f41d40a 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp @@ -463,12 +463,14 @@ struct RewriteONNXForZHighPass } RewriteONNXForZHighPass() = default; - RewriteONNXForZHighPass(mlir::ArrayRef execNodesOnCpu) - : execNodesOnCpu(execNodesOnCpu) {} + RewriteONNXForZHighPass( + mlir::ArrayRef execNodesOnCpu, bool useCostModel) + : execNodesOnCpu(execNodesOnCpu), useCostModel(useCostModel) {} void runOnOperation() final; public: mlir::ArrayRef execNodesOnCpu = mlir::ArrayRef(); + bool useCostModel = false; }; void RewriteONNXForZHighPass::runOnOperation() { @@ -491,7 +493,7 @@ void RewriteONNXForZHighPass::runOnOperation() { // generating `ONNX.Add`, `ONNX.Sub`, `ONNX.Mul`, `ONNX.Div`, // and `ONNX.Sqrt` to calculate inputs(`a` and `b`) addDynamicallyLegalOpFor( - &target, &dimAnalysis, execNodesOnCpu); + &target, &dimAnalysis, useCostModel, execNodesOnCpu); // Illegalize BinaryOp if one of the two inputs is a constant and // unidirectional broadcastable to the other input. Rewrite patterns will be @@ -615,8 +617,9 @@ std::unique_ptr createRewriteONNXForZHighPass() { } std::unique_ptr createRewriteONNXForZHighPass( - mlir::ArrayRef execNodesOnCpu) { - return std::make_unique(execNodesOnCpu); + mlir::ArrayRef execNodesOnCpu, bool useCostModel) { + return std::make_unique( + execNodesOnCpu, useCostModel); } } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.cpp new file mode 100644 index 0000000000..b69d106f42 --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.cpp @@ -0,0 +1,189 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------- ZHighPerfModel.cpp - Deciding ONNX vs ZHigh for ops -------===// +// +// Copyright 2023 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains model info to help decide for the relevant NNPA ops if +// they are faster / slower than their equivalent CPU versions. +// +//===----------------------------------------------------------------------===// + +// hi alex: determine which one are really needed +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.hpp" +#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h" +#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" +#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" +#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" + +#define DEBUG_TYPE "zhigh-perf-model" + +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +// This is the implementation of Excel.ceiling, which result the smallest value +// that is greater or equal to A and is a multiple of B. +inline int64_t roundToNextMultiple(int64_t a, int64_t b) { + // Ceil only works with unsigned number, but shapes are given as signed + // numbers. Do the necessary checks/conversions here. + assert(a >= 0 && "expected nonnegative"); + assert(b > 0 && "expected strictly positive"); + uint64_t aa = a; + uint64_t bb = b; + uint64_t ceilDiv = (aa + bb - 1) / bb; + return ceilDiv * bb; +} + +// Return true with a debug message reporting reason for success on NNPA. +inline bool fasterOnNNPA(Operation *op, std::string msg) { + LLVM_DEBUG({ + llvm::dbgs() << "Faster on NNPA: " << msg << " for op: "; + op->dump(); + }); + return true; +} + +// Return false with a debug message reporting reason for failure on NNPA. +inline bool fasterOnCPU(Operation *op, std::string msg) { + LLVM_DEBUG({ + llvm::dbgs() << "Faster on CPU: " << msg << " for op: "; + op->dump(); + }); + return false; +} + +} // namespace + +bool isElementwiseFasterOnNNPA(Operation *op, Value oper, + const DimAnalysis *dimAnalysis, double relativeNNPASpeedup) { + // At this time, use only 1 of the two + ShapedType operType = oper.getType().dyn_cast_or_null(); + assert(operType && operType.hasRank() && "expected shaped type with rank"); + int64_t operRank = operType.getRank(); + assert(operRank <= 4 && "expected rank <= 4"); + llvm::ArrayRef shape = operType.getShape(); + int64_t e4 = operRank >= 4 ? shape[operRank - 4] : 1; + int64_t e3 = operRank >= 3 ? shape[operRank - 3] : 1; + int64_t e2 = operRank >= 2 ? shape[operRank - 2] : 1; + int64_t e1 = operRank >= 1 ? shape[operRank - 1] : 1; + + // Disqualify if e1 is too small (full is 64, so shoot for half full). + if (e1 > 0 && e1 < 32) + return fasterOnCPU(op, "elementwise has too small e1"); + // If e1 or e2 are runtime, assume they will be large enough. + if (e1 == ShapedType::kDynamic || e2 == ShapedType::kDynamic) + return fasterOnNNPA(op, "elementwise has runtime e1 or e2"); + // If larger dims are runtime, assume it might just be size 1. + if (e3 == ShapedType::kDynamic) + e3 = 1; + if (e4 == ShapedType::kDynamic) + e4 = 1; + // Cedric's spreadsheet calculations. + int64_t computed2dFMA = + e4 * e3 * roundToNextMultiple(e2, 2) * roundToNextMultiple(e1, 64); + computed2dFMA = (double)computed2dFMA * relativeNNPASpeedup; + assert(computed2dFMA > 0 && "dyn size should have been removed"); + // Cedric's model show still significant benefits for 16 full tiles, + // arbitrarily assume cross over at 8. Will need new measurements on this. + if (computed2dFMA < 8 * 2048) + return fasterOnCPU( + op, "elementwise computed 2D FMA is too small (<8 full tiles)"); + return fasterOnNNPA(op, "elementwise has enough computations"); +} + +bool isElementwiseFasterOnNNPA(Operation *op, Value oper, Value rhs, + const DimAnalysis *dimAnalysis, double relativeNNPASpeedup) { + // At this time, we can treat the binary elementwise the same way as an unary + // elementwise. + return isElementwiseFasterOnNNPA(op, oper, dimAnalysis, relativeNNPASpeedup); +} + +bool isMatMulFasterOnNNPA(Operation *op, Value a, Value b, bool aTransposed, + bool bTransposed, const DimAnalysis *dimAnalysis) { + // Scanning A. + ShapedType aType = a.getType().dyn_cast_or_null(); + assert(aType && aType.hasRank() && "expected shaped type with A rank"); + int64_t aRank = aType.getRank(); + assert(aRank >= 2 && aRank <= 3 && "expected A rank 2..3"); + llvm::ArrayRef aShape = aType.getShape(); + int64_t aB = aRank >= 3 ? aShape[aRank - 3] : 1; + int64_t aNIndex = aTransposed ? aRank - 1 : aRank - 2; + int64_t aMIndex = aTransposed ? aRank - 2 : aRank - 1; + int64_t aN = aShape[aNIndex]; + int64_t aM = aShape[aMIndex]; + // Scanning B. + ShapedType bType = b.getType().dyn_cast_or_null(); + assert(bType && bType.hasRank() && "expected shaped type with B rank"); + int64_t bRank = bType.getRank(); + assert(bRank >= 2 && bRank <= 3 && "expected B rank 2..3"); + llvm::ArrayRef bShape = bType.getShape(); + int64_t bB = bRank >= 3 ? bShape[bRank - 3] : 1; + int64_t bMIndex = bTransposed ? bRank - 1 : bRank - 2; + int64_t bKIndex = bTransposed ? bRank - 2 : bRank - 1; + int64_t bM = bShape[bMIndex]; + int64_t bK = bShape[bKIndex]; + assert(aM == bM && "expected M dims to be identical"); + // Rules common to matmul with/without broadcast. + // Make sure the constant lower dim of the matmul are large enough. + if (aM > 0 && aM < 32) + return fasterOnCPU(op, "matmul no-broadcast M dim too small)"); + if (bK > 0 && bK < 32) + return fasterOnCPU(op, "matmul no-broadcast K dim too small)"); + // Assume the dynamic lower dim of the matmul will be large enough. + if (aM == ShapedType::kDynamic || bK == ShapedType::kDynamic) + return fasterOnNNPA(op, "matmul no-broadcast has runtime M or K"); + // Assume the N dim of the matmul will be small. + if (aN == ShapedType::kDynamic) + aN = 1; + // Determine if we have a broadcast (will change cost calculations). + bool hasBroadcast = true; + if (aRank == 2 && bRank == 2) // No broadcast dim. + hasBroadcast = false; + else if (aB == 1 && bB == 1) // No broadcast because both 1. + hasBroadcast = false; + else if (aRank == 3 && bRank == 3 && + dimAnalysis->sameDim(a, aRank - 3, b, bRank - 3)) + hasBroadcast = false; + // Assume the B dim of the matmul will be small. + if (aB == ShapedType::kDynamic) + aB = 1; + if (bB == ShapedType::kDynamic) + bB = 1; + + // Handle case without broadcast. + if (!hasBroadcast) { + int64_t computed2dFMA = aB * roundToNextMultiple(aN, 2) * + roundToNextMultiple(aM, 64) * + roundToNextMultiple(bK, 64); + assert(computed2dFMA > 0 && "dyn size should have been removed"); + // Cedric's model show still benefits for 64^3 == 128 full tiles. + if (computed2dFMA < 64 * 64 * 64) + return fasterOnCPU( + op, "matmul no-broadcast computed 2D FMA is too small (<64^3 flops)"); + return fasterOnNNPA(op, "matmul no-broadcast has enough computations"); + } + // Else we have broadcast. + // Virtual E2/E4 from Cedric's spreadsheet: TODO, need refinement. + int B = bB == 1 ? aB : bB; + int64_t virtualNB; + if (aN >= 128) + virtualNB = aN; // No B? + else if (aM >= 64) + virtualNB = aN * std::min(2, B); + else if (aM >= 32) + virtualNB = aN * std::min(4, B); + else + virtualNB = aN * std::min(8, B); + int64_t computed2dFMA = + virtualNB * roundToNextMultiple(aM, 64) * roundToNextMultiple(bK, 64); + if (computed2dFMA < 64 * 64 * 64) + return fasterOnCPU( + op, "matmul broadcast computed 2D FMA is too small (<64^3 flops)"); + return fasterOnNNPA(op, "matmul broadcast has enough computations"); +} diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.hpp new file mode 100644 index 0000000000..54c424b12e --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighPerfModel.hpp @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------- ZHighPerfModel.hpp - Deciding ONNX vs ZHigh for ops -------===// +// +// Copyright 2023 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains model info to help decide for the relevant NNPA ops if +// they are faster / slower than their equivalent CPU versions. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +// Return true if operation is faster on NNPA than CPU. + +// Result is normalized for add/sub/mul. Operations that have an additional +// advantage on the NNPA vs CPU execution can reflect that advantage via the +// relativeNNPASpeedup ratio. + +// Elementwise with one input operand. +bool isElementwiseFasterOnNNPA(mlir::Operation *op, mlir::Value operand, + const onnx_mlir::DimAnalysis *dimAnalysis, + double relativeNNPASpeedup = 1.0); + +// Elementwise with two input operands, lhs and rhs. +bool isElementwiseFasterOnNNPA(mlir::Operation *op, mlir::Value lhs, + mlir::Value rhs, const onnx_mlir::DimAnalysis *dimAnalysis, + double relativeNNPASpeedup = 1.0); + +bool isMatMulFasterOnNNPA(mlir::Operation *op, mlir::Value a, mlir::Value b, + bool aTransposed, bool bTransposed, + const onnx_mlir::DimAnalysis *dimAnalysis); diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index f994b7c7bc..603da9d330 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -22,12 +22,12 @@ namespace onnx_mlir { /// Add pass for lowering ONNX ops to ZHigh ops. std::unique_ptr createONNXToZHighPass(); std::unique_ptr createONNXToZHighPass( - mlir::ArrayRef execNodesOnCpu); + mlir::ArrayRef execNodesOnCpu, bool useCostModel); /// Add pass for rewriting ONNX ops for ZHigh. std::unique_ptr createRewriteONNXForZHighPass(); std::unique_ptr createRewriteONNXForZHighPass( - mlir::ArrayRef execNodesOnCpu); + mlir::ArrayRef execNodesOnCpu, bool useCostModel); /// Add pass for re-construct ONNX ops from ZHigh ops. std::unique_ptr createZHighToONNXPass(); diff --git a/test/accelerators/NNPA/backend/CMakeLists.txt b/test/accelerators/NNPA/backend/CMakeLists.txt index 41f1fa6b65..2a8ce58878 100644 --- a/test/accelerators/NNPA/backend/CMakeLists.txt +++ b/test/accelerators/NNPA/backend/CMakeLists.txt @@ -92,6 +92,8 @@ endif() # instruction name is added after test case name in each test case. set(NNPA_TEST_LIST + # When changing ==xxx== annotations, please run `make onnx_mlir_supported_ops` + # to re-actualize our support md pages. # ==ARCH== NNPA # ==ADDITIONAL_PARAGRAPH== NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.h](../src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 13f8e20cca..ba68092d77 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -74,6 +74,8 @@ def get_test_models(): # Elementary ops, ordered in the order they are found in # onnx-mlir/third_party/onnx/onnx/backend/test/case/node. + # When changing ==xxx== annotations, please run `make onnx_mlir_supported_ops` + # to re-actualize our support md pages. # ==ARCH== cpu # ==OP== Abs diff --git a/utils/RunONNXModel.py b/utils/RunONNXModel.py index fc55973a7c..062d2685e1 100755 --- a/utils/RunONNXModel.py +++ b/utils/RunONNXModel.py @@ -155,11 +155,13 @@ def check_non_negative(argname, value): " E.g. --upper-bound=int64:10,float32:0.2,uint8:9." " Supported types are bool, uint8, int8, uint16, int16, uint32, int32," " uint64, int64, float16, float32, float64") -parser.add_argument('--warmup', +parser.add_argument('-w', + '--warmup', type=lambda s: check_non_negative("--warmup", s), default=0, help="The number of warmup inference runs") -parser.add_argument('--n-iteration', +parser.add_argument('-n', + '--n-iteration', type=lambda s: check_positive("--n-iteration", s), default=1, help="The number of inference runs excluding warmup") @@ -495,6 +497,10 @@ def verify_outs(actual_outs, ref_outs): def warning(msg): print("Warning:", msg) +def data_without_top_bottom_quartile(data, percent): + data = np.array(sorted(data)) + trim = int(percent*data.size/100.0) + return data[trim:-trim] def main(): if not (args.model or args.load_so): @@ -646,14 +652,19 @@ def main(): end = time.perf_counter() elapsed = end - start perf_results += [elapsed] - print(" {} iteration: {} seconds".format(ordinal(i+1), elapsed)) + print(" {} iteration, {}, seconds".format(ordinal(i+1), elapsed)) # Print statistics info, e.g., min/max/stddev inference time. if args.n_iteration > 1 : - print(" Statistics (excluding warmup):" - " min {}, max {}, mean {}, stddev {}".format( + print(" Statistics 1 (excluding warmup)," + " min, {:.6e}, max, {:.6e}, mean, {:.6e}, stdev, {:.6e}".format( np.min(perf_results), np.max(perf_results), - np.mean(perf_results), np.std(perf_results, dtype=np.float64))) + np.mean(perf_results),np.std(perf_results, dtype=np.float64))) + t_perf_results = data_without_top_bottom_quartile(perf_results, 25) + print(" Statistics 2 (no warmup/quart.)," + " min, {:.6e}, max, {:.6e}, mean, {:.6e}, stdev, {:.6e}".format( + np.min(t_perf_results), np.max(t_perf_results), + np.mean(t_perf_results),np.std(t_perf_results, dtype=np.float64))) # Print the output if required. diff --git a/utils/analyze-simd.py b/utils/analyze-simd.py index ecec4b830c..4e844616e7 100755 --- a/utils/analyze-simd.py +++ b/utils/analyze-simd.py @@ -14,6 +14,7 @@ import re import io import subprocess +from pathlib import Path ################################################################################ # Usage. @@ -25,7 +26,7 @@ def print_usage(msg = ""): dprint("") if msg: dprint("ERROR: " + msg + "\n") - dprint("analyze-simd [-a ] (-c|-m|-o)+ [-n num] [-f pattern] [-dhlp] file") + dprint("analyze-simd [-t ] (-a|-c|-m|-o)+ [-n num] [-f pattern] [-dhlp] file") dprint(" Utility to analyze and print SIMD code located in functions") dprint("") dprint("Pattern:") @@ -59,7 +60,7 @@ def print_usage(msg = ""): print_code = False print_listing = False print_details = False -fct_match_str = r'^main_graph$' +fct_match_str = "" op_dict = {} aggr_dict = {} op_name = {} @@ -426,6 +427,12 @@ def main(argv): # All commands after the file name seems to be added here!!! print_usage("Need an single input file as last option: ", args, ".") filename = args[0] + + name_stub = Path(filename).stem + if not fct_match_str: + fct_match_str = r'^main_graph_' + name_stub + '$' + print("# search default function: main_graph with default tag \""+fct_match_str+"\"") + match_binary = re.match(r'(.*)\.so$', filename) if match_binary: asm_filename = match_binary.group(1) + ".s" @@ -433,6 +440,7 @@ def main(argv): dprint("# generate asm file with: " + cmd) ret = subprocess.call(cmd, shell=True) filename = asm_filename + scan_basic_blocks(filename) buff = scan_for_simd(filename, pattern, num) return