Skip to content

Commit

Permalink
[Performance] Faster slice sampler (#2031)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 22, 2024
1 parent a69c667 commit cd540bf
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 39 deletions.
175 changes: 140 additions & 35 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from torchrl._extension import EXTENSION_WARNING

from torchrl._utils import _replace_last
from torchrl._utils import _replace_last, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int

Expand Down Expand Up @@ -54,7 +54,7 @@ def extend(self, index: torch.Tensor) -> None:

def update_priority(
self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor]
) -> dict:
) -> dict | None:
return

def mark_update(self, index: Union[int, torch.Tensor]) -> None:
Expand Down Expand Up @@ -221,7 +221,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
if storage.ndim > 1:
index = torch.unravel_index(index, storage.shape)
# we 'always' return the indices. The 'drop_last' just instructs the
# sampler to turn to 'ran_out = True` whenever the next sample
# sampler to turn to `ran_out = True` whenever the next sample
# will be too short. This will be read by the replay buffer
# as a signal for an early break of the __iter__().
return index, {}
Expand Down Expand Up @@ -477,7 +477,7 @@ def update_priority(
"""
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
# we need to reshape priority if it has more than one elements or if it has
# we need to reshape priority if it has more than one element or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
Expand Down Expand Up @@ -637,7 +637,25 @@ class SliceSampler(Sampler):
if the last element of the trajectory tensor is identical to the first,
the same trajectory spans across end and beginning.
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
Will cache the start and end signal of the trajectory. This can be safely used even
if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
as this operation will erase the cache.
.. warning:: ``cache_values=True`` will not work if the sampler is used with a
storage that is extended by another buffer. For instance:
>>> buffer0 = ReplayBuffer(storage=storage,
... sampler=SliceSampler(num_slices=8, cache_values=True),
... writer=ImmutableWriter())
>>> buffer1 = ReplayBuffer(storage=storage,
... sampler=other_sampler)
>>> # Wrong! Does not erase the buffer from the sampler of buffer0
>>> buffer1.extend(data)
.. warning:: ``cache_values=True`` will not work as expected if the buffer is
shared between processes and one process is responsible for writing
and one process for sampling, as erasing the cache can only be done locally.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
Expand All @@ -652,6 +670,10 @@ class SliceSampler(Sampler):
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
Expand Down Expand Up @@ -730,6 +752,7 @@ def __init__(
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
compile: bool | dict = False,
):
self.num_slices = num_slices
self.slice_len = slice_len
Expand Down Expand Up @@ -784,6 +807,31 @@ def __init__(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.compile = bool(compile)
if self.compile:
if isinstance(compile, dict):
kwargs = compile
else:
kwargs = {}
self._get_index = torch.compile(self._get_index, **kwargs)

def __getstate__(self):
if get_spawning_popen() is not None and self.cache_values:
logger.warning(
f"It seems you are sharing a {type(self).__name__} across processes with"
f"cache_values=True. "
f"While this isn't forbidden and could perfectly work if your dataset "
f"is unaltered on both processes, remember that calling extend/add on"
f"one process will NOT erase the cache on another process's sampler, "
f"which will cause synchronization issues."
)
state = copy(self.__dict__)
state["_cache"] = {}
return state

def extend(self, index: torch.Tensor) -> None:
if self.cache_values:
self._cache.clear()

def __repr__(self):
return (
Expand All @@ -795,8 +843,8 @@ def __repr__(self):
f"strict_length={self.strict_length})"
)

@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
@classmethod
def _find_start_stop_traj(cls, *, trajectory=None, end=None, at_capacity: bool):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
Expand Down Expand Up @@ -835,6 +883,10 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
)
return cls._end_to_start_stop(length=length, end=end)

@staticmethod
def _end_to_start_stop(end, length):
# Using transpose ensures the start and stop are sorted the same way
stop_idx = end.transpose(0, -1).nonzero()
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
Expand All @@ -859,30 +911,33 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool):
lengths[lengths < 0] = lengths[lengths < 0] + length
return start_idx, stop_idx, lengths

def _start_to_end(self, st: torch.Tensor, length: int):
arange = torch.arange(length, device=st.device, dtype=st.dtype)
ndims = st.shape[-1] - 1 if st.ndim else 0
if ndims:
arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1)
else:
arange = arange.unsqueeze(-1)
if st.shape != arange.shape:
# we do this to make sure that we're not broadcasting the start
# wrong as a tensor with shape [N] can't be expanded to [N, 1]
# without getting an error
st = st.expand_as(arange)
return arange + st

def _tensor_slices_from_startend(self, seq_length, start, storage_length):
# start is a 2d tensor resulting from nonzero()
# seq_length is a 1d tensor indicating the desired length of each sequence

def _start_to_end(st: torch.Tensor, length: int):
arange = torch.arange(length, device=st.device, dtype=st.dtype)
ndims = st.shape[-1] - 1 if st.ndim else 0
arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1)
if st.shape != arange.shape:
# we do this to make sure that we're not broadcasting the start
# wrong as a tensor with shape [N] can't be expanded to [N, 1]
# without getting an error
st = st.expand_as(arange)
return arange + st

if isinstance(seq_length, int):
result = torch.cat(
[_start_to_end(_start, length=seq_length) for _start in start]
[self._start_to_end(_start, length=seq_length) for _start in start]
)
else:
# when padding is needed
result = torch.cat(
[
_start_to_end(_start, _seq_len)
self._start_to_end(_start, _seq_len)
for _start, _seq_len in zip(start, seq_length)
]
)
Expand Down Expand Up @@ -945,14 +1000,16 @@ def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
f"The batch-size must be divisible by the number of slices, got "
f"batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
f"The batch-size must be divisible by the slice length, got "
f"batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
Expand Down Expand Up @@ -993,8 +1050,8 @@ def _sample_slices(
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
# start_idx and stop_idx are 2d tensors organized like a non-zero

def get_traj_idx(lengths=lengths):
return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device)
def get_traj_idx(maxval):
return torch.randint(maxval, (num_slices,), device=lengths.device)

if (lengths < seq_length).any():
if self.strict_length:
Expand All @@ -1013,7 +1070,7 @@ def get_traj_idx(lengths=lengths):
stop_idx = stop_idx[idx]

if traj_idx is None:
traj_idx = get_traj_idx(lengths=lengths_idx)
traj_idx = get_traj_idx(lengths_idx.shape[0])
else:
# Here we must filter out the indices that correspond to trajectories
# we don't want to keep. That could potentially lead to an empty sample.
Expand All @@ -1036,18 +1093,37 @@ def get_traj_idx(lengths=lengths):
lengths = lengths_idx
else:
if traj_idx is None:
traj_idx = get_traj_idx()
traj_idx = get_traj_idx(lengths.shape[0])
else:
num_slices = traj_idx.shape[0]

# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)
else:
if traj_idx is None:
traj_idx = get_traj_idx()
traj_idx = get_traj_idx(lengths.shape[0])
else:
num_slices = traj_idx.shape[0]
return self._get_index(
lengths=lengths,
start_idx=start_idx,
stop_idx=stop_idx,
num_slices=num_slices,
seq_length=seq_length,
storage_length=storage_length,
traj_idx=traj_idx,
)

def _get_index(
self,
lengths: torch.Tensor,
start_idx: torch.Tensor,
stop_idx: torch.Tensor,
seq_length: int,
num_slices: int,
storage_length: int,
traj_idx: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, dict]:
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
Expand Down Expand Up @@ -1130,11 +1206,6 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...

def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state


class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
"""Samples slices of data along the first dimension, given start and stop signals, without replacement.
Expand Down Expand Up @@ -1182,6 +1253,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
shuffle (bool, optional): if ``False``, the order of the trajectories
is not shuffled. Defaults to ``True``.
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
Expand Down Expand Up @@ -1256,6 +1331,7 @@ def __init__(
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
shuffle: bool = True,
compile: bool | dict = False,
):
SliceSampler.__init__(
self,
Expand All @@ -1268,6 +1344,7 @@ def __init__(
strict_length=strict_length,
ends=ends,
trajectories=trajectories,
compile=compile,
)
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)

Expand Down Expand Up @@ -1376,7 +1453,25 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
or when this signal is readily available. Must be used with ``cache_values=True``
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
Will cache the start and end signal of the trajectory. This can be safely used even
if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
as this operation will erase the cache.
.. warning:: ``cache_values=True`` will not work if the sampler is used with a
storage that is extended by another buffer. For instance:
>>> buffer0 = ReplayBuffer(storage=storage,
... sampler=SliceSampler(num_slices=8, cache_values=True),
... writer=ImmutableWriter())
>>> buffer1 = ReplayBuffer(storage=storage,
... sampler=other_sampler)
>>> # Wrong! Does not erase the buffer from the sampler of buffer0
>>> buffer1.extend(data)
.. warning:: ``cache_values=True`` will not work as expected if the buffer is
shared between processes and one process is responsible for writing
and one process for sampling, as erasing the cache can only be done locally.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
Expand All @@ -1391,6 +1486,10 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
Keyword arguments can also be passed to torch.compile with this arg.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -1447,6 +1546,7 @@ def __init__(
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
compile: bool | dict = False,
):
SliceSampler.__init__(
self,
Expand All @@ -1459,6 +1559,7 @@ def __init__(
strict_length=strict_length,
ends=ends,
trajectories=trajectories,
compile=compile,
)
PrioritizedSampler.__init__(
self,
Expand Down Expand Up @@ -1493,6 +1594,10 @@ def __getstate__(self):
state = SliceSampler.__getstate__(self)
state.update(PrioritizedSampler.__getstate__(self))

def extend(self, index: torch.Tensor) -> None:
super(PrioritizedSampler, self).extend(index)
return super(SliceSampler, self).extend(index)

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
# Sample `batch_size` indices representing the start of a slice.
# The sampling is based on a weight vector.
Expand All @@ -1512,7 +1617,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
# make seq_length a tensor with values clamped by lengths
seq_length = lengths[traj_idx].clamp_max(seq_length)

# build a list of index that we dont want to sample: all the steps at a `seq_length` distance of
# build a list of index that we don't want to sample: all the steps at a `seq_length` distance of
# the end the trajectory, with the end of trajectory (`stop_idx`) included
if not isinstance(seq_length, int):
try:
Expand Down Expand Up @@ -1676,7 +1781,7 @@ class SamplerEnsemble(Sampler):
The indices provided in the info dictionary are placed in a :class:`~tensordict.TensorDict` with
keys ``index`` and ``buffer_ids`` that allow the upper :class:`~torchrl.data.ReplayBufferEnsemble`
and :class:`~torchrl.data.StorageEnsemble` objects to retrieve the data.
This format is different than with other samplers which usually return indices
This format is different from with other samplers which usually return indices
as regular tensors.
"""
Expand Down
Loading

0 comments on commit cd540bf

Please sign in to comment.