Skip to content

Commit

Permalink
Add NNPA level compatibility check (#2533)
Browse files Browse the repository at this point in the history
* Add NNPA level compatability check

Signed-off-by: Mike Essenmacher <[email protected]>

* Clang format changes

Signed-off-by: Mike Essenmacher <[email protected]>

* Clang format changes

Signed-off-by: Mike Essenmacher <[email protected]>

* Add mcpu to NNPA tests

Signed-off-by: Mike Essenmacher <[email protected]>

* Add link dependency to ONNXToZHigh for OMCompilerOptions

Signed-off-by: Mike Essenmacher <[email protected]>

* Update link dependency

Signed-off-by: Mike Essenmacher <[email protected]>

* Review updates

Signed-off-by: Mike Essenmacher <[email protected]>

* Review updates

Signed-off-by: Mike Essenmacher <[email protected]>

* Resolve conflicts

Signed-off-by: Mike Essenmacher <[email protected]>

* Add mcpu z16 to device_placement_pass_perf_model.mlir

Signed-off-by: Mike Essenmacher <[email protected]>

---------

Signed-off-by: Mike Essenmacher <[email protected]>
Signed-off-by: Mike Essenmacher <[email protected]>
  • Loading branch information
mikeessen authored Oct 3, 2023
1 parent c03bc49 commit 7592edf
Show file tree
Hide file tree
Showing 92 changed files with 219 additions and 93 deletions.
2 changes: 2 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_onnx_mlir_library(OMONNXToZHigh
libzdnn

LINK_LIBS PUBLIC
OMCompilerOptions
OMONNXOps
OMONNXToKrnl
OMZHighOps
Expand All @@ -31,6 +32,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh
libzdnn

LINK_LIBS PUBLIC
OMCompilerOptions
OMONNXOps
OMONNXToKrnl
OMZHighOps
Expand Down
5 changes: 4 additions & 1 deletion src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----------------------- NNPALimit.h ----------------------------------===//
//
// Copyright 2022 The IBM Research Authors.
// Copyright 2022-2023 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -28,3 +28,6 @@ static constexpr int64_t NNPA_MAXIMUM_TENSOR_SIZE = 4294967296;
// See zDNN API doc
static constexpr int64_t MAXIMUM_NUM_HIDDEN_SIZE_LSTM = 8192;
static constexpr int64_t MAXIMUM_NUM_HIDDEN_SIZE_GRU = 10880;

// The NNPA levels.
static constexpr const char *NNPA_Z16 = "z16";
98 changes: 98 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,37 @@

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"

using namespace mlir;
using namespace onnx_mlir;

/// Convert the input NNPA level, ie. "z16", to a floating point value
/// representing the level, ie. "16.0".
float convertNNPALevel(std::string inputNNPALevel) {
float retNNPAFloat = 0;
try {
retNNPAFloat = std::strtof(
inputNNPALevel.substr(1, inputNNPALevel.size()).c_str(), NULL);
} catch (...) {
retNNPAFloat = 0;
}
return retNNPAFloat;
}

/// A function to check whether the input NNPA level, ie. "z16", is compatible
/// with the current NNPA level.
bool isCompatibleWithNNPALevel(std::string inputNNPALevel) {
float inLevel = convertNNPALevel(inputNNPALevel);
float mcpuLevel = convertNNPALevel(mcpu);
if (inLevel == 0 && mcpuLevel == 0)
return false;
return inLevel <= mcpuLevel;
}

/// A function to check whether a value's element type is valid for zAIU or not.
/// zAIU supports only F16, F32 and BFLOAT. Since MLIR does not support BFLOAT,
/// we check F16 and F32 here only. zAIU only supports rank in range of (0, 4].
Expand Down Expand Up @@ -250,6 +274,9 @@ bool isSuitableForZDNN(OP_TYPE op, const DimAnalysis *dimAnalysis) {
template <>
bool isSuitableForZDNN<ONNXAddOp>(
ONNXAddOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getB()))
Expand All @@ -261,6 +288,9 @@ bool isSuitableForZDNN<ONNXAddOp>(
template <>
bool isSuitableForZDNN<ONNXSubOp>(
ONNXSubOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getB()))
Expand All @@ -272,6 +302,9 @@ bool isSuitableForZDNN<ONNXSubOp>(
template <>
bool isSuitableForZDNN<ONNXMulOp>(
ONNXMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getB()))
Expand All @@ -283,6 +316,9 @@ bool isSuitableForZDNN<ONNXMulOp>(
template <>
bool isSuitableForZDNN<ONNXDivOp>(
ONNXDivOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getB()))
Expand All @@ -294,6 +330,9 @@ bool isSuitableForZDNN<ONNXDivOp>(
template <>
bool isSuitableForZDNN<ONNXSumOp>(
ONNXSumOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
// Do not support a single input.
if (op.getData_0().size() < 2)
return false;
Expand All @@ -316,6 +355,9 @@ bool isSuitableForZDNN<ONNXSumOp>(
template <>
bool isSuitableForZDNN<ONNXMinOp>(
ONNXMinOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
int64_t opnum = op.getNumOperands();
if (opnum != 2) {
return false;
Expand All @@ -332,6 +374,9 @@ bool isSuitableForZDNN<ONNXMinOp>(
template <>
bool isSuitableForZDNN<ONNXMaxOp>(
ONNXMaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
int64_t opnum = op.getNumOperands();
if (opnum != 2) {
return false;
Expand All @@ -349,6 +394,9 @@ bool isSuitableForZDNN<ONNXMaxOp>(
template <>
bool isSuitableForZDNN<ONNXSoftmaxOp>(
ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getInput()))
return false;
ShapedType inputType = op.getType().cast<ShapedType>();
Expand All @@ -363,6 +411,9 @@ bool isSuitableForZDNN<ONNXSoftmaxOp>(
template <>
bool isSuitableForZDNN<ONNXReluOp>(
ONNXReluOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getX()))
return false;
ShapedType xType = op.getX().getType().cast<ShapedType>();
Expand All @@ -373,6 +424,9 @@ bool isSuitableForZDNN<ONNXReluOp>(
template <>
bool isSuitableForZDNN<ONNXTanhOp>(
ONNXTanhOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getInput()))
return false;
ShapedType inputType = op.getType().cast<ShapedType>();
Expand All @@ -383,6 +437,9 @@ bool isSuitableForZDNN<ONNXTanhOp>(
template <>
bool isSuitableForZDNN<ONNXSigmoidOp>(
ONNXSigmoidOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getX()))
return false;
ShapedType xType = op.getX().getType().cast<ShapedType>();
Expand All @@ -393,6 +450,9 @@ bool isSuitableForZDNN<ONNXSigmoidOp>(
template <>
bool isSuitableForZDNN<ONNXLogOp>(
ONNXLogOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getInput()))
return false;
ShapedType inputType = op.getInput().getType().cast<ShapedType>();
Expand All @@ -403,6 +463,9 @@ bool isSuitableForZDNN<ONNXLogOp>(
template <>
bool isSuitableForZDNN<ONNXExpOp>(
ONNXExpOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isValidElementTypeAndRank(op.getInput()))
return false;
ShapedType inputType = op.getInput().getType().cast<ShapedType>();
Expand All @@ -413,6 +476,9 @@ bool isSuitableForZDNN<ONNXExpOp>(
template <>
bool isSuitableForZDNN<ONNXMatMulOp>(
ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
int64_t opnum = op.getNumOperands();
if (opnum != 2) {
return false;
Expand Down Expand Up @@ -467,6 +533,10 @@ bool isSuitableForZDNN<ONNXGemmOp>(
Value B = op.getB();
Value C = op.getC();

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check data type.
if (!isValidElementTypeAndRank(A))
return false;
Expand Down Expand Up @@ -519,6 +589,10 @@ bool isSuitableForZDNN<ONNXGemmOp>(
template <>
bool isSuitableForZDNN<ONNXReduceMeanV13Op>(
ONNXReduceMeanV13Op op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check data type.
if (!isValidElementTypeAndRank(op.getData()))
return false;
Expand Down Expand Up @@ -560,6 +634,10 @@ bool isSuitableForZDNN<ONNXLSTMOp>(
Value R = op.getR();
Value B = op.getB();

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check direction.
if ((direction != FORWARD) && (direction != REVERSE) &&
(direction != BIDIRECTIONAL))
Expand Down Expand Up @@ -635,6 +713,10 @@ bool isSuitableForZDNN<ONNXGRUOp>(
Value R = op.getR();
Value B = op.getB();

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check direction.
if ((direction != FORWARD) && (direction != REVERSE) &&
(direction != BIDIRECTIONAL))
Expand Down Expand Up @@ -702,6 +784,10 @@ bool isSuitableForZDNN<ONNXGRUOp>(
template <>
bool isSuitableForZDNN<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check data type.
if (!isValidElementTypeAndRank(op.getX()))
return false;
Expand All @@ -725,6 +811,10 @@ bool isSuitableForZDNN<ONNXMaxPoolSingleOutOp>(
template <>
bool isSuitableForZDNN<ONNXAveragePoolOp>(
ONNXAveragePoolOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check data type.
if (!isValidElementTypeAndRank(op.getX()))
return false;
Expand Down Expand Up @@ -782,6 +872,10 @@ static bool checkConv2DParamRestrictions(int64_t inputDim, int64_t kernelDim,
template <>
bool isSuitableForZDNN<ONNXConvOp>(
ONNXConvOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// Check data type.
if (!isValidElementTypeAndRank(op.getX()))
return false;
Expand Down Expand Up @@ -864,6 +958,10 @@ bool isSuitableForZDNN<ONNXBatchNormalizationInferenceModeOp>(
ArrayRef<int64_t> shapeInput = inputType.getShape();
ArrayRef<int64_t> shapeOutput = outputType.getShape();

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;

// 4D tensors(N x C x H x W) are supported as input and output.
if (shapeInput.size() != 4 || shapeOutput.size() != 4)
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===---------- ONNXLegalityCheck.hpp - Check legality for ONNX ops -------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2023 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -26,6 +26,10 @@ template <typename OP_TYPE>
bool isSuitableForZDNN(
OP_TYPE op, const onnx_mlir::DimAnalysis *dimAnalysis = nullptr);

/// Check if the input NNPA level is compatible with the current NNPA
/// level.
bool isCompatibleWithNNPALevel(std::string inputNNPALevel);

/// Get padding type using shape helper. This returns
/// `SAME_PADDING`, `VALID_PADDING`, or empty.
template <typename OP, typename OPAdaptor, typename OPShapeHelper>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ void getRewriteONNXForZHighDynamicallyLegal(
// broadcasting.
addDynamicallyLegalOpFor<ONNXAddOp>(
target, dimAnalysis, [](ONNXAddOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return true;
return !((isDefinedByONNXConstantOp(op.getA()) &&
isUniBroadcatableFirstToSecond(op.getA(), op.getB())) ||
(isDefinedByONNXConstantOp(op.getB()) &&
Expand All @@ -508,20 +511,29 @@ void getRewriteONNXForZHighDynamicallyLegal(
});
addDynamicallyLegalOpFor<ONNXDivOp>(
target, dimAnalysis, [](ONNXDivOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return true;
return !((isDefinedByONNXConstantOp(op.getA()) &&
isUniBroadcatableFirstToSecond(op.getA(), op.getB())) ||
(isDefinedByONNXConstantOp(op.getB()) &&
isUniBroadcatableFirstToSecond(op.getB(), op.getA())));
});
addDynamicallyLegalOpFor<ONNXMulOp>(
target, dimAnalysis, [](ONNXMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return true;
return !((isDefinedByONNXConstantOp(op.getA()) &&
isUniBroadcatableFirstToSecond(op.getA(), op.getB())) ||
(isDefinedByONNXConstantOp(op.getB()) &&
isUniBroadcatableFirstToSecond(op.getB(), op.getA())));
});
addDynamicallyLegalOpFor<ONNXSubOp>(
target, dimAnalysis, [](ONNXSubOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return true;
return !((isDefinedByONNXConstantOp(op.getA()) &&
isUniBroadcatableFirstToSecond(op.getA(), op.getB())) ||
(isDefinedByONNXConstantOp(op.getB()) &&
Expand All @@ -540,6 +552,9 @@ void getRewriteONNXForZHighDynamicallyLegal(
// one where N-D will become 3-D or to split MatMul into smaller MatMuls.
addDynamicallyLegalOpFor<ONNXMatMulOp>(
target, dimAnalysis, [](ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return true;
Type aType = op.getA().getType();
Type bType = op.getB().getType();
if (!isRankedShapedType(aType) || !isRankedShapedType(bType))
Expand Down Expand Up @@ -579,10 +594,14 @@ void getRewriteONNXForZHighDynamicallyLegal(
});

// Illegalize SoftmaxOp if
// - the NNPA level is not compatible, or
// - axis is the last dimension.
// This SoftmaxOp will be rewritten in which its input is reshaped to 3D.
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(target, dimAnalysis,
[](ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return true;
Value input = op.getInput();
if (auto shapedType = input.getType().dyn_cast<RankedTensorType>()) {
if ((shapedType.getRank() > 3) &&
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --maccel=NNPA --onnx-dim-analysis %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --onnx-dim-analysis %s -split-input-file | FileCheck %s

// COM: test zdnn unary operations. Use Relu as a sample.
func.func @test_stick_unary_unstick(%arg0 : tensor<?x3x?xf32>) -> tensor<?x3x?xf32> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --device-placement --maccel=NNPA --split-input-file %s | FileCheck %s
// RUN: onnx-mlir-opt --device-placement --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s

module attributes {llvm.data_layout = "E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64", llvm.target_triple = "s390x-ibm-linux", "onnx-mlir.symbol-postfix" = "model"} {
func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} {
Expand Down
Loading

0 comments on commit 7592edf

Please sign in to comment.