From 3c6f4af70062e542ca79a2ee52178fb9a686e4d9 Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 25 Jul 2023 19:09:18 +0200 Subject: [PATCH] Fix episode buffer memory mapping on windows --- sheeprl/data/buffers.py | 8 ++++++-- tests/test_data/test_buffers.py | 2 ++ tests/test_data/test_episode_buffer.py | 16 ++++++++++------ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index ac6bc983..2c6d849c 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -454,13 +454,17 @@ def add(self, episode: TensorDictBase) -> None: self._cum_lengths = cum_lengths.tolist() self._cum_lengths.append(len(self) + ep_len) if self._memmap: + episode_dir = None + if self._memmap_dir is not None: + episode_dir = self._memmap_dir / f"episode_{len(self._buf)}" + episode_dir.mkdir(parents=True, exist_ok=True) for k, v in episode.items(): episode[k] = MemmapTensor.from_tensor( v, - filename=None if self._memmap_dir is None else self._memmap_dir / f"{k}.memmap", + filename=None if episode_dir is None else episode_dir / f"{k}.memmap", transfer_ownership=False, ) - episode.memmap_(prefix=self._memmap_dir) + episode.memmap_(prefix=episode_dir) episode.to(self.device) self._buf.append(episode) diff --git a/tests/test_data/test_buffers.py b/tests/test_data/test_buffers.py index 4c3f3972..683ca72f 100644 --- a/tests/test_data/test_buffers.py +++ b/tests/test_data/test_buffers.py @@ -189,4 +189,6 @@ def test_memmap_to_file_replay_buffer(): fabric.save(ckpt_file, {"rb": rb}) ckpt = fabric.load(ckpt_file) assert (ckpt["rb"]["observations"][:10] == rb["observations"][:10]).all() + del rb + del ckpt shutil.rmtree(root_dir) diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index d88d48bd..cc3f196f 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -190,18 +190,22 @@ def test_memmap_to_file_episode_buffer(): root_dir = os.path.join("pytest_" + str(int(time.time()))) memmap_dir = os.path.join(root_dir, "memmap_buffer") rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir) - for _ in range(buf_size // bs): + for i in range(buf_size // bs): td = TensorDict( {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, batch_size=[bs], ) td["dones"][-1] = 1 rb.add(td) + del td assert rb[-1].is_memmap() + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.memmap")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.memmap")) assert rb.is_memmap - assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "dones.meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) - assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) + for ep in rb.buffer: + del ep + del rb shutil.rmtree(root_dir)