From f613eefb0f44b6a9ccd79f82703bb309aa165009 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Jun 2024 09:14:27 +0100 Subject: [PATCH] [BugFix] Fix prefetch in samples without replacement - .sample() compatibility issues (#2226) --- test/test_rb.py | 18 +++++++++++------- torchrl/data/replay_buffers/replay_buffers.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 3bbc5fdc659..aae21953cf7 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1927,12 +1927,13 @@ def test_sampler_without_rep_state_dict(self, backend): s = new_replay_buffer.sample(batch_size=1) assert (s.exclude("index") == 0).all() - def test_sampler_without_replacement_cap_prefetch(self): + @pytest.mark.parametrize("drop_last", [False, True]) + def test_sampler_without_replacement_cap_prefetch(self, drop_last): torch.manual_seed(0) - data = TensorDict({"a": torch.arange(10)}, batch_size=[10]) + data = TensorDict({"a": torch.arange(11)}, batch_size=[11]) rb = ReplayBuffer( - storage=LazyTensorStorage(10), - sampler=SamplerWithoutReplacement(), + storage=LazyTensorStorage(11), + sampler=SamplerWithoutReplacement(drop_last=drop_last), batch_size=2, prefetch=3, ) @@ -1941,10 +1942,13 @@ def test_sampler_without_replacement_cap_prefetch(self): for _ in range(100): s = set() for i, d in enumerate(rb): - assert i <= 4 + assert i <= (4 + int(not drop_last)), i s = s.union(set(d["a"].tolist())) - assert i == 4 - assert s == set(range(10)) + assert i == (4 + int(not drop_last)), i + if drop_last: + assert s != set(range(11)) + else: + assert s == set(range(11)) @pytest.mark.parametrize( "batch_size,num_slices,slice_len,prioritized", diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 808d2af2327..a06ab656a5f 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -637,7 +637,8 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An while ( len(self._prefetch_queue) < min(self._sampler._remaining_batches, self._prefetch_cap) - ) and not self._sampler.ran_out: + and not self._sampler.ran_out + ) or not len(self._prefetch_queue): fut = self._prefetch_executor.submit(self._sample, batch_size) self._prefetch_queue.append(fut) ret = self._prefetch_queue.popleft().result()