forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 2
/
NNPAAccelerator.cpp
170 lines (136 loc) · 5.75 KB
/
NNPAAccelerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===-------------------------- NNPAAccelerator.cpp -----------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Add accelerator support for the IBM Telum processor.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp"
#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp"
#include "src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/NNPAAccelerator.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
#include "src/Accelerators/NNPA/Support/NNPALimit.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "zdnn.h"
#include <memory>
#define DEBUG_TYPE "NNPAAccelerator"
extern llvm::cl::OptionCategory OMNNPAPassOptions;
namespace onnx_mlir {
namespace accel {
Accelerator *createNNPA() { return NNPAAccelerator::getInstance(); }
NNPAAccelerator *NNPAAccelerator::instance = nullptr;
NNPAAccelerator *NNPAAccelerator::getInstance() {
if (instance == nullptr)
instance = new NNPAAccelerator();
return instance;
}
NNPAAccelerator::NNPAAccelerator() : Accelerator(Accelerator::Kind::NNPA) {
LLVM_DEBUG(llvm::dbgs() << "Creating an NNPA accelerator\n");
// Print a warning if mcpu is not set or < z16.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
llvm::outs() << "Warning: No NNPA code is generated because --mcpu is not "
"set or < z16.\n";
acceleratorTargets.push_back(this);
// Order is important! libRuntimeNNPA depends on libzdnn
addCompilerConfig(CCM_SHARED_LIB_DEPS, {"RuntimeNNPA", "zdnn"});
};
NNPAAccelerator::~NNPAAccelerator() { delete instance; }
uint64_t NNPAAccelerator::getVersionNumber() const { return ZDNN_VERNUM; }
void NNPAAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for NNPA accelerator\n");
addPassesNNPA(module, pm, emissionTarget, outputNameNoExt);
}
void NNPAAccelerator::registerDialects(mlir::DialectRegistry ®istry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for NNPA accelerator\n");
registry.insert<zhigh::ZHighDialect>();
registry.insert<zlow::ZLowDialect>();
}
void NNPAAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for NNPA accelerator\n");
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile,
nnpaSaveDevicePlacementFile, nnpaPlacementHeuristic);
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::createONNXToZHighPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::createRewriteONNXForZHighPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::createZHighToONNXPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zlow::createZLowRewritePass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zlow::createZLowDummyOpForMultiDerefPass();
});
mlir::registerPass(
[]() -> std::unique_ptr<mlir::Pass> { return createFoldStdAllocPass(); });
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zhigh::createZHighConstPropagationPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zhigh::createZHighLayoutPropagationPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zhigh::createZHighClipToDLFloatPass();
});
}
mlir::MemRefType NNPAAccelerator::convertTensorTypeToMemRefType(
const mlir::TensorType tensorType) const {
assert(tensorType.hasRank() && "expected only ranked shapes");
if (tensorType.cast<mlir::RankedTensorType>()
.getEncoding()
.dyn_cast_or_null<onnx_mlir::zhigh::ZTensorEncodingAttr>()) {
onnx_mlir::zhigh::ZMemRefType zMemRefType =
onnx_mlir::zhigh::convertZTensorToMemRefType(tensorType);
return zMemRefType.value;
}
return nullptr;
}
int64_t NNPAAccelerator::getDefaultAllocAlignment(
const mlir::TensorType tensorType) const {
assert(tensorType.hasRank() && "expected only ranked shapes");
if (tensorType.cast<mlir::RankedTensorType>()
.getEncoding()
.dyn_cast_or_null<onnx_mlir::zhigh::ZTensorEncodingAttr>())
return gAlignment;
return -1;
}
void NNPAAccelerator::conversionTargetONNXToKrnl(
mlir::ConversionTarget &target) const {
target.addLegalDialect<zlow::ZLowDialect>();
}
void NNPAAccelerator::rewritePatternONNXToKrnl(
mlir::RewritePatternSet &patterns, mlir::TypeConverter &typeConverter,
mlir::MLIRContext *ctx) const {
onnx_mlir::zhigh::populateZHighToZLowConversionPattern(
patterns, typeConverter, ctx);
}
void NNPAAccelerator::conversionTargetKrnlToLLVM(
mlir::ConversionTarget &target) const {}
void NNPAAccelerator::rewritePatternKrnlToLLVM(
mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter,
mlir::MLIRContext *ctx) const {
onnx_mlir::zlow::populateZLowToLLVMConversionPattern(
patterns, typeConverter, ctx);
}
} // namespace accel
} // namespace onnx_mlir