Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 20, 2024
2 parents 572a570 + 0238b1c commit 563e4a9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
19 changes: 18 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ class SliceSampler(Sampler):
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
.. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
its execution, prefer using `end_key` over `traj_key`, and consider the following
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
Expand Down Expand Up @@ -796,6 +800,10 @@ class SliceSampler(Sampler):
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
Using tuples allows a fine grained control over the span on the left (beginning
of the stored trajectory) and on the right (end of the stored trajectory).
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
will be used to retrieve the indices of the trajectory starts. This can significanlty
accelerate the sampling when the buffer content is large.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
Expand Down Expand Up @@ -1562,6 +1570,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
the sampler, and continuous sampling without replacement is currently not
allowed.
.. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate
its execution, prefer using `end_key` over `traj_key`, and consider the following
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
Keyword Args:
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
If ``False``, this last sample will be kept.
Expand Down Expand Up @@ -1604,6 +1616,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
will be used to retrieve the indices of the trajectory starts. This can significanlty
accelerate the sampling when the buffer content is large.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
Expand Down Expand Up @@ -1708,7 +1724,6 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
tensor([[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.]])
"""

def __init__(
Expand All @@ -1725,6 +1740,7 @@ def __init__(
strict_length: bool = True,
shuffle: bool = True,
compile: bool | dict = False,
use_gpu: bool | torch.device = False,
):
SliceSampler.__init__(
self,
Expand All @@ -1738,6 +1754,7 @@ def __init__(
ends=ends,
trajectories=trajectories,
compile=compile,
use_gpu=use_gpu,
)
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)

Expand Down
10 changes: 8 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4996,8 +4996,14 @@ def __init__(
)
kwargs = primers
if not isinstance(kwargs, Composite):
kwargs = Composite(**kwargs)
self.primers = kwargs
shape = kwargs.pop("shape", None)
device = kwargs.pop("device", None)
if "batch_size" in kwargs.keys():
extra_kwargs = {"batch_size": kwargs.pop("batch_size")}
else:
extra_kwargs = {}
primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs)
self.primers = primers
self.expand_specs = expand_specs

if random and default_value:
Expand Down

0 comments on commit 563e4a9

Please sign in to comment.