Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNPA] Use device attribute to control device placement for ONNX operations #2510

Merged
merged 11 commits into from
Sep 21, 2023
9 changes: 4 additions & 5 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ using namespace onnx_mlir;

namespace onnx_mlir {

void addONNXToZHighPasses(
mlir::PassManager &pm, ArrayRef<std::string> execNodesOnCpu) {
void addONNXToZHighPasses(mlir::PassManager &pm) {
for (unsigned i = 0; i < 3; i++) {
// Repeat this process so that shape-related ops such as Shape, Expand,
// Gather generated during RewriteONNXForZHigh will become constants.
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass());
// Simplify shape-related ops, including ShapeOp-to-DimOp replacement,
// constant propagation, shape inference and canonicalize.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());
Expand All @@ -75,7 +74,7 @@ void addONNXToZHighPasses(
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));

pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createONNXToZHighPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
// There are more opportunities for const propagation once all zhigh ops were
// generated.
Expand Down Expand Up @@ -155,7 +154,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

if (emissionTarget >= EmitMLIR) {
// Lower zAIU-compatible ONNX ops to ZHigh dialect where possible.
addONNXToZHighPasses(pm, execNodesOnCpu);
addONNXToZHighPasses(pm);

if (nnpaEmissionTarget >= EmitZHighIR)
emissionTarget = EmitMLIR;
Expand Down
63 changes: 21 additions & 42 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,7 @@ struct ONNXToZHighLoweringPass
ONNXToZHighLoweringPass() = default;
ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass)
: PassWrapper<ONNXToZHighLoweringPass, OperationPass<ModuleOp>>() {}
ONNXToZHighLoweringPass(mlir::ArrayRef<std::string> execNodesOnCpu) {
this->execNodesOnCpu = execNodesOnCpu;
}
void runOnOperation() final;

public:
ListOption<std::string> execNodesOnCpu{*this, "execNodesOnCpu",
llvm::cl::desc("Comma-separated list of node names in an onnx graph. The "
"specified nodes are forced to run on the CPU instead of "
"using the zDNN. The node name is an optional attribute "
"in onnx graph, which is `onnx_node_name` in ONNX IR"),
llvm::cl::ZeroOrMore};
};
} // end anonymous namespace.

