forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.h
264 lines (223 loc) · 9.79 KB
/
engine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
#pragma once
// Engine implements backpropagation from output variables and their gradients
// to "root" variables (variables created by the user with requires_grad=True).
#include <ATen/ThreadLocalDebugInfo.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/utils/future.h>
#include <deque>
#include <exception>
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <utility>
#include <vector>
#include <thread>
namespace torch { namespace autograd {
struct ReadyQueue;
}} // namespace torch::autograd
namespace torch { namespace autograd {
using FutureVariableList = torch::utils::Future<variable_list>;
void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
// NB: -1 indicates the CPU worker!
static constexpr int NO_DEVICE = -2;
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask {
// Indicates if an error occurred while executing any task. When this is
// true, it signals all threads to stop executing.
std::atomic_bool has_error_;
std::atomic<uint64_t> outstanding_tasks_;
// It is safe to read grad_mode_ and keep_graph_ without synchronization
bool keep_graph_;
bool grad_mode_;
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
// has_error_, future_result_ and leaf_streams.
std::mutex mutex_;
std::unordered_map<Node*, InputBuffer> not_ready_;
std::unordered_map<Node*, int> dependencies_;
struct ExecInfo {
struct Capture {
Capture(int input_idx, int output_idx)
: input_idx_(input_idx), output_idx_(output_idx) {}
int input_idx_; // within Node inputs
int output_idx_; // within the output vector of a GraphTask
};
bool should_execute() const {
return needed_ || captures_;
}
bool needed_ = false;
std::unique_ptr<std::vector<Capture>> captures_;
};
// Exec info has a bit complicated semantics. If it's empty, it means the task
// is run in a "default" mode, which means that all next_edges we encounter
// should get executed. If it's not empty, only functions that have an entry
// and this entry has needed == True should be executed. exec_info_.empty()
// means it's .backward(), otherwise it's .grad(). exec_info_ is safe to read
// without synchronization
std::unordered_map<Node*, ExecInfo> exec_info_;
std::vector<Variable> captured_vars_;
std::shared_ptr<at::ThreadLocalDebugInfoBase> debug_info_ =
at::getThreadLocalDebugInfo();
std::unordered_set<c10::Stream> leaf_streams;
void init_to_execute(Node& graph_root, const edge_list& outputs);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
// Safe to read owner_ and reentrant_depth_ without synchronizaton
int owner_;
// The number of parent graph tasks for this graph task
const int reentrant_depth_;
bool can_checkpoint() {
return exec_info_.empty();
}
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception& e, const std::shared_ptr<Node>& fn);
// Whether or not to stop execution for this GraphTask when an error is
// encountered. When set to true, this would cause Engine::execute() to throw
// an exception as soon as the autograd engine receives an exception.
bool exit_on_error_;
// Future representing the completion of the graph task. Notified when all
// tasks are done.
std::shared_ptr<FutureVariableList> future_result_;
GraphTask(
bool keep_graph,
bool grad_mode,
int reentrant_depth,
bool exit_on_error = false)
: has_error_(false),
outstanding_tasks_(0),
keep_graph_(keep_graph),
grad_mode_(grad_mode),
owner_(NO_DEVICE),
reentrant_depth_(reentrant_depth),
exit_on_error_(exit_on_error),
future_result_(std::make_shared<FutureVariableList>()) {}
};
struct NodeTask {
std::weak_ptr<GraphTask> base_;
std::shared_ptr<Node> fn_;
// This buffer serves as an implicit "addition" node for all of the
// gradients flowing here. Once all the dependencies are finished, we
// use the contents of this buffer to run the function.
InputBuffer inputs_;
// When worker receives a task with isShutdownTask = true, it will immediately
// exit. The engine sends a shutdown task to every queue upon its destruction.
bool isShutdownTask_;
int getReentrantDepth() const;
NodeTask(
std::weak_ptr<GraphTask> base,
std::shared_ptr<Node> fn,
InputBuffer inputs,
bool isShutdownTask = false)
: base_(base),
fn_(std::move(fn)),
inputs_(std::move(inputs)),
isShutdownTask_(isShutdownTask) {}
};
// A single instance of this struct should be created through the whole process lifetime.
// The worker thread creation logic and Engine's destructor rely on this.
struct TORCH_API Engine {
/// Returns a reference to a static `Engine` instance.
static Engine& get_default_engine();
Engine();
virtual ~Engine();
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Node>, InputBuffer>>;
using dependencies_type = std::unordered_map<Node*, int>;
// Given a list of (Node, input number) pairs computes the value of the graph
// by following next_edge references.
virtual variable_list execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
const edge_list& outputs = {});
// Given a pre-populated GraphTask and GraphRoot, computes the backward pass
// for the graph. This API should only be used by internal autograd specific
// machinery and shouldn't be exposed to users in anyway.
virtual std::shared_ptr<FutureVariableList> execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root);
// Enqueues a blocked task for execution on the CPU thread. A blocked task is
// basically a task that isn't triggered automatically to be
// 'ready to execute' by the autograd engine. This task needs to be unblocked
// for execution via an external mechanism. This method assumes that
// the appropriate GraphTask has already been initialized appropriately.
// Another important part is that this does not increment 'outstanding_tasks_'
// in the appropriate GraphTask. It is assumed we've already done this before
// hand for this task (to ensure we block for its execution). This is useful
// in the distributed autograd case where we need to increment
// 'outstanding_tasks_' first to indicate the local autograd engine needs to
// wait for this task, but the task might actually be received later over the
// network for execution.
void enqueue_blocked_task_on_cpu(NodeTask task);
virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
return nullptr;
}
void queue_callback(std::function<void()> callback);
bool is_checkpoint_valid();
size_t ready_queue_size(at::Device device);
protected:
void compute_dependencies(Node* root, GraphTask& task);
void evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs);
ReadyQueue& ready_queue(at::Device device);
ReadyQueue& ready_queue_by_index(int device_index);
void start_threads();
virtual void thread_init(int device);
virtual void thread_on_exception(
std::shared_ptr<GraphTask>& graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e);
virtual void thread_main(
const std::shared_ptr<GraphTask>& task,
bool reentrant_thread);
void reentrant_thread_init();
void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
void set_device(int device);
// Ensures ready_queues_ are initialized only once
std::once_flag start_threads_flag_;
// Safe to read ready_queues_ without synchronization after intialization
std::vector<std::shared_ptr<ReadyQueue>> ready_queues_;
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_
std::mutex post_callbacks_lock_;
// How many nested reentrant calls are allowed until a new thread is used
int max_recursion_depth_;
struct ThreadPoolShared {
// Data structures used by the threads for executing reentrant backwards
// tasks. See Note [Reentrant backwards]
// Number of available threads for processing new GraphTasks.
unsigned int num_workers_;
// The threads will wait on work_ to be notified of GraphTasks
std::condition_variable work_;
// To protect reads and writes to graphtask_queue_ and num_workers_
// and for synchronizing creating new threads when needed
std::mutex mutex_;
// Workers will process the GraphTasks added to this queue. A GraphTask is
// allocated inside Engine::execute and lives for the duration of execute
std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
ThreadPoolShared() : num_workers_(0) {}
};
// Temporary workaround until shutting down threads is done
// We need shared ownership of all these objects because the threads are leaked
// when Engine shuts down, so there may be threads waiting on work_
// for the graphtasks_queue_ to be nonempty.
std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
private:
variable_list graph_task_exec_post_processing(
const std::shared_ptr<GraphTask>& graph_task);
void mark_graph_task_completed(std::shared_ptr<GraphTask>& graph_task);
};
// allow python_engine to override the default engine when it loads
using EngineStub = Engine& (*)();
TORCH_API void set_default_engine_stub(EngineStub stub);
}} // namespace torch::autograd