From cd540bf96a9c998e89a59382b1961fd8a2bc57f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Mar 2024 08:56:19 +0000 Subject: [PATCH] [Performance] Faster slice sampler (#2031) --- torchrl/data/replay_buffers/samplers.py | 175 +++++++++++++++++++----- torchrl/data/replay_buffers/storages.py | 8 +- 2 files changed, 144 insertions(+), 39 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 5b49b067956..20f0472ed0e 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -22,7 +22,7 @@ from torchrl._extension import EXTENSION_WARNING -from torchrl._utils import _replace_last +from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage from torchrl.data.replay_buffers.utils import _is_int @@ -54,7 +54,7 @@ def extend(self, index: torch.Tensor) -> None: def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] - ) -> dict: + ) -> dict | None: return def mark_update(self, index: Union[int, torch.Tensor]) -> None: @@ -221,7 +221,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: if storage.ndim > 1: index = torch.unravel_index(index, storage.shape) # we 'always' return the indices. The 'drop_last' just instructs the - # sampler to turn to 'ran_out = True` whenever the next sample + # sampler to turn to `ran_out = True` whenever the next sample # will be too short. This will be read by the replay buffer # as a signal for an early break of the __iter__(). return index, {} @@ -477,7 +477,7 @@ def update_priority( """ priority = torch.as_tensor(priority, device=torch.device("cpu")).detach() index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu")) - # we need to reshape priority if it has more than one elements or if it has + # we need to reshape priority if it has more than one element or if it has # a different shape than index if priority.numel() > 1 and priority.shape != index.shape: try: @@ -637,7 +637,25 @@ class SliceSampler(Sampler): if the last element of the trajectory tensor is identical to the first, the same trajectory spans across end and beginning. cache_values (bool, optional): to be used with static datasets. - Will cache the start and end signal of the trajectory. + Will cache the start and end signal of the trajectory. This can be safely used even + if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend` + as this operation will erase the cache. + + .. warning:: ``cache_values=True`` will not work if the sampler is used with a + storage that is extended by another buffer. For instance: + + >>> buffer0 = ReplayBuffer(storage=storage, + ... sampler=SliceSampler(num_slices=8, cache_values=True), + ... writer=ImmutableWriter()) + >>> buffer1 = ReplayBuffer(storage=storage, + ... sampler=other_sampler) + >>> # Wrong! Does not erase the buffer from the sampler of buffer0 + >>> buffer1.extend(data) + + .. warning:: ``cache_values=True`` will not work as expected if the buffer is + shared between processes and one process is responsible for writing + and one process for sampling, as erasing the cache can only be done locally. + truncated_key (NestedKey, optional): If not ``None``, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided @@ -652,6 +670,10 @@ class SliceSampler(Sampler): Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. + compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of + the :meth:`~sample` method will be compiled with :func:`~torch.compile`. + Keyword arguments can also be passed to torch.compile with this arg. + Defaults to ``False``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first @@ -730,6 +752,7 @@ def __init__( cache_values: bool = False, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, + compile: bool | dict = False, ): self.num_slices = num_slices self.slice_len = slice_len @@ -784,6 +807,31 @@ def __init__( "Either num_slices or slice_len must be not None, and not both. " f"Got num_slices={num_slices} and slice_len={slice_len}." ) + self.compile = bool(compile) + if self.compile: + if isinstance(compile, dict): + kwargs = compile + else: + kwargs = {} + self._get_index = torch.compile(self._get_index, **kwargs) + + def __getstate__(self): + if get_spawning_popen() is not None and self.cache_values: + logger.warning( + f"It seems you are sharing a {type(self).__name__} across processes with" + f"cache_values=True. " + f"While this isn't forbidden and could perfectly work if your dataset " + f"is unaltered on both processes, remember that calling extend/add on" + f"one process will NOT erase the cache on another process's sampler, " + f"which will cause synchronization issues." + ) + state = copy(self.__dict__) + state["_cache"] = {} + return state + + def extend(self, index: torch.Tensor) -> None: + if self.cache_values: + self._cache.clear() def __repr__(self): return ( @@ -795,8 +843,8 @@ def __repr__(self): f"strict_length={self.strict_length})" ) - @staticmethod - def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool): + @classmethod + def _find_start_stop_traj(cls, *, trajectory=None, end=None, at_capacity: bool): if trajectory is not None: # slower # _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True) @@ -835,6 +883,10 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool): raise RuntimeError( "Expected the end-of-trajectory signal to be at least 1-dimensional." ) + return cls._end_to_start_stop(length=length, end=end) + + @staticmethod + def _end_to_start_stop(end, length): # Using transpose ensures the start and stop are sorted the same way stop_idx = end.transpose(0, -1).nonzero() stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone() @@ -859,30 +911,33 @@ def _find_start_stop_traj(*, trajectory=None, end=None, at_capacity: bool): lengths[lengths < 0] = lengths[lengths < 0] + length return start_idx, stop_idx, lengths + def _start_to_end(self, st: torch.Tensor, length: int): + arange = torch.arange(length, device=st.device, dtype=st.dtype) + ndims = st.shape[-1] - 1 if st.ndim else 0 + if ndims: + arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1) + else: + arange = arange.unsqueeze(-1) + if st.shape != arange.shape: + # we do this to make sure that we're not broadcasting the start + # wrong as a tensor with shape [N] can't be expanded to [N, 1] + # without getting an error + st = st.expand_as(arange) + return arange + st + def _tensor_slices_from_startend(self, seq_length, start, storage_length): # start is a 2d tensor resulting from nonzero() # seq_length is a 1d tensor indicating the desired length of each sequence - def _start_to_end(st: torch.Tensor, length: int): - arange = torch.arange(length, device=st.device, dtype=st.dtype) - ndims = st.shape[-1] - 1 if st.ndim else 0 - arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1) - if st.shape != arange.shape: - # we do this to make sure that we're not broadcasting the start - # wrong as a tensor with shape [N] can't be expanded to [N, 1] - # without getting an error - st = st.expand_as(arange) - return arange + st - if isinstance(seq_length, int): result = torch.cat( - [_start_to_end(_start, length=seq_length) for _start in start] + [self._start_to_end(_start, length=seq_length) for _start in start] ) else: # when padding is needed result = torch.cat( [ - _start_to_end(_start, _seq_len) + self._start_to_end(_start, _seq_len) for _start, _seq_len in zip(start, seq_length) ] ) @@ -945,14 +1000,16 @@ def _adjusted_batch_size(self, batch_size): if self.num_slices is not None: if batch_size % self.num_slices != 0: raise RuntimeError( - f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}." + f"The batch-size must be divisible by the number of slices, got " + f"batch_size={batch_size} and num_slices={self.num_slices}." ) seq_length = batch_size // self.num_slices num_slices = self.num_slices else: if batch_size % self.slice_len != 0: raise RuntimeError( - f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}." + f"The batch-size must be divisible by the slice length, got " + f"batch_size={batch_size} and slice_len={self.slice_len}." ) seq_length = self.slice_len num_slices = batch_size // self.slice_len @@ -993,8 +1050,8 @@ def _sample_slices( ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: # start_idx and stop_idx are 2d tensors organized like a non-zero - def get_traj_idx(lengths=lengths): - return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device) + def get_traj_idx(maxval): + return torch.randint(maxval, (num_slices,), device=lengths.device) if (lengths < seq_length).any(): if self.strict_length: @@ -1013,7 +1070,7 @@ def get_traj_idx(lengths=lengths): stop_idx = stop_idx[idx] if traj_idx is None: - traj_idx = get_traj_idx(lengths=lengths_idx) + traj_idx = get_traj_idx(lengths_idx.shape[0]) else: # Here we must filter out the indices that correspond to trajectories # we don't want to keep. That could potentially lead to an empty sample. @@ -1036,7 +1093,7 @@ def get_traj_idx(lengths=lengths): lengths = lengths_idx else: if traj_idx is None: - traj_idx = get_traj_idx() + traj_idx = get_traj_idx(lengths.shape[0]) else: num_slices = traj_idx.shape[0] @@ -1044,10 +1101,29 @@ def get_traj_idx(lengths=lengths): seq_length = lengths[traj_idx].clamp_max(seq_length) else: if traj_idx is None: - traj_idx = get_traj_idx() + traj_idx = get_traj_idx(lengths.shape[0]) else: num_slices = traj_idx.shape[0] + return self._get_index( + lengths=lengths, + start_idx=start_idx, + stop_idx=stop_idx, + num_slices=num_slices, + seq_length=seq_length, + storage_length=storage_length, + traj_idx=traj_idx, + ) + def _get_index( + self, + lengths: torch.Tensor, + start_idx: torch.Tensor, + stop_idx: torch.Tensor, + seq_length: int, + num_slices: int, + storage_length: int, + traj_idx: torch.Tensor | None = None, + ) -> Tuple[torch.Tensor, dict]: relative_starts = ( ( torch.rand(num_slices, device=lengths.device) @@ -1130,11 +1206,6 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... - def __getstate__(self): - state = copy(self.__dict__) - state["_cache"] = {} - return state - class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): """Samples slices of data along the first dimension, given start and stop signals, without replacement. @@ -1182,6 +1253,10 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. shuffle (bool, optional): if ``False``, the order of the trajectories is not shuffled. Defaults to ``True``. + compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of + the :meth:`~sample` method will be compiled with :func:`~torch.compile`. + Keyword arguments can also be passed to torch.compile with this arg. + Defaults to ``False``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first @@ -1256,6 +1331,7 @@ def __init__( truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, shuffle: bool = True, + compile: bool | dict = False, ): SliceSampler.__init__( self, @@ -1268,6 +1344,7 @@ def __init__( strict_length=strict_length, ends=ends, trajectories=trajectories, + compile=compile, ) SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle) @@ -1376,7 +1453,25 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): or when this signal is readily available. Must be used with ``cache_values=True`` and cannot be used in conjunction with ``end_key`` or ``traj_key``. cache_values (bool, optional): to be used with static datasets. - Will cache the start and end signal of the trajectory. + Will cache the start and end signal of the trajectory. This can be safely used even + if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend` + as this operation will erase the cache. + + .. warning:: ``cache_values=True`` will not work if the sampler is used with a + storage that is extended by another buffer. For instance: + + >>> buffer0 = ReplayBuffer(storage=storage, + ... sampler=SliceSampler(num_slices=8, cache_values=True), + ... writer=ImmutableWriter()) + >>> buffer1 = ReplayBuffer(storage=storage, + ... sampler=other_sampler) + >>> # Wrong! Does not erase the buffer from the sampler of buffer0 + >>> buffer1.extend(data) + + .. warning:: ``cache_values=True`` will not work as expected if the buffer is + shared between processes and one process is responsible for writing + and one process for sampling, as erasing the cache can only be done locally. + truncated_key (NestedKey, optional): If not ``None``, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided @@ -1391,6 +1486,10 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. + compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of + the :meth:`~sample` method will be compiled with :func:`~torch.compile`. + Keyword arguments can also be passed to torch.compile with this arg. + Defaults to ``False``. Examples: >>> import torch @@ -1447,6 +1546,7 @@ def __init__( cache_values: bool = False, truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, + compile: bool | dict = False, ): SliceSampler.__init__( self, @@ -1459,6 +1559,7 @@ def __init__( strict_length=strict_length, ends=ends, trajectories=trajectories, + compile=compile, ) PrioritizedSampler.__init__( self, @@ -1493,6 +1594,10 @@ def __getstate__(self): state = SliceSampler.__getstate__(self) state.update(PrioritizedSampler.__getstate__(self)) + def extend(self, index: torch.Tensor) -> None: + super(PrioritizedSampler, self).extend(index) + return super(SliceSampler, self).extend(index) + def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: # Sample `batch_size` indices representing the start of a slice. # The sampling is based on a weight vector. @@ -1512,7 +1617,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] # make seq_length a tensor with values clamped by lengths seq_length = lengths[traj_idx].clamp_max(seq_length) - # build a list of index that we dont want to sample: all the steps at a `seq_length` distance of + # build a list of index that we don't want to sample: all the steps at a `seq_length` distance of # the end the trajectory, with the end of trajectory (`stop_idx`) included if not isinstance(seq_length, int): try: @@ -1676,7 +1781,7 @@ class SamplerEnsemble(Sampler): The indices provided in the info dictionary are placed in a :class:`~tensordict.TensorDict` with keys ``index`` and ``buffer_ids`` that allow the upper :class:`~torchrl.data.ReplayBufferEnsemble` and :class:`~torchrl.data.StorageEnsemble` objects to retrieve the data. - This format is different than with other samplers which usually return indices + This format is different from with other samplers which usually return indices as regular tensors. """ diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 146f017f4ad..383bce0386d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -582,16 +582,16 @@ def flatten(self): def __getstate__(self): state = copy(self.__dict__) if get_spawning_popen() is None: - len = self._len + length = self._len del state["_len_value"] - state["len__context"] = len + state["len__context"] = length elif not self.initialized: # check that the storage is initialized raise RuntimeError( - f"Cannot share a storage of type {type(self)} between processed if " + f"Cannot share a storage of type {type(self)} between processes if " f"it has not been initialized yet. Populate the buffer with " f"some data in the main process before passing it to the other " - f"subprocesses (or create the buffer explicitely with a TensorStorage)." + f"subprocesses (or create the buffer explicitly with a TensorStorage)." ) else: # check that the content is shared, otherwise tell the user we can't help