forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataio.py
635 lines (510 loc) · 23 KB
/
dataio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
## @package dataio
# Module caffe2.python.dataio
"""
Defines the base interface for reading and writing operations.
Readers/Writers are objects that produce operations that read/write sequences
of data. Each operation reads or writes a list of BlobReferences.
Readers and Writers must be implemented such that read and write operations
are atomic and thread safe.
Examples of possible Readers and Writers:
QueueReader, QueueWriter,
DatasetReader, DatasetWriter,
See `dataset.py` for an example of implementation.
"""
from caffe2.python import core
from caffe2.python.schema import Field, Struct, from_blob_list
import numpy as np
import time
class Reader(object):
"""
Reader is an abstract class to be implemented in order to provide
operations capable of iterating through a dataset or stream of data.
A Reader must implement at least one operation, `read`, which
adds operations to a net that read the next batch of data. Readers can
optionally support the `reset` operation, which is useful when multiple
passes over the data are required.
"""
def __init__(self, schema=None):
if schema is not None:
assert isinstance(schema, Field)
self._schema = schema
def schema(self):
assert self._schema is not None, 'Schema not provided for this reader.'
return self._schema
def _set_schema(self, schema):
self._schema = schema
def setup_ex(self, init_net, finish_net):
"""Setup nets to run at task initialization and cleanup time.
Args:
global_init_net: A net invoked at task init time.
global_finish_net: A net invoked at task cleanup time.
"""
pass
def read_ex(self, local_init_net, local_finish_net):
read_net = core.Net('reader_body')
return ([read_net], ) + self.read(read_net)
def read_record_ex(self, local_init_net, local_finish_net):
nets, should_stop, fields = self.read_ex(
local_init_net, local_finish_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return nets, should_stop, fields
def read(self, read_net):
"""Append operations to read_net that will read a batch from the
underlying data soruce.
Operations added to `read_net` must be thread safe and atomic, that is,
it should be possible to clone `read_net` and run multiple instances of
it in parallel.
Args:
read_net: the net that will be appended with read operations
Returns:
A tuple (should_stop, fields), with:
should_stop: BlobReference pointing to a boolean scalar
blob that indicates whether the read operation
was succesfull or whether the end of data has
been reached.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
raise NotImplementedError('Readers must implement `read`.')
def reset(self, net):
"""Append operations to `net` that will reset the reader.
This can be used to read the data multiple times.
Not all readers support this operation.
"""
raise NotImplementedError('This reader cannot be resetted.')
def read_record(self, read_net):
should_stop, fields = self.read(read_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return should_stop, fields
def execution_step(self, reader_net_name=None, external_should_stop=None):
"""Create an execution step with a net containing read operators.
The execution step will contain a `stop_blob` that knows how to stop
the execution loop when end of data was reached.
E.g.:
read_step, fields = reader.execution_step()
consume_net = core.Net('consume')
consume_net.Print(fields[0], [])
p = core.Plan('reader')
p.AddStep(read_step.AddNet(consume_net))
core.RunPlan(p)
Args:
reader_net_name: (optional) the name of the reader_net to be
created. The execution step will
be named accordingly.
Returns:
A tuple (read_step, fields), with:
read_step: A newly created execution step containing a net with
read operations. The step will have `stop_blob` set,
in order to stop the loop on end of data.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
reader_net = core.Net(reader_net_name or 'reader')
should_stop, fields = self.read_record(reader_net)
if external_should_stop is not None:
should_stop = reader_net.Or([external_should_stop, should_stop])
read_step = core.execution_step(
'{}_step'.format(reader_net_name),
reader_net,
should_stop_blob=should_stop)
return (read_step, fields)
class Writer(object):
"""
Writer is an abstract class to be implemented in order to provide
operations capable of feeding a data stream or a dataset.
A Writer must implement 2 operations:
`write`, which adds operations to a net that write the write batch of
data, and `commit`, which adds operations to a net in order to indicate
that no more data will be written.
"""
_schema = None
def schema(self):
return self._schema
def write(self, writer_net, fields):
"""Add operations to `writer_net` that write the next batch of data.
Operations added to the net must be thread-safe and unique, that is:
multiple writers must be able to write to the dataset in parallel.
Args:
fields: a tuple of BlobReference containing the batch of data to
write.
"""
raise NotImplementedError('Writers must implement write.')
def write_record(self, writer_net, fields):
if isinstance(fields, Field):
self._schema = fields
fields = fields.field_blobs()
self.write(writer_net, fields)
def setup_ex(self, init_net, finish_net):
"""Experimental, don't use yet"""
self.commit(finish_net)
def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
"""Experimental extension to the interface. Don't use yet"""
write_net = core.Net('write_net')
self.write(write_net, fields)
return [write_net]
def write_record_ex(
self, fields, local_init_net, local_finish_net, stop_blob=None):
"""Experimental extension to the interface. Don't use yet."""
if isinstance(fields, Field):
self._schema = fields
fields = fields.field_blobs()
if stop_blob is None:
stop_blob = local_init_net.NextName("dequeue_status")
write_nets = self.write_ex(
fields, local_init_net, local_finish_net, stop_blob)
return (write_nets, stop_blob)
def commit(self, finish_net):
"""Add operations to `finish_net` that signal end of data.
This must be implemented by all Writers, but may be no-op for some
of them.
"""
pass
class ReaderBuilder(object):
""" Allow usage of a reader in distributed fashion. """
def schema(self):
raise NotImplementedError()
def setup(self, **kwargs):
"""
Optionally, perform one-time setup before calling new_reader().
Subclass should make sure this function is only called once.
"""
raise NotImplementedError()
def new_reader(self, **kwargs):
raise NotImplementedError()
class PipedReaderBuilder(ReaderBuilder):
"""ReaderBuilder that modifies underlying builder by calling `piper`
function on each new reader produced, and return the result of
the function. This way, it is possible to append data processing
pipelines that will be replicated for each reader that gets created.
E.g.:
PipedReaderBuilder(
ReaderBuilder(...),
lambda reader: pipe(reader, processor=my_proc))
"""
def __init__(self, builder, piper):
self._builder = builder
self._piper = piper
def schema(self):
return self._builder.schema()
def setup(self, **kwargs):
return self._builder.setup(**kwargs)
def new_reader(self, **kwargs):
# Passing everything down since you could wrap a PipedReaderBuilder in
# another PipedReaderBuilder
output = self._piper(
reader=self._builder.new_reader(**kwargs),
**kwargs
)
return output if isinstance(output, Reader) else output.reader()
class Pipe(object):
def __init__(self, schema=None, obj_key=None):
self._num_writers = 0
self._num_readers = 0
self._schema = schema
self._obj_key = obj_key
def schema(self):
return self._schema
def setup(self, global_init_net):
pass
def reader(self):
raise NotImplementedError()
def writer(self):
raise NotImplementedError()
def num_readers(self):
return self._num_readers
def num_writers(self):
return self._num_writers
def _new_writer(self, writer_schema, writer_init_net):
if writer_schema is not None and self._schema is None:
self._schema = writer_schema
self._num_writers += 1
if self._obj_key is not None:
writer_init_net.add_attribute(self._obj_key, self)
def _new_reader(self, reader_init_net):
self._num_readers += 1
if self._obj_key is not None:
reader_init_net.add_attribute(self._obj_key, self)
class CounterReader(Reader):
""" Reader that produces increasing integers. """
def __init__(self):
Reader.__init__(self, schema=Struct(('iter', np.int64)))
self.counter = None
self.should_stop = None
def setup_ex(self, global_init_net, global_finish_net):
if self.counter is None:
self.counter = global_init_net.CreateCounter([], init_count=0)
self.should_stop = global_init_net.ConstantFill(
[], shape=[], dtype=core.DataType.BOOL, value=False)
def read_ex(self, local_init_net, local_finish_net):
count_net = core.Net('limited_reader_counter')
value = count_net.CountUp([self.counter], 1)
return [count_net], self.should_stop, [value]
class ReaderWithLimitBase(Reader):
"""Abstract Reader constrained by certain conditions.
Base class for Reader classes which check for certain conditions to stop
further processing (e.g. max number of iterations or time limit).
Also produces a boolean blob (data_finished) that can be used to see if
the reader exausted all input data (true) or stopped for another reason
(false).
"""
def __init__(self, reader):
Reader.__init__(self, schema=reader._schema)
self.reader = reader
self.net = core.Net('reader_with_limit')
self._data_finished = self.net.AddExternalInput(
self.net.NextName('data_finished'))
self.should_stop = None
def setup_ex(self, global_init_net, global_finish_net):
global_init_net.ConstantFill(
[], [self._data_finished],
shape=[], value=False, dtype=core.DataType.BOOL)
self.reader.setup_ex(global_init_net, global_finish_net)
self.setup_limiter(global_init_net, global_finish_net)
def read_ex(self, local_init_net, local_finish_net):
"""Reads from an underlying Reader class, but may stop due to additional
constraints.
Build and return network(s) to read data from a Reader with
additional constraints, depending on which derived class is used.
Derived classes implement setup_limited and check_limiter_condition
which determine the nature of the constraint imposed on the reader,
e.g. iteration limits or time limit.
Args:
local_init_net: A net invoked at task instance init time (Once per
parallel thread).
local_finish_net: A net invoked at task instance cleanup time (Once
per parallel thread).
"""
# Check if limiting constraint is met.
stop_condition_net = core.Net('limited_reader_condition')
should_stop = self.check_limiter_condition(stop_condition_net)
# Call original reader.
nets, local_data_finished, fields = self.reader.read_ex(
local_init_net, local_finish_net)
self._set_schema(self.reader._schema)
# Check if original reader is done.
check_done_net = core.Net('limited_reader_post')
# Copy to the same blob as the counter output to trigger reader
# stopping - this is ok because execution will check should_stop_blob
# after every single operation, so it has already been checked on this
# iteration by this point.
check_done_net.Copy(local_data_finished, should_stop)
# Update externally-accessible flag indicating if reader is done
check_done_net.Or([self._data_finished, local_data_finished],
[self._data_finished])
return [stop_condition_net] + nets + [check_done_net], should_stop, fields
def setup_limiter(self, global_init_net, global_finish_net):
"""Configure task level init/cleanup nets required to implement limit
condition. Must be implemented by subclass.
Args:
global_init_net: A net invoked at task init time.
global_finish_net: A net invoked at task cleanup time.
"""
raise NotImplementedError("Subclass must implement `setup_limiter`")
def check_limiter_condition(self, stop_condition_net):
"""Configure a net that is invoked between reading batches to see if
limit condition is met. Must be implemented by subclass.
Args:
stop_condition_net: A net invoked to evaluate an early termination
condition.
"""
raise NotImplementedError("Subclass must implement `check_limiter_condition")
def data_finished(self):
"""
Return a blob that can be checked after the end of the reading task,
which will contain a scalar float indicating whether the underlying
reader has been exhausted (True) or whether we stopped because reached
the limit of iterations (False).
"""
return self._data_finished
class ReaderWithLimit(ReaderWithLimitBase):
"""Reader that stops after `num_iter` batches.
If `num_iter` <= 0 or is None, reverts to an unconstrained reader that
exports a boolean blob indicating that the reader has exhausted
the data steam.
"""
def __init__(self, reader, num_iter=1):
"""Class initializer.
Args:
reader: The underlying reader object doing the actual read.
num_iter: Number of batches to read. If `None`,
the class reverts to a normal reader except that it also
produces a data_finished blob as a side effect to indicate
whether the input stream is exhausted.
"""
super(ReaderWithLimit, self).__init__(reader)
self.counter = None
self.num_iter = num_iter
if self.num_iter is not None:
self.counter = self.net.AddExternalInput(
self.net.NextName('counter'))
def setup_limiter(self, global_init_net, global_finish_net):
if self.counter:
global_init_net.CreateCounter(
[], [self.counter], init_count=int(self.num_iter))
def check_limiter_condition(self, stop_condition_net):
if self.counter:
return stop_condition_net.CountDown([self.counter], 1)
else:
return stop_condition_net.ConstantFill(
[], 1,
shape=[], value=False, dtype=core.DataType.BOOL)
def CountUntil(num_iter):
return ReaderWithLimit(CounterReader(), num_iter)
class ReaderWithTimeLimit(ReaderWithLimitBase):
"""Reader that stops after `duration` seconds.
If `duration` <= 0 or is None, reverts to an unconstrained reader that
exports a boolean blob indicating that the reader has exhausted
the data steam.
"""
def __init__(self, reader, duration=0):
"""Class initializer.
Args:
reader: The underlying reader object doing the actual read.
duration: Number of seconds to read. If un-specified, None, or <= 0,
the class reverts to a normal reader except that it also
produces a data_finished blob as a side effect to indicate
whether the input stream is exhausted.
"""
super(ReaderWithTimeLimit, self).__init__(reader)
self.timer = None
self.duration = duration
self.duration_ns_blob = None
def setup_limiter(self, global_init_net, global_finish_net):
if self.duration is not None and self.duration > 0:
duration_ns = int(self.duration * (10**9))
self.timer = global_init_net.TimerBegin(
[], counter_name='epoch_timer')
start_time = global_init_net.TimerGet(self.timer)
self.duration_ns_blob = global_init_net.ConstantFill(
[start_time], value=duration_ns)
global_finish_net.TimerEnd([self.timer], [])
def check_limiter_condition(self, stop_condition_net):
if self.duration:
time_elapsed = stop_condition_net.TimerGet(self.timer)
return stop_condition_net.GE(
[time_elapsed, self.duration_ns_blob], str(self.should_stop))
else:
return stop_condition_net.ConstantFill(
[], 1, shape=[], value=False, dtype=core.DataType.BOOL
)
class ReaderWithDelay(Reader):
"""Test reader class that inserts a delay between reading batches."""
def __init__(self, reader, delay):
Reader.__init__(self, schema=reader._schema)
self.reader = reader
self.delay = delay
def setup_ex(self, global_init_net, global_finish_net):
self.reader.setup_ex(global_init_net, global_finish_net)
def read_ex(self, local_init_net, local_finish_net):
read_net = core.Net("reader_body")
def sleep_op(*args, **argd):
time.sleep(self.delay)
read_net.Python(sleep_op)([], [])
return ([read_net],) + self.reader.read(read_net)
class CompositeReader(Reader):
"""
Base class for a reader that wrap multiple readers, e.g., reading from
multiple sources simultaneously.
"""
def __init__(self, names, readers):
"""
Args:
names: list[str] names of readers; used as schema keys
readers: list[Reader] Reader instances, must have schema
"""
assert len(names) == len(readers)
super(CompositeReader, self).__init__(schema=Struct(*[
(name, reader.schema()) for name, reader in zip(names, readers)
]))
self._names = names
self._readers = readers
def setup_ex(self, init_net, finish_net):
for reader in self._readers:
reader.setup_ex(init_net, finish_net)
def read_ex(self, local_init_net, local_finish_net):
"""
Stops when one of the reader finished
"""
# First, instantiate all the reader nets
fields = []
stop_blobs = []
all_sub_read_nets = []
for name, reader in zip(self._names, self._readers):
sub_read_nets, should_stop, record = reader.read_record_ex(
local_init_net, local_finish_net)
stop_blobs.append(should_stop)
all_sub_read_nets.append(sub_read_nets)
fields.extend(record.field_blobs())
read_nets = []
# Use the stop blob of the last reader as stop blob of composite reader.
local_should_stop = stop_blobs[-1]
for name, sub_read_nets, stop_blob in zip(self._names, all_sub_read_nets, stop_blobs):
read_nets.extend(sub_read_nets)
if stop_blob == local_should_stop:
# Skip adding stop net because Or([A, A], A) doesn't pass operator
# schema check
continue
stop_net = core.Net("{}_stop".format(name))
stop_net.Or([local_should_stop, stop_blob], local_should_stop)
read_nets.append(stop_net)
return read_nets, local_should_stop, fields
def reset(self, net):
for reader in self._readers:
reader.reset(net)
class CompositeReaderBuilder(ReaderBuilder):
"""
A reader builder for CompositeReader
"""
def __init__(self, names, reader_builders):
"""
Args:
names: list[str] names of readers; used as schema keys
reader_builders: list[ReaderBuilder] ReaderBuilder instances;
must have schema
"""
super(CompositeReaderBuilder, self).__init__()
self._names = names
self._reader_builders = reader_builders
self._schema = Struct(*[
(name, reader_builder.schema())
for name, reader_builder in zip(names, reader_builders)
])
def schema(self):
return self._schema
def setup(self, **kwargs):
data_finished_blobs = {}
# limiter is stateful; it can only be used once. Since
# CompositeReader stops when one of the reader stops,
# this is fine.
if "limiter" in kwargs:
limiter = kwargs.pop("limiter")
else:
limiter = None
for i, reader_builder in enumerate(self._reader_builders):
if i == len(self._reader_builders) - 1 and limiter is not None:
# The limiter must be applied to the last reader so that the
# batch counter is incremented only if every reader has data
kwargs["limiter"] = limiter
sub_reader_data_finished_blobs = reader_builder.setup(**kwargs)
overlapping_keys = set(data_finished_blobs.keys()) & set(sub_reader_data_finished_blobs.keys())
overlapping_values = set(data_finished_blobs.values()) & set(sub_reader_data_finished_blobs.values())
assert overlapping_keys == set(), "Overlapping keys: {}".format(overlapping_keys)
assert overlapping_values == set(), "Overlapping values: {}".format(overlapping_values)
data_finished_blobs.update(sub_reader_data_finished_blobs)
return data_finished_blobs
def new_reader(self, **kwargs):
readers = []
for reader_builder in self._reader_builders:
reader = reader_builder.new_reader(**kwargs)
if isinstance(reader, Reader):
pass
elif hasattr(reader, 'reader'):
reader = reader.reader()
else:
raise ValueError('reader must be an instance of Reader or Pipe')
readers.append(reader)
multi_reader = CompositeReader(self._names, readers)
assert multi_reader.schema() == self._schema
return multi_reader