From f0bc6ba5500694afca4e7f47a20bfd8896edc4e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 19 Dec 2024 16:28:45 +0000 Subject: [PATCH] [BugFix] Avoid KeyError in slice sampler (for compile) ghstack-source-id: 75d8882254aad3c6d8c2bb5c993a5fe93c9143e8 Pull Request resolved: https://github.com/pytorch/rl/pull/2670 --- torchrl/data/replay_buffers/samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index bbdf2387683..2ad0550ed06 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1485,13 +1485,13 @@ def _get_index( truncated[seq_length.cumsum(0) - 1] = 1 index = index.to(torch.long).unbind(-1) st_index = storage[index] - try: - done = st_index[done_key] | truncated - except KeyError: + done = st_index.get(done_key, default=None) + if done is None: done = truncated.clone() - try: - terminated = st_index[terminated_key] - except KeyError: + else: + done = done | truncated + terminated = st_index.get(terminated_key, default=None) + if terminated is None: terminated = torch.zeros_like(truncated) return index, { truncated_key: truncated,