diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 2ad0550ed06..273cf627521 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1243,7 +1243,7 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory.clone(), + trajectory=trajectory, at_capacity=storage._is_full, cursor=getattr(storage, "_last_cursor", None), )