forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
init.h
179 lines (162 loc) · 6.34 KB
/
init.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#ifndef CAFFE2_CORE_INIT_H_
#define CAFFE2_CORE_INIT_H_
#include "caffe2/core/common.h"
#include "caffe2/core/flags.h"
#include "caffe2/core/logging.h"
namespace caffe2 {
namespace internal {
class TORCH_API Caffe2InitializeRegistry {
public:
typedef bool (*InitFunction)(int*, char***);
// Registry() is defined in .cpp file to make registration work across
// multiple shared libraries loaded with RTLD_LOCAL
static Caffe2InitializeRegistry* Registry();
void Register(
InitFunction function,
bool run_early,
const char* description,
const char* name = nullptr) {
if (name) {
named_functions_[name] = function;
}
if (run_early) {
// Disallow registration after GlobalInit of early init functions
CAFFE_ENFORCE(!early_init_functions_run_yet_);
early_init_functions_.emplace_back(function, description);
} else {
if (init_functions_run_yet_) {
// Run immediately, since GlobalInit already ran. This should be
// rare but we want to allow it in some cases.
LOG(WARNING) << "Running init function after GlobalInit: "
<< description;
// TODO(orionr): Consider removing argc and argv for non-early
// registration. Unfortunately that would require a new InitFunction
// typedef, so not making the change right now.
//
// Note that init doesn't receive argc and argv, so the function
// might fail and we want to raise an error in that case.
int argc = 0;
char** argv = nullptr;
bool success = (function)(&argc, &argv);
CAFFE_ENFORCE(success);
} else {
// Wait until GlobalInit to run
init_functions_.emplace_back(function, description);
}
}
}
bool RunRegisteredEarlyInitFunctions(int* pargc, char*** pargv) {
CAFFE_ENFORCE(!early_init_functions_run_yet_);
early_init_functions_run_yet_ = true;
return RunRegisteredInitFunctionsInternal(
early_init_functions_, pargc, pargv);
}
bool RunRegisteredInitFunctions(int* pargc, char*** pargv) {
CAFFE_ENFORCE(!init_functions_run_yet_);
init_functions_run_yet_ = true;
return RunRegisteredInitFunctionsInternal(init_functions_, pargc, pargv);
}
bool RunNamedFunction(const char* name, int* pargc, char*** pargv) {
if (named_functions_.count(name)) {
return named_functions_[name](pargc, pargv);
}
return false;
}
private:
// Run all registered initialization functions. This has to be called AFTER
// all static initialization are finished and main() has started, since we are
// using logging.
bool RunRegisteredInitFunctionsInternal(
vector<std::pair<InitFunction, const char*>>& functions,
int* pargc, char*** pargv) {
for (const auto& init_pair : functions) {
VLOG(1) << "Running init function: " << init_pair.second;
if (!(*init_pair.first)(pargc, pargv)) {
LOG(ERROR) << "Initialization function failed.";
return false;
}
}
return true;
}
Caffe2InitializeRegistry() {}
vector<std::pair<InitFunction, const char*> > early_init_functions_;
vector<std::pair<InitFunction, const char*> > init_functions_;
std::unordered_map<std::string, InitFunction> named_functions_;
bool early_init_functions_run_yet_ = false;
bool init_functions_run_yet_ = false;
};
} // namespace internal
TORCH_API bool unsafeRunCaffe2InitFunction(
const char* name,
int* pargc = nullptr,
char*** pargv = nullptr);
class TORCH_API InitRegisterer {
public:
InitRegisterer(
internal::Caffe2InitializeRegistry::InitFunction function,
bool run_early,
const char* description,
const char* name = nullptr) {
internal::Caffe2InitializeRegistry::Registry()->Register(
function, run_early, description, name);
}
};
#define REGISTER_CAFFE2_INIT_FUNCTION(name, function, description) \
namespace { \
::caffe2::InitRegisterer \
g_caffe2_initregisterer_##name(function, false, description, #name); \
} // namespace
#define REGISTER_CAFFE2_EARLY_INIT_FUNCTION(name, function, description) \
namespace { \
::caffe2::InitRegisterer \
g_caffe2_initregisterer_##name(function, true, description, #name); \
} // namespace
/**
* @brief Determine whether GlobalInit has already been run
*/
TORCH_API bool GlobalInitAlreadyRun();
class TORCH_API GlobalInitIsCalledGuard {
public:
GlobalInitIsCalledGuard() {
if (!GlobalInitAlreadyRun()) {
LOG(WARNING)
<< "Caffe2 GlobalInit should be run before any other API calls.";
}
}
};
/**
* @brief Initialize the global environment of caffe2.
*
* Caffe2 uses a registration pattern for initialization functions. Custom
* initialization functions should take the signature
* bool (*func)(int*, char***)
* where the pointers to argc and argv are passed in. Caffe2 then runs the
* initialization in three phases:
* (1) Functions registered with REGISTER_CAFFE2_EARLY_INIT_FUNCTION. Note that
* since it is possible the logger is not initialized yet, any logging in
* such early init functions may not be printed correctly.
* (2) Parses Caffe-specific commandline flags, and initializes caffe logging.
* (3) Functions registered with REGISTER_CAFFE2_INIT_FUNCTION.
* If there is something wrong at each stage, the function returns false. If
* the global initialization has already been run, the function returns false
* as well.
*
* GlobalInit is re-entrant safe; a re-entrant call will no-op and exit.
*
* GlobalInit is safe to call multiple times but not idempotent;
* successive calls will parse flags and re-set caffe2 logging levels from
* flags as needed, but NOT re-run early init and init functions.
*
* GlobalInit is also thread-safe and can be called concurrently.
*/
TORCH_API bool GlobalInit(int* pargc, char*** argv);
/**
* @brief Initialize the global environment without command line arguments
*
* This is a version of the GlobalInit where no argument is passed in.
* On mobile devices, use this global init, since we cannot pass the
* command line options to caffe2, no arguments are passed.
*/
TORCH_API bool GlobalInit();
} // namespace caffe2
#endif // CAFFE2_CORE_INIT_H_