forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_buffers.py
1066 lines (945 loc) · 41.9 KB
/
replay_buffers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import collections
import threading
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from tensordict import is_tensorclass
from tensordict.tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
)
from tensordict.utils import expand_as_right
from torchrl._utils import accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Sampler,
)
from torchrl.data.replay_buffers.storages import (
_get_default_collate,
ListStorage,
Storage,
)
from torchrl.data.replay_buffers.utils import (
_to_numpy,
_to_torch,
INT_CLASSES,
pin_memory_output,
)
from torchrl.data.replay_buffers.writers import (
RoundRobinWriter,
TensorDictRoundRobinWriter,
Writer,
)
from torchrl.data.utils import DEVICE_TYPING
class ReplayBuffer:
"""A generic, composable replay buffer class.
Keyword Args:
storage (Storage, optional): the storage to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
``max_size`` of ``1_000`` will be created.
sampler (Sampler, optional): the sampler to be used. If none is provided,
a default :class:`~torchrl.data.replay_buffers.RandomSampler`
will be used.
writer (Writer, optional): the writer to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter`
will be used.
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s)/outputs. Used when using batched
loading from a map-style dataset. The default value will be decided
based on the storage type.
pin_memory (bool): whether pin_memory() should be called on the rb
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading. Defaults to None (no prefetching).
transform (Transform, optional): Transform to be executed when
sample() is called.
To chain transforms use the :class:`~torchrl.envs.Compose` class.
Transforms should be used with :class:`tensordict.TensorDict`
content. If used with other structures, the transforms should be
encoded with a ``"data"`` leading key that will be used to
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
be preferred whenever the batch-size is consistent across the
experiment. If the batch-size is likely to change, it can be
passed to the :meth:`~.sample` method. This option is
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
Examples:
>>> import torch
>>>
>>> from torchrl.data import ReplayBuffer, ListStorage
>>>
>>> torch.manual_seed(0)
>>> rb = ReplayBuffer(
... storage=ListStorage(max_size=1000),
... batch_size=5,
... )
>>> # populate the replay buffer and get the item indices
>>> data = range(10)
>>> indices = rb.extend(data)
>>> # sample will return as many elements as specified in the constructor
>>> sample = rb.sample()
>>> print(sample)
tensor([4, 9, 3, 0, 3])
>>> # Passing the batch-size to the sample method overrides the one in the constructor
>>> sample = rb.sample(batch_size=3)
>>> print(sample)
tensor([9, 7, 3])
>>> # one cans sample using the ``sample`` method or iterate over the buffer
>>> for i, batch in enumerate(rb):
... print(i, batch)
... if i == 3:
... break
0 tensor([7, 3, 1, 6, 6])
1 tensor([9, 8, 6, 6, 8])
2 tensor([4, 3, 6, 9, 1])
3 tensor([4, 4, 1, 9, 9])
Replay buffers accept *any* kind of data. Not all storage types
will work, as some expect numerical data only, but the default
:class:`torchrl.data.ListStorage` will:
Examples:
>>> torch.manual_seed(0)
>>> buffer = ReplayBuffer(storage=ListStorage(100), collate_fn=lambda x: x)
>>> indices = buffer.extend(["a", 1, None])
>>> buffer.sample(3)
[None, 'a', None]
"""
def __init__(
self,
*,
storage: Optional[Storage] = None,
sampler: Optional[Sampler] = None,
writer: Optional[Writer] = None,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
transform: Optional["Transform"] = None, # noqa-F821
batch_size: Optional[int] = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
self._sampler = sampler if sampler is not None else RandomSampler()
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer.register_storage(self._storage)
self._collate_fn = (
collate_fn
if collate_fn is not None
else _get_default_collate(
self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer)
)
)
self._pin_memory = pin_memory
self._prefetch = bool(prefetch)
self._prefetch_cap = prefetch or 0
self._prefetch_queue = collections.deque()
if self._prefetch_cap:
self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap)
self._replay_lock = threading.RLock()
self._futures_lock = threading.RLock()
from torchrl.envs.transforms.transforms import Compose
if transform is None:
transform = Compose()
elif not isinstance(transform, Compose):
transform = Compose(transform)
transform.eval()
self._transform = transform
if batch_size is None and prefetch:
raise ValueError(
"Dynamic batch-size specification is incompatible "
"with multithreaded sampling. "
"When using prefetch, the batch-size must be specified in "
"advance. "
)
if (
batch_size is None
and hasattr(self._sampler, "drop_last")
and self._sampler.drop_last
):
raise ValueError(
"Samplers with drop_last=True must work with a predictible batch-size. "
"Please pass the batch-size to the ReplayBuffer constructor."
)
self._batch_size = batch_size
def __len__(self) -> int:
with self._replay_lock:
return len(self._storage)
def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"storage={self._storage}, "
f"sampler={self._sampler}, "
f"writer={self._writer}"
")"
)
@pin_memory_output
def __getitem__(self, index: Union[int, torch.Tensor]) -> Any:
index = _to_numpy(index)
with self._replay_lock:
data = self._storage[index]
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)
if self._transform is not None and len(self._transform):
is_td = True
if not is_tensor_collection(data):
data = TensorDict({"data": data}, [])
is_td = False
data = self._transform(data)
if not is_td:
data = data["data"]
return data
def state_dict(self) -> Dict[str, Any]:
return {
"_storage": self._storage.state_dict(),
"_sampler": self._sampler.state_dict(),
"_writer": self._writer.state_dict(),
"_batch_size": self._batch_size,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._storage.load_state_dict(state_dict["_storage"])
self._sampler.load_state_dict(state_dict["_sampler"])
self._writer.load_state_dict(state_dict["_writer"])
self._batch_size = state_dict["_batch_size"]
def add(self, data: Any) -> int:
"""Add a single element to the replay buffer.
Args:
data (Any): data to be added to the replay buffer
Returns:
index where the data lives in the replay buffer.
"""
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
data = self._transform.inv(data)
return self._add(data)
def _add(self, data):
with self._replay_lock:
index = self._writer.add(data)
self._sampler.add(index)
return index
def _extend(self, data: Sequence) -> torch.Tensor:
with self._replay_lock:
index = self._writer.extend(data)
self._sampler.extend(index)
return index
def extend(self, data: Sequence) -> torch.Tensor:
"""Extends the replay buffer with one or more elements contained in an iterable.
If present, the inverse transforms will be called.`
Args:
data (iterable): collection of data to be added to the replay
buffer.
Returns:
Indices of the data added to the replay buffer.
"""
if self._transform is not None and (
is_tensor_collection(data) or len(self._transform)
):
data = self._transform.inv(data)
return self._extend(data)
def update_priority(
self,
index: Union[int, torch.Tensor],
priority: Union[int, torch.Tensor],
) -> None:
with self._replay_lock:
self._sampler.update_priority(index, priority)
@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage[index]
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)
if self._transform is not None and len(self._transform):
is_td = True
if not is_tensor_collection(data):
data = TensorDict({"data": data}, [])
is_td = False
is_locked = data.is_locked
if is_locked:
data.unlock_()
data = self._transform(data)
if is_locked:
data.lock_()
if not is_td:
data = data["data"]
return data, info
def empty(self):
"""Empties the replay buffer and reset cursor to 0."""
self._writer._empty()
self._sampler._empty()
self._storage._empty()
def sample(
self, batch_size: Optional[int] = None, return_info: bool = False
) -> Any:
"""Samples a batch of data from the replay buffer.
Uses Sampler to sample indices, and retrieves them from Storage.
Args:
batch_size (int, optional): size of data to be collected. If none
is provided, this method will sample a batch-size as indicated
by the sampler.
return_info (bool): whether to return info. If True, the result
is a tuple (data, info). If False, the result is the data.
Returns:
A batch of data selected in the replay buffer.
A tuple containing this batch and info if return_info flag is set to True.
"""
if (
batch_size is not None
and self._batch_size is not None
and batch_size != self._batch_size
):
warnings.warn(
f"Got conflicting batch_sizes in constructor ({self._batch_size}) "
f"and `sample` ({batch_size}). Refer to the ReplayBuffer documentation "
"for a proper usage of the batch-size arguments. "
"The batch-size provided to the sample method "
"will prevail."
)
elif batch_size is None and self._batch_size is not None:
batch_size = self._batch_size
elif batch_size is None:
raise RuntimeError(
"batch_size not specified. You can specify the batch_size when "
"constructing the replay buffer, or pass it to the sample method. "
"Refer to the ReplayBuffer documentation "
"for a proper usage of the batch-size arguments."
)
if not self._prefetch:
ret = self._sample(batch_size)
else:
if len(self._prefetch_queue) == 0:
ret = self._sample(batch_size)
else:
with self._futures_lock:
ret = self._prefetch_queue.popleft().result()
with self._futures_lock:
while len(self._prefetch_queue) < self._prefetch_cap:
fut = self._prefetch_executor.submit(self._sample, batch_size)
self._prefetch_queue.append(fut)
if return_info:
return ret
return ret[0]
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index)
def append_transform(self, transform: "Transform") -> None: # noqa-F821
"""Appends transform at the end.
Transforms are applied in order when `sample` is called.
Args:
transform (Transform): The transform to be appended
"""
transform.eval()
self._transform.append(transform)
def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-F821
"""Inserts transform.
Transforms are executed in order when `sample` is called.
Args:
index (int): Position to insert the transform.
transform (Transform): The transform to be appended
"""
transform.eval()
self._transform.insert(index, transform)
def __iter__(self):
if self._sampler.ran_out:
self._sampler.ran_out = False
if self._batch_size is None:
raise RuntimeError(
"Cannot iterate over the replay buffer. "
"Batch_size was not specified during construction of the replay buffer."
)
while not self._sampler.ran_out:
data = self.sample()
yield data
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
_replay_lock = state.pop("_replay_lock", None)
_futures_lock = state.pop("_futures_lock", None)
if _replay_lock is not None:
state["_replay_lock_placeholder"] = None
if _futures_lock is not None:
state["_futures_lock_placeholder"] = None
return state
def __setstate__(self, state: Dict[str, Any]):
if "_replay_lock_placeholder" in state:
state.pop("_replay_lock_placeholder")
_replay_lock = threading.RLock()
state["_replay_lock"] = _replay_lock
if "_futures_lock_placeholder" in state:
state.pop("_futures_lock_placeholder")
_futures_lock = threading.RLock()
state["_futures_lock"] = _futures_lock
self.__dict__.update(state)
class PrioritizedReplayBuffer(ReplayBuffer):
"""Prioritized replay buffer.
All arguments are keyword-only arguments.
Presented in
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
Args:
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float): delta added to the priorities to ensure that the buffer
does not contain null priorities.
storage (Storage, optional): the storage to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
``max_size`` of ``1_000`` will be created.
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s)/outputs. Used when using batched
loading from a map-style dataset. The default value will be decided
based on the storage type.
pin_memory (bool): whether pin_memory() should be called on the rb
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading. Defaults to None (no prefetching).
transform (Transform, optional): Transform to be executed when
sample() is called.
To chain transforms use the :class:`~torchrl.envs.Compose` class.
Transforms should be used with :class:`tensordict.TensorDict`
content. If used with other structures, the transforms should be
encoded with a ``"data"`` leading key that will be used to
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
be preferred whenever the batch-size is consistent across the
experiment. If the batch-size is likely to change, it can be
passed to the :meth:`~.sample` method. This option is
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
.. note::
Generic prioritized replay buffers (ie. non-tensordict backed) require
calling :meth:`~.sample` with the ``return_info`` argument set to
``True`` to have access to the indices, and hence update the priority.
Using :class:`tensordict.TensorDict` and the related
:class:`~torchrl.data.TensorDictPrioritizedReplayBuffer` simplifies this
process.
Examples:
>>> import torch
>>>
>>> from torchrl.data import ListStorage, PrioritizedReplayBuffer
>>>
>>> torch.manual_seed(0)
>>>
>>> rb = PrioritizedReplayBuffer(alpha=0.7, beta=0.9, storage=ListStorage(10))
>>> data = range(10)
>>> rb.extend(data)
>>> sample = rb.sample(3)
>>> print(sample)
tensor([1, 0, 1])
>>> # get the info to find what the indices are
>>> sample, info = rb.sample(5, return_info=True)
>>> print(sample, info)
tensor([2, 7, 4, 3, 5]) {'_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
>>> # update priority
>>> priority = torch.ones(5) * 5
>>> rb.update_priority(info["index"], priority)
>>> # and now a new sample, the weights should be updated
>>> sample, info = rb.sample(5, return_info=True)
>>> print(sample, info)
tensor([2, 5, 2, 2, 5]) {'_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
dtype=float32), 'index': array([2, 5, 2, 2, 5])}
"""
def __init__(
self,
*,
alpha: float,
beta: float,
eps: float = 1e-8,
dtype: torch.dtype = torch.float,
storage: Optional[Storage] = None,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
transform: Optional["Transform"] = None, # noqa-F821
batch_size: Optional[int] = None,
) -> None:
if storage is None:
storage = ListStorage(max_size=1_000)
sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
super(PrioritizedReplayBuffer, self).__init__(
storage=storage,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
transform=transform,
batch_size=batch_size,
)
class TensorDictReplayBuffer(ReplayBuffer):
"""TensorDict-specific wrapper around the :class:`~torchrl.data.ReplayBuffer` class.
Keyword Args:
storage (Storage, optional): the storage to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
``max_size`` of ``1_000`` will be created.
sampler (Sampler, optional): the sampler to be used. If none is provided
a default RandomSampler() will be used.
writer (Writer, optional): the writer to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter`
will be used.
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s)/outputs. Used when using batched
loading from a map-style dataset. The default value will be decided
based on the storage type.
pin_memory (bool): whether pin_memory() should be called on the rb
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading. Defaults to None (no prefetching).
transform (Transform, optional): Transform to be executed when
sample() is called.
To chain transforms use the :class:`~torchrl.envs.Compose` class.
Transforms should be used with :class:`tensordict.TensorDict`
content. If used with other structures, the transforms should be
encoded with a ``"data"`` leading key that will be used to
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
be preferred whenever the batch-size is consistent across the
experiment. If the batch-size is likely to change, it can be
passed to the :meth:`~.sample` method. This option is
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
:class:`~torchrl.data.PrioritizedSampler`.
Defaults to ``"td_error"``.
Examples:
>>> import torch
>>>
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
>>> from tensordict import TensorDict
>>>
>>> torch.manual_seed(0)
>>>
>>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10), batch_size=5)
>>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 1, 1)}, [10])
>>> rb.extend(data)
>>> sample = rb.sample(3)
>>> # samples keep track of the index
>>> print(sample)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False),
index: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int32, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False)
>>> # we can iterate over the buffer
>>> for i, data in enumerate(rb):
... print(i, data)
... if i == 2:
... break
0 TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False),
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False)
1 TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False),
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False)
"""
def __init__(self, *, priority_key: str = "td_error", **kw) -> None:
writer = kw.get("writer", None)
if writer is None:
kw["writer"] = TensorDictRoundRobinWriter()
super().__init__(**kw)
self.priority_key = priority_key
def _get_priority_item(self, tensordict: TensorDictBase) -> float:
if "_data" in tensordict.keys():
tensordict = tensordict.get("_data")
priority = tensordict.get(self.priority_key, None)
if priority is None:
return self._sampler.default_priority
try:
if priority.numel() > 1:
priority = _reduce(priority, self._sampler.reduction)
else:
priority = priority.item()
except ValueError:
raise ValueError(
f"Found a priority key of size"
f" {tensordict.get(self.priority_key).shape} but expected "
f"scalar value"
)
return priority
def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
if "_data" in tensordict.keys():
tensordict = tensordict.get("_data")
priority = tensordict.get(self.priority_key, None)
if priority is None:
return torch.tensor(
self._sampler.default_priority,
dtype=torch.float,
device=tensordict.device,
).expand(tensordict.shape[0])
priority = priority.reshape(priority.shape[0], -1)
priority = _reduce(priority, self._sampler.reduction, dim=1)
return priority
def add(self, data: TensorDictBase) -> int:
if self._transform is not None:
data = self._transform.inv(data)
if is_tensor_collection(data):
data_add = TensorDict(
{
"_data": data,
},
batch_size=[],
device=data.device,
)
if data.batch_size:
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)
else:
data_add = data
index = super()._add(data_add)
if index is not None:
if is_tensor_collection(data_add):
data_add.set("index", index)
# priority = self._get_priority(data)
# if priority:
self.update_tensordict_priority(data_add)
return index
def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
tensordicts = TensorDict(
{"_data": tensordicts},
batch_size=tensordicts.batch_size[:1],
)
if tensordicts.batch_dims > 1:
# we want the tensordict to have one dimension only. The batch size
# of the sampled tensordicts can be changed thereafter
if not isinstance(tensordicts, LazyStackedTensorDict):
tensordicts = tensordicts.clone(recurse=False)
else:
tensordicts = tensordicts.contiguous()
# we keep track of the batch size to reinstantiate it when sampling
if "_rb_batch_size" in tensordicts.keys():
raise KeyError(
"conflicting key '_rb_batch_size'. Consider removing from data."
)
shape = torch.tensor(tensordicts.batch_size[1:]).expand(
tensordicts.batch_size[0], tensordicts.batch_dims - 1
)
tensordicts.set("_rb_batch_size", shape)
tensordicts.set(
"index",
torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int),
)
if self._transform is not None:
data = self._transform.inv(tensordicts.get("_data"))
tensordicts.set("_data", data)
if data.device is not None:
tensordicts = tensordicts.to(data.device)
index = super()._extend(tensordicts)
self.update_tensordict_priority(tensordicts)
return index
def update_tensordict_priority(self, data: TensorDictBase) -> None:
if not isinstance(self._sampler, PrioritizedSampler):
return
if data.ndim:
priority = self._get_priority_vector(data)
else:
priority = self._get_priority_item(data)
index = data.get("index")
while index.shape != priority.shape:
# reduce index
index = index[..., 0]
self.update_priority(index, priority)
def sample(
self,
batch_size: Optional[int] = None,
return_info: bool = False,
include_info: bool = None,
) -> TensorDictBase:
"""Samples a batch of data from the replay buffer.
Uses Sampler to sample indices, and retrieves them from Storage.
Args:
batch_size (int, optional): size of data to be collected. If none
is provided, this method will sample a batch-size as indicated
by the sampler.
return_info (bool): whether to return info. If True, the result
is a tuple (data, info). If False, the result is the data.
Returns:
A tensordict containing a batch of data selected in the replay buffer.
A tuple containing this tensordict and info if return_info flag is set to True.
"""
if include_info is not None:
warnings.warn(
"include_info is going to be deprecated soon."
"The default behaviour has changed to `include_info=True` "
"to avoid bugs linked to wrongly preassigned values in the "
"output tensordict."
)
data, info = super().sample(batch_size, return_info=True)
if not is_tensorclass(data) and include_info in (True, None):
is_locked = data.is_locked
if is_locked:
data.unlock_()
for k, v in info.items():
data.set(k, expand_as_right(_to_torch(v, data.device), data))
if is_locked:
data.lock_()
if return_info:
return data, info
return data
class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
"""TensorDict-specific wrapper around the :class:`~torchrl.data.PrioritizedReplayBuffer` class.
This class returns tensordicts with a new key ``"index"`` that represents
the index of each element in the replay buffer. It also provides the
:meth:`~.update_tensordict_priority` method that only requires for the
tensordict to be passed to it with its new priority value.
Keyword Args:
alpha (float): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float): delta added to the priorities to ensure that the buffer
does not contain null priorities.
storage (Storage, optional): the storage to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
``max_size`` of ``1_000`` will be created.
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s)/outputs. Used when using batched
loading from a map-style dataset. The default value will be decided
based on the storage type.
pin_memory (bool): whether pin_memory() should be called on the rb
samples.
prefetch (int, optional): number of next batches to be prefetched
using multithreading. Defaults to None (no prefetching).
transform (Transform, optional): Transform to be executed when
sample() is called.
To chain transforms use the :class:`~torchrl.envs.Compose` class.
Transforms should be used with :class:`tensordict.TensorDict`
content. If used with other structures, the transforms should be
encoded with a ``"data"`` leading key that will be used to
construct a tensordict from the non-tensordict content.
batch_size (int, optional): the batch size to be used when sample() is
called.
.. note::
The batch-size can be specified at construction time via the
``batch_size`` argument, or at sampling time. The former should
be preferred whenever the batch-size is consistent across the
experiment. If the batch-size is likely to change, it can be
passed to the :meth:`~.sample` method. This option is
incompatible with prefetching (since this requires to know the
batch-size in advance) as well as with samplers that have a
``drop_last`` argument.
priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
:class:`~torchrl.data.PrioritizedSampler`.
Defaults to ``"td_error"``.
reduction (str, optional): the reduction method for multidimensional
tensordicts (ie stored trajectories). Can be one of "max", "min",
"median" or "mean".
Examples:
>>> import torch
>>>
>>> from torchrl.data import LazyTensorStorage, TensorDictPrioritizedReplayBuffer
>>> from tensordict import TensorDict
>>>
>>> torch.manual_seed(0)
>>>
>>> rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, storage=LazyTensorStorage(10), batch_size=5)
>>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 3, 1)}, [10])
>>> rb.extend(data)
>>> print("len of rb", len(rb))
len of rb 10
>>> sample = rb.sample(5)
>>> print(sample)
TensorDict(
fields={
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False),
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False)
>>> print("index", sample["index"])
index tensor([9, 5, 2, 2, 7])
>>> # give a high priority to these samples...
>>> sample.set("td_error", 100*torch.ones(sample.shape))
>>> # and update priority
>>> rb.update_tensordict_priority(sample)
>>> # the new sample should have a high overlap with the previous one
>>> sample = rb.sample(5)
>>> print(sample)
TensorDict(
fields={
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False),
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False)
>>> print("index", sample["index"])
index tensor([2, 5, 5, 9, 7])
"""
def __init__(
self,
*,
alpha: float,
beta: float,
priority_key: str = "td_error",
eps: float = 1e-8,
storage: Optional[Storage] = None,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
prefetch: Optional[int] = None,
transform: Optional["Transform"] = None, # noqa-F821
reduction: Optional[str] = "max",
batch_size: Optional[int] = None,
) -> None:
if storage is None:
storage = ListStorage(max_size=1_000)
sampler = PrioritizedSampler(
storage.max_size, alpha, beta, eps, reduction=reduction
)
super(TensorDictPrioritizedReplayBuffer, self).__init__(
priority_key=priority_key,
storage=storage,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
transform=transform,
batch_size=batch_size,
)
@accept_remote_rref_udf_invocation
class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
"""A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def sample(
self,
batch_size: Optional[int] = None,
include_info: bool = None,
return_info: bool = False,
) -> TensorDictBase:
return super().sample(
batch_size=batch_size, include_info=include_info, return_info=return_info
)
def add(self, data: TensorDictBase) -> int:
return super().add(data)
def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
return super().extend(tensordicts)
def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[int, torch.Tensor]
) -> None:
return super().update_priority(index, priority)
def update_tensordict_priority(self, data: TensorDictBase) -> None:
return super().update_tensordict_priority(data)
class InPlaceSampler:
"""A sampler to write tennsordicts in-place.
To be used cautiously as this may lead to unexpected behaviour (i.e. tensordicts
overwritten during execution).