forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 2
/
NNPAAccelerator.hpp
81 lines (69 loc) · 3.13 KB
/
NNPAAccelerator.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===-------------------------- NNPAAccelerator.hpp ----------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// ===========================================================================
//
// Accelerator support for the IBM Telum coprocessor.
//
//===---------------------------------------------------------------------===//
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "src/Accelerators/Accelerator.hpp"
namespace onnx_mlir {
namespace accel {
/// Singleton class to construct an NNPA accelerator.
class NNPAAccelerator final : public Accelerator {
private:
static NNPAAccelerator *instance;
NNPAAccelerator();
public:
/// Singleton should not be clonable or assignable.
NNPAAccelerator(NNPAAccelerator &) = delete;
void operator=(const NNPAAccelerator &) = delete;
~NNPAAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance.
static NNPAAccelerator *getInstance();
/// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc.
static bool classof(const Accelerator *accel) {
return accel->getKind() == Accelerator::Kind::NNPA;
}
static bool classof(const NNPAAccelerator *) { return true; }
uint64_t getVersionNumber() const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void registerDialects(mlir::DialectRegistry ®istry) const final;
virtual void registerPasses(int optLevel) const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-to-krnl pass
//===--------------------------------------------------------------------===//
virtual mlir::MemRefType convertTensorTypeToMemRefType(
const mlir::TensorType tensorType) const final;
virtual void conversionTargetONNXToKrnl(
mlir::ConversionTarget &target) const final;
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final;
virtual int64_t getDefaultAllocAlignment(
const mlir::TensorType tensorType) const final;
//===--------------------------------------------------------------------===//
// Hooks for krnl-to-llvm pass
//===--------------------------------------------------------------------===//
virtual void conversionTargetKrnlToLLVM(
mlir::ConversionTarget &target) const final;
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns,
mlir::LLVMTypeConverter &typeConverter,
mlir::MLIRContext *ctx) const final;
};
} // namespace accel
} // namespace onnx_mlir