diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 869ea5cdae3..4658651dcf0 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -982,22 +982,20 @@ def _find_start_stop_traj( # faster end = trajectory[:-1] != trajectory[1:] - end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0) + if not at_capacity: + end = torch.cat([end, torch.ones_like(end[:1])], 0) + else: + end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0) length = trajectory.shape[0] else: - # TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True - # We presume that not done at the end means that the traj spans across end and beginning of storage length = end.shape[0] + if not at_capacity: + end = end.clone() + end[length - 1] = True + ndim = end.ndim - if not at_capacity: - end = torch.index_fill( - end, - index=torch.tensor(-1, device=end.device, dtype=torch.long), - dim=0, - value=1, - ) - else: + if at_capacity: # we must have at least one end by traj to individuate trajectories # so if no end can be found we set it manually if cursor is not None: @@ -1019,7 +1017,6 @@ def _find_start_stop_traj( mask = ~end.any(0, True) mask = torch.cat([torch.zeros_like(end[:-1]), mask]) end = torch.masked_fill(mask, end, 1) - ndim = end.ndim if ndim == 0: raise RuntimeError( "Expected the end-of-trajectory signal to be at least 1-dimensional." @@ -1126,7 +1123,7 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory, + trajectory=trajectory.clone(), at_capacity=storage._is_full, cursor=getattr(storage, "_last_cursor", None), ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index fb684e7c043..d2d37e86f07 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -902,9 +902,7 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.to(self.device) - out = out.expand(max_size_along_dim0(data.shape)) - out = out.clone() - out = out.zero_() + out = torch.empty_like(out.expand(max_size_along_dim0(data.shape))) else: # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = tree_map(