From 4fd54fef493da9bd2084a574ff731f844c83913c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:43 +0000 Subject: [PATCH] [Performance] Avoid cloning trajs in SliceSampler ghstack-source-id: 2e133fcea716b202694cfa84df3f6e4ba3507bbc Pull Request resolved: https://github.com/pytorch/rl/pull/2671 --- torchrl/data/replay_buffers/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 2ad0550ed06..273cf627521 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1243,7 +1243,7 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory.clone(), + trajectory=trajectory, at_capacity=storage._is_full, cursor=getattr(storage, "_last_cursor", None), )