forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_workers.py
461 lines (388 loc) · 15.6 KB
/
data_workers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
## @package data_workers
# Module caffe2.python.data_workers
'''
This module provides a python-land multithreaded data input mechanism
for Caffe2 nets.
Basic usage is as follows:
coordinator = data_workers.init_data_input_workers(
net,
["data", "label"],
my_fetch_fun,
batch_size=32,
input_source_name="train",
dont_rebatch=False
)
...
coordinator.start()
First argument is the Caffe2 net (or model helper), and second argument
is list of input blobs that are to be fed.
Argument 'input_source_name' is used to distinguish different sources of data,
such as train or test data. This is to ensure the data does not get mixed up,
although two nets would share blobs.
To do the actual data loading, one defines a "fetcher function"
that has call signature
my_fetch_fun(worker_id, batch_size)
Optionally, one can define a "init function" that is called once before
threads start, and has call signature:
my_init_fun(data_coordinator, global_coordinator)
If dont_rebatch is set to True, the data input is not batched into equal sized
chunks but data directly provided by fetchers is used.
'batch_columns' can be used to specify which dimension is the batch dimension,
for each of the inputs. Default is 0 for all iputs.
'timeout' is the timeout in seconds after which if no data is available, the
net will fail (default 600s = 10 mins).
This function returns a list of numpy arrays corresponding to the different
input blobs. In the example above, it would return two arrays, one for the
data blob and another for the labels. These arrays can have arbitrary number
of elements (i.e they do not need to match the batch size). The batch size
is provided for the function as a hint only.
For example, fetcher function could download images from a remote service or
load random images from a directory on a file system.
For a dummy example, see the data_workers_test unit test.
Note that for data_parallel_models, init_data_input_workers will be called
for each GPU. Note that the 'coordinator' returned by the function is same
each time.
'''
import queue as Queue
from itertools import chain
import logging
import threading
import numpy as np
import time
from caffe2.python import workspace, core, scope, utils
from caffe2.proto import caffe2_pb2
from caffe2.python.parallel_workers import Metrics, State, \
WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
log = logging.getLogger("data_workers")
log.setLevel(logging.INFO)
LOG_INT_SECS = 60
def get_worker_ids(num_workers):
return list(range(0, num_workers))
def init_data_input_workers(
net,
input_blob_names,
fetch_fun,
batch_size,
num_worker_threads=2,
input_source_name="train",
max_buffered_batches=800,
init_fun=None,
external_loggers=None,
dont_rebatch=False,
batch_columns=None,
timeout=600
):
global global_coordinator
device_option = scope.CurrentDeviceScope()
if (device_option is None):
device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
metrics = Metrics(external_loggers)
batch_feeder = BatchFeeder(
net,
input_blob_names,
batch_size,
device_option,
scope.CurrentNameScope(),
input_source_name,
global_coordinator.get_queue(input_source_name, max_buffered_batches),
metrics,
dont_rebatch,
batch_columns,
timeout=timeout
)
# Launch fetch worker threads
worker_ids = [
global_coordinator.get_new_worker_id()
for i in range(num_worker_threads)
]
# Create coordinator object
coordinator = WorkerCoordinator(
input_source_name, worker_ids, init_fun, batch_feeder)
workers = [
threading.Thread(
target=run_worker,
name="data_workers fetcher id {}".format(worker_id),
args=[coordinator,
DataWorker(coordinator, worker_id, fetch_fun, metrics,
batch_size, batch_feeder)],
) for worker_id in worker_ids
]
workers.append(threading.Thread(
target=enqueuer,
name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
args=[coordinator, batch_feeder]))
coordinator._workers = workers
global_coordinator.add(coordinator)
return global_coordinator
class BatchFeeder(State):
def __init__(self, net, input_blob_names, batch_size,
device_option, namescope, input_source_name, queue,
metrics, dont_rebatch, batch_columns, timeout=600):
self._counter = 0
self._input_blob_names = input_blob_names
self._batch_size = batch_size
self._internal_queue = queue
self._queues = []
self._device_option = device_option
self._namescope = namescope
self._timeout = timeout
self._input_source_name = input_source_name
self._c2_queue_capacity = 4
self._create_caffe2_queues(net)
self._create_caffe2_ops(net)
self._inputs = 0
self._prev_seconds = 0
self._last_warning = time.time()
self._dont_rebatch = dont_rebatch
self._init_scratch()
self._metrics = metrics
if batch_columns is None:
batch_columns = [0 for _ in input_blob_names]
self._batch_columns = batch_columns
def start(self):
self._inputs = 0
self._prev_seconds = time.time()
def stop(self):
try:
for q in self._queues:
workspace.RunOperatorOnce(
core.CreateOperator("CloseBlobsQueue", [q], [])
)
finally:
self._log_inputs_per_interval(0, force=True)
def cleanup(self):
utils.ResetBlobs(self._scratch_blob.values())
utils.ResetBlobs(self._scratch_status.values())
def _get(self, data_input_coordinator):
start_time = time.time()
last_warning = time.time()
while data_input_coordinator.is_active():
try:
return self._internal_queue.get(block=True, timeout=0.5)
except Queue.Empty:
if time.time() - last_warning > 10.0:
log.warning("** Data input is slow: (still) no data in {} secs.".format(
time.time() - start_time))
last_warning = time.time()
continue
return None
def _validate_chunk(self, chunk):
if chunk is None:
log.warning("Fetcher function returned None")
return False
assert len(chunk) == len(self._input_blob_names), \
"Expecting data blob for each input"
for d in chunk:
assert isinstance(d, np.ndarray), \
"Fetcher function must return a numpy array"
if not self._dont_rebatch:
j = 1
for d in chunk[1:]:
assert d.shape[self._batch_columns[j]] == \
chunk[0].shape[self._batch_columns[0]], \
"Each returned input must have equal number of samples"
j += 1
if len(chunk) == 0:
log.warning("Worker provided zero length input")
return False
return True
def put(self, chunk, data_input_coordinator):
if not self._validate_chunk(chunk):
return
while data_input_coordinator.is_active():
try:
qsize = self._internal_queue.qsize()
if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
log.warning("Warning, data loading lagging behind: " +
"queue size={}, name={}".format(qsize, self._input_source_name))
self._last_warning = time.time()
self._counter += 1
self._internal_queue.put(chunk, block=True, timeout=0.5)
self._log_inputs_per_interval(chunk[0].shape[0])
return
except Queue.Full:
log.debug("Queue full: stalling fetchers...")
continue
def _enqueue_batch_direct(self, data_input_coordinator):
data = self._get(data_input_coordinator)
if data is None:
return
if data_input_coordinator.is_active():
for b, q, c in zip(self._input_blob_names, self._queues, data):
self._enqueue(b, q, c)
def _enqueue_batch(self, data_input_coordinator):
'''
This pulls data from the python-side queue and collects them
into batch-sized pieces, unless dont_rebatch is set to true.
'''
if self._dont_rebatch:
self._enqueue_batch_direct(data_input_coordinator)
return
cur_batch = [np.array([]) for d in self._input_blob_names]
first_batch_col = self._batch_columns[0]
# Collect data until we have a full batch size
while (
cur_batch[0].shape[0] == 0 or
cur_batch[0].shape[first_batch_col] < self._batch_size
) and data_input_coordinator.is_active():
chunk = self._get(data_input_coordinator)
if chunk is None:
continue
for j, chunk_elem in enumerate(chunk):
if cur_batch[j].shape[0] == 0:
cur_batch[j] = chunk_elem.copy()
else:
cur_batch[j] = np.append(
cur_batch[j], chunk_elem, axis=self._batch_columns[j]
)
start_time = time.time()
try:
# Return data over the batch size back to queue
if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
first_batch_col
] > self._batch_size:
leftover = []
trimmed_batch = []
for j, b in enumerate(cur_batch):
[c, l] = np.split(
b, [self._batch_size], axis=self._batch_columns[j]
)
leftover.append(l)
trimmed_batch.append(c)
cur_batch = trimmed_batch
try:
self._internal_queue.put(leftover, block=False)
except Queue.Full:
pass
assert cur_batch[0].shape[first_batch_col] == self._batch_size
if data_input_coordinator.is_active():
for b, q, c in zip(
self._input_blob_names, self._queues, cur_batch
):
self._enqueue(b, q, c)
finally:
self._metrics.put_metric('enqueue_time', time.time() - start_time)
def _init_scratch(self):
self._scratch_blob = {}
self._scratch_status = {}
for blob_name in self._input_blob_names:
scratch_name = self._namescope + blob_name + \
"_scratch_" + self._input_source_name
self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
self._scratch_status[blob_name] = core.BlobReference(
scratch_name + "_status"
)
# Feed empty arrays to the scratch blobs here, so that there won't be
# race conditions when calling FeedBlob (which calls wworkspace
# CreateBlob()) from enqueue threads
for b in chain(
self._scratch_blob.values(), self._scratch_status.values()
):
workspace.FeedBlob(
b,
np.array([]).astype(np.float32),
device_option=self._device_option,
)
def _enqueue(self, blob_name, queue, data_arr):
'''
Enqueue the correctly sized batch arrays to Caffe2's queue.
'''
workspace.FeedBlob(
self._scratch_blob[blob_name],
data_arr,
device_option=self._device_option
)
op = core.CreateOperator(
"SafeEnqueueBlobs",
[queue, self._scratch_blob[blob_name]],
[self._scratch_blob[blob_name], self._scratch_status[blob_name]],
device_option=self._device_option
)
workspace.RunOperatorOnce(op)
def _create_caffe2_queues(self, net):
'''
Creates queues on caffe2 side
'''
def create_queue(queue_name, num_blobs, capacity):
workspace.RunOperatorOnce(
core.CreateOperator(
"CreateBlobsQueue",
[], [queue_name],
num_blobs=1,
capacity=capacity))
return core.ScopedBlobReference(queue_name)
for blob_name in self._input_blob_names:
qname = blob_name + "_c2queue" + "_" + self._input_source_name
q = create_queue(
qname, num_blobs=1, capacity=self._c2_queue_capacity
)
self._queues.append(q)
def _create_caffe2_ops(self, net):
'''
Creates dequeue-ops on caffe2 side
'''
for q, blob_name in zip(self._queues, self._input_blob_names):
# Add operator to the Caffe2 network to dequeue
net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
def _log_inputs_per_interval(self, inputs, force=False):
self._inputs += inputs
current_seconds = time.time()
delta_seconds = current_seconds - self._prev_seconds
if delta_seconds >= LOG_INT_SECS or force:
inputs_per_sec = int(self._inputs / delta_seconds)
qsize = self._internal_queue.qsize()
log.info("{}/{}: {} inputs/sec".format(
self._input_source_name,
self._namescope,
inputs_per_sec,
))
log.info("-- queue: {} batches".format(qsize))
# log and reset perf metrics
self._metrics.put_metric(
'inputs_per_sec', inputs_per_sec, False)
self._metrics.put_metric('queue_size', qsize, False)
self._metrics.put_metric(
'time_elapsed', delta_seconds, False)
self._metrics.log_metrics()
self._metrics.reset_metrics()
self._inputs = 0
self._prev_seconds = current_seconds
class GlobalCoordinator(GlobalWorkerCoordinator):
def __init__(self):
GlobalWorkerCoordinator.__init__(self)
self._queues = {}
def get_queue(self, queue_name, max_buffered_batches):
assert isinstance(max_buffered_batches, int)
if queue_name not in self._queues:
self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
return self._queues[queue_name]
def reset_data_input(self, namescope, name, net, batch_size):
log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
for c in self._coordinators:
if c._worker_name == name and c._state._namescope == namescope:
c._state._batch_size = batch_size
c._state._create_caffe2_ops(net)
class DataWorker(Worker):
def __init__(
self,
coordinator,
worker_id,
worker_fun,
metrics,
batch_size,
batch_feeder
):
Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
metrics=metrics)
self._batch_size = batch_size
self._batch_feeder = batch_feeder
def run(self):
input_data = self._worker_fun(self._worker_id, self._batch_size)
self._batch_feeder.put(input_data, self._coordinator)
def finish(self):
self._metrics.put_metric(
'fetcher_time', time.time() - self._start_time)
global_coordinator = GlobalCoordinator()
def enqueuer(coordinator, batch_feeder):
while coordinator.is_active():
batch_feeder._enqueue_batch(coordinator)