From eece34d4de9a586f53ae46c72b2162b17a0752f4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 19 Dec 2024 16:28:45 +0000 Subject: [PATCH] [Performance] Avoid cloning trajs in SliceSampler ghstack-source-id: ac4a85a7dba5b045af980bfafaf1da95fb2c6198 Pull Request resolved: https://github.com/pytorch/rl/pull/2671 --- torchrl/data/replay_buffers/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 2ad0550ed06..273cf627521 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1243,7 +1243,7 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory.clone(), + trajectory=trajectory, at_capacity=storage._is_full, cursor=getattr(storage, "_last_cursor", None), )