Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNPA] Use device attribute to control device placement for ONNX operations #2510

Merged
merged 11 commits into from
Sep 21, 2023
9 changes: 4 additions & 5 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ using namespace onnx_mlir;

namespace onnx_mlir {

void addONNXToZHighPasses(
mlir::PassManager &pm, ArrayRef<std::string> execNodesOnCpu) {
void addONNXToZHighPasses(mlir::PassManager &pm) {
for (unsigned i = 0; i < 3; i++) {
// Repeat this process so that shape-related ops such as Shape, Expand,
// Gather generated during RewriteONNXForZHigh will become constants.
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass());
// Simplify shape-related ops, including ShapeOp-to-DimOp replacement,
// constant propagation, shape inference and canonicalize.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());
Expand All @@ -75,7 +74,7 @@ void addONNXToZHighPasses(
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));

pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createONNXToZHighPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
// There are more opportunities for const propagation once all zhigh ops were
// generated.
Expand Down Expand Up @@ -155,7 +154,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

if (emissionTarget >= EmitMLIR) {
// Lower zAIU-compatible ONNX ops to ZHigh dialect where possible.
addONNXToZHighPasses(pm, execNodesOnCpu);
addONNXToZHighPasses(pm);

if (nnpaEmissionTarget >= EmitZHighIR)
emissionTarget = EmitMLIR;
Expand Down
63 changes: 21 additions & 42 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,7 @@ struct ONNXToZHighLoweringPass
ONNXToZHighLoweringPass() = default;
ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass)
: PassWrapper<ONNXToZHighLoweringPass, OperationPass<ModuleOp>>() {}
ONNXToZHighLoweringPass(mlir::ArrayRef<std::string> execNodesOnCpu) {
this->execNodesOnCpu = execNodesOnCpu;
}
void runOnOperation() final;

public:
ListOption<std::string> execNodesOnCpu{*this, "execNodesOnCpu",
llvm::cl::desc("Comma-separated list of node names in an onnx graph. The "
"specified nodes are forced to run on the CPU instead of "
"using the zDNN. The node name is an optional attribute "
"in onnx graph, which is `onnx_node_name` in ONNX IR"),
llvm::cl::ZeroOrMore};
};
} // end anonymous namespace.

Expand Down Expand Up @@ -317,32 +306,27 @@ void ONNXToZHighLoweringPass::runOnOperation() {
// ONNX ops to ZHigh dialect under specific conditions.
// When adding a new op, need to implement a method, i.e. isSuitableForZDNN,
// for the op in ONNXLegalityCheck.cpp.
addDynamicallyLegalOpFor<ONNXAddOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSubOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMulOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXDivOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSumOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMinOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMaxOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXReluOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXTanhOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSigmoidOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXLogOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXExpOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMaxPoolSingleOutOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXAveragePoolOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMatMulOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXGemmOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXReduceMeanV13Op>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXLSTMOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXGRUOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXConvOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXAddOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSubOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMulOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXDivOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSumOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMinOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMaxOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXReluOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXTanhOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSigmoidOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXLogOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXExpOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMaxPoolSingleOutOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXAveragePoolOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXMatMulOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXGemmOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXReduceMeanV13Op>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXLSTMOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXGRUOp>(&target, &dimAnalysis);
addDynamicallyLegalOpFor<ONNXConvOp>(&target, &dimAnalysis);

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
Expand All @@ -355,9 +339,4 @@ std::unique_ptr<Pass> createONNXToZHighPass() {
return std::make_unique<ONNXToZHighLoweringPass>();
}

std::unique_ptr<Pass> createONNXToZHighPass(
mlir::ArrayRef<std::string> execNodesOnCpu) {
return std::make_unique<ONNXToZHighLoweringPass>(execNodesOnCpu);
}

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "src/Dialect/ONNX/DialectBuilder.hpp"

using namespace mlir;
using namespace onnx_mlir;
namespace onnx_mlir {

/// Get transposed tensor by using a permutation array.
Value emitONNXTranspose(
Expand Down Expand Up @@ -67,3 +67,5 @@ ValueRange splitAlongAxis(
ValueRange splits = create.onnx.split(splitTy, X, splitSizes, axis);
return splits;
}

} // namespace onnx_mlir
42 changes: 29 additions & 13 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,44 @@

#pragma once

#include "llvm/ADT/STLExtras.h"

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/NNPALimit.h"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"

namespace onnx_mlir {

const std::string CPU_DEVICE = "cpu";
const std::string NNPA_DEVICE = "nnpa";

template <typename OP_TYPE>
void addDynamicallyLegalOpFor(mlir::ConversionTarget *target,
const onnx_mlir::DimAnalysis *dimAnalysis,
mlir::ArrayRef<std::string> execNodesOnCpu) {
target->addDynamicallyLegalOp<OP_TYPE>([dimAnalysis, execNodesOnCpu](
llvm::function_ref<bool(OP_TYPE, const DimAnalysis *)> checkLegalityFn =
nullptr) {
target->addDynamicallyLegalOp<OP_TYPE>([dimAnalysis, checkLegalityFn](
OP_TYPE op) {
// Check operations to be forced to run on CPU.
mlir::Operation *genericOp = op.getOperation();
mlir::StringAttr nodeName =
genericOp->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName) {
bool exists =
llvm::any_of(execNodesOnCpu, [nodeName](llvm::StringRef val) {
return nodeName.getValue().equals_insensitive(val);
});
if (exists)
return true;
}
mlir::StringAttr device =
genericOp->getAttrOfType<mlir::StringAttr>("device");
// If device is CPU, force to run the op on CPU.
if (device && device.getValue().equals_insensitive(CPU_DEVICE))
return true;
// If device is NNPA, force to run the op on NNPA.
if (device && device.getValue().equals_insensitive(NNPA_DEVICE))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to do a check of legality check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like isNNPA() && isLegality()? device=NNPA can be forcing to NNPA or maybe good for NNPA. I am OK going with one of them.

Forcing to NNPA is convenient when we annotate an op with device=NNPA directly and we really want that op go to NNPA despite of compiler optimizations.

maybe good for NNPA is safe when we use a cost model, since the cost model may have a mistake in assigning an op to NNPA (e.g. that op is not suitable for NNPA)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forcing to NNPA is also useful when we have dynamic shape, and we want an op to run NNPA because the compiler is not able to know if it is suitable for CPU or NNPA.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have been clearer

assert(isLegal(xxx) && "trying to force an op to NNPA that is not perceived as legal for NNPA");

return false;
// If device is empty, let the compiler makes decision.
assert((!device || (device && device.getValue().equals_insensitive(""))) &&
"Invalid device name");

// Use the user legality check if it's given.
if (checkLegalityFn)
return checkLegalityFn(op, dimAnalysis);

// Check zDNN limitations
// Check zDNN limitations for each input tensors.
// TODO: Check tensor size NNPA_MAXIMUM_TENSOR_SIZE of another limitation
bool exceedLimit =
llvm::any_of(genericOp->getOperands(), [](mlir::Value operand) {
Expand All @@ -53,6 +66,7 @@ void addDynamicallyLegalOpFor(mlir::ConversionTarget *target,
}
return false;
});
// There is a tensor whose size exceeds zDNN limitations, run the op on CPU.
if (exceedLimit)
return true;

Expand All @@ -75,3 +89,5 @@ mlir::Value emitONNXTransposeWithType(mlir::Location loc,
mlir::ValueRange splitAlongAxis(
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> &create,
mlir::Value X, int64_t axis);

} // namespace onnx_mlir
Loading
Loading