Expand Down Expand Up @@ -317,32 +306,27 @@ void ONNXToZHighLoweringPass::runOnOperation() {
// ONNX ops to ZHigh dialect under specific conditions.
// When adding a new op, need to implement a method, i.e. isSuitableForZDNN,
// for the op in ONNXLegalityCheck.cpp.
addDynamicallyLegalOpFor<ONNXAddOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSubOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMulOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXDivOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSumOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMinOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMaxOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXReluOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXTanhOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSigmoidOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXLogOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXExpOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMaxPoolSingleOutOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXAveragePoolOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMatMulOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXGemmOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXReduceMeanV13Op>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXLSTMOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXGRUOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXConvOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXAddOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSubOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMulOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXDivOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSumOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMinOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMaxOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXReluOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXTanhOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSigmoidOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXLogOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXExpOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMaxPoolSingleOutOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXAveragePoolOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMatMulOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXGemmOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXReduceMeanV13Op>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXLSTMOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXGRUOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXConvOp>(&target, &dimAnalysis);

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
Expand All @@ -355,9 +339,4 @@ std::unique_ptr<Pass> createONNXToZHighPass() {
return std::make_unique<ONNXToZHighLoweringPass>();
}

std::unique_ptr<Pass> createONNXToZHighPass(
mlir::ArrayRef<std::string> execNodesOnCpu) {
return std::make_unique<ONNXToZHighLoweringPass>(execNodesOnCpu);
}

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "src/Dialect/ONNX/DialectBuilder.hpp"

using namespace mlir;
using namespace onnx_mlir;
namespace onnx_mlir {

/// Get transposed tensor by using a permutation array.
Value emitONNXTranspose(
Expand Down Expand Up @@ -67,3 +67,5 @@ ValueRange splitAlongAxis(
ValueRange splits = create.onnx.split(splitTy, X, splitSizes, axis);
return splits;
}

} // namespace onnx_mlir
91 changes: 62 additions & 29 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,80 @@

#pragma once

#include "llvm/ADT/STLExtras.h"

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"

namespace onnx_mlir {

const std::string CPU_DEVICE = "cpu";
const std::string NNPA_DEVICE = "nnpa";

template <typename OP_TYPE>
void addDynamicallyLegalOpFor(mlir::ConversionTarget *target,
const onnx_mlir::DimAnalysis *dimAnalysis,
mlir::ArrayRef<std::string> execNodesOnCpu) {
target->addDynamicallyLegalOp<OP_TYPE>([dimAnalysis, execNodesOnCpu](
llvm::function_ref<bool(OP_TYPE, const DimAnalysis *)> checkLegalityFn =
nullptr) {
target->addDynamicallyLegalOp<OP_TYPE>([dimAnalysis, checkLegalityFn](
OP_TYPE op) {
// Check operations to be forced to run on CPU.
mlir::Operation *genericOp = op.getOperation();
mlir::StringAttr nodeName =
genericOp->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName) {
bool exists =
llvm::any_of(execNodesOnCpu, [nodeName](llvm::StringRef val) {
return nodeName.getValue().equals_insensitive(val);
mlir::StringAttr device =
genericOp->getAttrOfType<mlir::StringAttr>("device");
assert((!device ||
(device &&
(device.getValue().equals_insensitive("") ||
device.getValue().equals_insensitive(CPU_DEVICE) ||
device.getValue().equals_insensitive(NNPA_DEVICE)))) &&
"Invalid device name");

// If device is CPU, force to run the op on CPU.
if (device && device.getValue().equals_insensitive(CPU_DEVICE))
return true;

// If not CPU, check if the op is legal for NNPA.
bool isLegalForNNPA = false;
if (checkLegalityFn)
isLegalForNNPA = !checkLegalityFn(op, dimAnalysis);
else {
// Check zDNN limitations for each input tensors.
// TODO: Check tensor size NNPA_MAXIMUM_TENSOR_SIZE of another limitation
bool exceedLimit =
llvm::any_of(genericOp->getOperands(), [](mlir::Value operand) {
if (auto valueType =
operand.getType().dyn_cast<mlir::ShapedType>()) {
// Check if static dimension size exceeds zDNN limitations
llvm::ArrayRef<int64_t> valueShape = valueType.getShape();
if (llvm::any_of(valueShape, [](int64_t dim) {
return (!mlir::ShapedType::isDynamic(dim)) &&
(dim > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE);
}))
return true;
}
return false;
});
if (exists)
return true;
isLegalForNNPA =
!exceedLimit && isSuitableForZDNN<OP_TYPE>(op, dimAnalysis);
}

// Check zDNN limitations
// TODO: Check tensor size NNPA_MAXIMUM_TENSOR_SIZE of another limitation
bool exceedLimit =
llvm::any_of(genericOp->getOperands(), [](mlir::Value operand) {
if (auto valueType = operand.getType().dyn_cast<mlir::ShapedType>()) {
// Check if static dimension size exceeds zDNN limitations
llvm::ArrayRef<int64_t> valueShape = valueType.getShape();
if (llvm::any_of(valueShape, [](int64_t dim) {
return (!mlir::ShapedType::isDynamic(dim)) &&
(dim > NNPA_MAXIMUM_DIMENSION_INDEX_SIZE);
}))
return true;
}
return false;
});
if (exceedLimit)
return true;
// Users specified NNPA device of an op, but the compiler found the op is
// not legal for NNPA, e.g. in case of dynamic shape. In this case, print
// out a warning message.
if (device && device.getValue().equals_insensitive(NNPA_DEVICE) &&
!isLegalForNNPA) {
llvm::outs() << "Warning: though the following operation was specified "
"to run on NNPA, the compiler found that NNPA did not "
"support that operation. It's potentially that the "
"compiler was not able to check broadcasting in case of "
"dynamic shape so that it thought the operation was not "
"legal for NNPA.\n";
op.dump();
return false;
}

return !isSuitableForZDNN<OP_TYPE>(op, dimAnalysis);
return !isLegalForNNPA;
});
}

Expand All @@ -75,3 +106,5 @@ mlir::Value emitONNXTransposeWithType(mlir::Location loc,
mlir::ValueRange splitAlongAxis(
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> &create,
mlir::Value X, int64_t axis);

} // namespace onnx_mlir
Loading
Loading