Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNPA] Use device attribute to control device placement for ONNX operations #2510

Merged
merged 11 commits into from
Sep 21, 2023
13 changes: 7 additions & 6 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ using namespace onnx_mlir;

namespace onnx_mlir {

void addONNXToZHighPasses(
mlir::PassManager &pm, ArrayRef<std::string> execNodesOnCpu) {
void addONNXToZHighPasses(mlir::PassManager &pm) {
for (unsigned i = 0; i < 3; i++) {
// Repeat this process so that shape-related ops such as Shape, Expand,
// Gather generated during RewriteONNXForZHigh will become constants.
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass());
// Simplify shape-related ops, including ShapeOp-to-DimOp replacement,
// constant propagation, shape inference and canonicalize.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());
Expand All @@ -75,7 +74,7 @@ void addONNXToZHighPasses(
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));

pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createONNXToZHighPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
// There are more opportunities for const propagation once all zhigh ops were
// generated.
Expand Down Expand Up @@ -150,12 +149,14 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
// InputIRLevelType inputIRLevel = determineInputIRLevel(module);

// LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;);
if (emissionTarget >= EmitONNXIR)
if (emissionTarget >= EmitONNXIR) {
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty());
pm.addPass(onnx_mlir::createDevicePlacementPass());
}

if (emissionTarget >= EmitMLIR) {
// Lower zAIU-compatible ONNX ops to ZHigh dialect where possible.
addONNXToZHighPasses(pm, execNodesOnCpu);
addONNXToZHighPasses(pm);

if (nnpaEmissionTarget >= EmitZHighIR)
emissionTarget = EmitMLIR;
Expand Down
15 changes: 15 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ add_onnx_mlir_library(OMZHighToONNX
ACCEL_INCLUDE_DIRS PRIVATE
${NNPA_INCLUDE_PATH}
)

add_onnx_mlir_library(OMDevicePlacement
DevicePlacement.cpp

DEPENDS
libzdnn

LINK_LIBS PUBLIC
OMONNXOps
OMONNXToZHigh
OMRewriteONNXForZHigh

ACCEL_INCLUDE_DIRS PRIVATE
${NNPA_INCLUDE_PATH}
)
129 changes: 129 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===-------- DevicePlacement.cpp - Device Placement for NNPA -------------===//
//
// Copyright 2023 The IBM Research Authors.
//
// =============================================================================
//
// This pass is to set device (CPU, or NNPA) for each operation in ONNX level.
// Device placement can be decided by:
// - user configuration file if given
// - a cost model
//
//===----------------------------------------------------------------------===//

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"

#define DEBUG_TYPE "device-placement"

using namespace mlir;
using namespace onnx_mlir;

namespace {

struct DevicePlacementPass
: public PassWrapper<DevicePlacementPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DevicePlacementPass)

StringRef getArgument() const override { return "device-placement"; }

StringRef getDescription() const override {
return "Device placement for NNPA";
}

void runOnOperation() final;
};

void DevicePlacementPass::runOnOperation() {
using OpSetType = DenseSet<Operation *>;
ModuleOp module = getOperation();
MLIRContext *context = &getContext();

// Run the unknown dimension analysis to help check equality of unknown
// dimensions at compile time.
DimAnalysis dimAnalysis(module);
dimAnalysis.analyze();

// Cost model and user configuration file go here if it's given.
// (Reserved for cost model and user configuration file)

// Run patterns that converts ONNX to ZHigh with analysis mode to collect
// operations that are not converted. Those non-converted ops are running on
// the host instead of accelerator.
// Keep the order of calling pass synced with RewriteONNXForZHigh.cpp and
// ONNXToZHigh.cpp.

OpSetType legalizedOps1, legalizedOps2, legalizedOps3;

ConversionTarget target(*context);
target.addLegalDialect<ONNXDialect, func::FuncDialect, arith::ArithDialect>();

// Call RewriteONNXForZHigh pass.
RewritePatternSet Patterns1(context);
getRewriteONNXForZHighPatterns(Patterns1, &dimAnalysis);
getRewriteONNXForZHighDynamicallyLegal(&target, &dimAnalysis);
(void)applyAnalysisConversion(
module, target, std::move(Patterns1), legalizedOps1);

// Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh.
// E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv.
RewritePatternSet Patterns2(context);
getONNXToZHighOneOpPatterns(Patterns2);
(void)applyAnalysisConversion(
module, target, std::move(Patterns2), legalizedOps2);

// Call ONNXToZHigh pass for lowering a single ONNX op to ZHigh.
RewritePatternSet Patterns3(context);
getONNXToZHighOneOpPatterns(Patterns3);
getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis);
(void)applyAnalysisConversion(
module, target, std::move(Patterns3), legalizedOps3);

// Get the legalized ops that will run on the host.
OpSetType cpuOps = llvm::set_intersection<OpSetType, OpSetType>(
legalizedOps1, llvm::set_intersection<OpSetType, OpSetType>(
legalizedOps2, legalizedOps3));

// Now annotate accelerator operations in the IR with `device` attribute.
module.walk([&](Operation *op) -> WalkResult {
if (op->getDialect()->getNamespace() != ONNXDialect::getDialectNamespace())
return WalkResult::advance();
// No annotation for these ops.
if (isa<ONNXEntryPointOp, ONNXReturnOp, ONNXConstantOp>(op))
return WalkResult::advance();
// If `device` is already set, respect it.
StringAttr device = op->getAttrOfType<mlir::StringAttr>(DEVICE_ATTRIBUTE);
if (device && !device.getValue().empty())
return WalkResult::advance();
// Otherwise, set device.
if (!cpuOps.contains(op))
op->setAttr(DEVICE_ATTRIBUTE, StringAttr::get(context, NNPA_DEVICE));
return WalkResult::advance();
});
}

} // namespace

namespace onnx_mlir {

/*!
* Create a DevicePlacement pass.
*/
std::unique_ptr<mlir::Pass> createDevicePlacementPass() {
return std::make_unique<DevicePlacementPass>();
}

} // namespace onnx_mlir
95 changes: 47 additions & 48 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//
//===----------------------------------------------------------------------===//

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
Expand Down Expand Up @@ -257,21 +258,49 @@ struct ONNXToZHighLoweringPass
ONNXToZHighLoweringPass() = default;
ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass)
: PassWrapper<ONNXToZHighLoweringPass, OperationPass<ModuleOp>>() {}
ONNXToZHighLoweringPass(mlir::ArrayRef<std::string> execNodesOnCpu) {
this->execNodesOnCpu = execNodesOnCpu;
}
void runOnOperation() final;

public:
ListOption<std::string> execNodesOnCpu{*this, "execNodesOnCpu",
llvm::cl::desc("Comma-separated list of node names in an onnx graph. The "
"specified nodes are forced to run on the CPU instead of "
"using the zDNN. The node name is an optional attribute "
"in onnx graph, which is `onnx_node_name` in ONNX IR"),
llvm::cl::ZeroOrMore};
};
} // end anonymous namespace.

void getONNXToZHighOneOpPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
populateWithGenerated(patterns);
patterns.insert<ONNXSumOpPatternEnhancedRecursion>(context);
}

void getONNXToZHighOneOpDynamicallyLegal(
ConversionTarget *target, const DimAnalysis *dimAnalysis) {
addDynamicallyLegalOpFor<ONNXAddOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXSubOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXMulOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXDivOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXSumOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXMinOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXMaxOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXReluOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXTanhOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXSigmoidOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXLogOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXExpOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXMaxPoolSingleOutOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXAveragePoolOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXMatMulOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXGemmOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXReduceMeanV13Op>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXLSTMOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXGRUOp>(target, dimAnalysis);
addDynamicallyLegalOpFor<ONNXConvOp>(target, dimAnalysis);
}

void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.insert<replaceONNXMatMulAddPattern1>(context);
patterns.insert<replaceONNXMatMulAddPattern2>(context);
patterns.insert<replaceONNXReluConvPattern>(context);
patterns.insert<replaceONNXLogSoftmaxPattern>(context);
}

void ONNXToZHighLoweringPass::runOnOperation() {
ModuleOp module = getOperation();

Expand All @@ -289,25 +318,25 @@ void ONNXToZHighLoweringPass::runOnOperation() {
target.addLegalDialect<ONNXDialect, zhigh::ZHighDialect, KrnlDialect,
func::FuncDialect, arith::ArithDialect>();

// NOTE: if we change the order of calling combinedPatterns and single op
// patterns, make sure to change the order in DevicePlacement.cpp also to make
// them synced.

// Combined ONNX ops to ZHigh lowering.
// There are some combinations of ONNX ops that can be lowering into a single
// ZHigh op, e.g. ONNXMatMul and ONNXAdd can be lowered to ZHighMatmul.
// The lowering of such combinations should be done before the lowering of
// a single ONNX Op, because the single op lowering might have conditions that
// prohibit the combined ops lowering happened.
RewritePatternSet combinedPatterns(&getContext());
combinedPatterns.insert<replaceONNXMatMulAddPattern1>(&getContext());
combinedPatterns.insert<replaceONNXMatMulAddPattern2>(&getContext());
combinedPatterns.insert<replaceONNXReluConvPattern>(&getContext());
combinedPatterns.insert<replaceONNXLogSoftmaxPattern>(&getContext());
onnx_mlir::getONNXToZHighMultipleOpPatterns(combinedPatterns);

// It's ok to fail.
(void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns));

// Single ONNX to ZHigh operation lowering.
RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns.insert<ONNXSumOpPatternEnhancedRecursion>(&getContext());
onnx_mlir::getONNXToZHighOneOpPatterns(patterns);

// This is to make sure we don't want to alloc any MemRef at this high-level
// representation.
Expand All @@ -317,32 +346,7 @@ void ONNXToZHighLoweringPass::runOnOperation() {
// ONNX ops to ZHigh dialect under specific conditions.
// When adding a new op, need to implement a method, i.e. isSuitableForZDNN,
// for the op in ONNXLegalityCheck.cpp.
addDynamicallyLegalOpFor<ONNXAddOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSubOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMulOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXDivOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSumOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMinOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMaxOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXReluOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXTanhOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSigmoidOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXLogOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXExpOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXSoftmaxOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMaxPoolSingleOutOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXAveragePoolOp>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXMatMulOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXGemmOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXReduceMeanV13Op>(
&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXLSTMOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXGRUOp>(&target, &dimAnalysis, execNodesOnCpu);
addDynamicallyLegalOpFor<ONNXConvOp>(&target, &dimAnalysis, execNodesOnCpu);
getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis);

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
Expand All @@ -355,9 +359,4 @@ std::unique_ptr<Pass> createONNXToZHighPass() {
return std::make_unique<ONNXToZHighLoweringPass>();
}

std::unique_ptr<Pass> createONNXToZHighPass(
mlir::ArrayRef<std::string> execNodesOnCpu) {
return std::make_unique<ONNXToZHighLoweringPass>(execNodesOnCpu);
}

} // namespace onnx_mlir
32 changes: 32 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ONNXToZHigh.hpp - ONNX dialect to ZHigh lowering -------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file implements the lowering of ONNX operations to a combination of
// ONNX and ZHigh operations.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"

namespace onnx_mlir {

// Exports ONNXtoZHigh patterns.
void getONNXToZHighOneOpPatterns(mlir::RewritePatternSet &patterns);
void getONNXToZHighMultipleOpPatterns(mlir::RewritePatternSet &patterns);

// Exports ONNXtoZHigh dynamically legal checks.
void getONNXToZHighOneOpDynamicallyLegal(
mlir::ConversionTarget *target, const DimAnalysis *dimAnalysis);

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "src/Dialect/ONNX/DialectBuilder.hpp"

using namespace mlir;
using namespace onnx_mlir;
namespace onnx_mlir {

/// Get transposed tensor by using a permutation array.
Value emitONNXTranspose(
Expand Down Expand Up @@ -67,3 +67,5 @@ ValueRange splitAlongAxis(
ValueRange splits = create.onnx.split(splitTy, X, splitSizes, axis);
return splits;
}

} // namespace onnx_mlir
Loading
Loading