Skip to content

Commit

Permalink
Fix episode buffer memory mapping on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Jul 25, 2023
1 parent ece0bbd commit 3c6f4af
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
8 changes: 6 additions & 2 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,17 @@ def add(self, episode: TensorDictBase) -> None:
self._cum_lengths = cum_lengths.tolist()
self._cum_lengths.append(len(self) + ep_len)
if self._memmap:
episode_dir = None
if self._memmap_dir is not None:
episode_dir = self._memmap_dir / f"episode_{len(self._buf)}"
episode_dir.mkdir(parents=True, exist_ok=True)
for k, v in episode.items():
episode[k] = MemmapTensor.from_tensor(
v,
filename=None if self._memmap_dir is None else self._memmap_dir / f"{k}.memmap",
filename=None if episode_dir is None else episode_dir / f"{k}.memmap",
transfer_ownership=False,
)
episode.memmap_(prefix=self._memmap_dir)
episode.memmap_(prefix=episode_dir)
episode.to(self.device)
self._buf.append(episode)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_data/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,6 @@ def test_memmap_to_file_replay_buffer():
fabric.save(ckpt_file, {"rb": rb})
ckpt = fabric.load(ckpt_file)
assert (ckpt["rb"]["observations"][:10] == rb["observations"][:10]).all()
del rb
del ckpt
shutil.rmtree(root_dir)
16 changes: 10 additions & 6 deletions tests/test_data/test_episode_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,22 @@ def test_memmap_to_file_episode_buffer():
root_dir = os.path.join("pytest_" + str(int(time.time())))
memmap_dir = os.path.join(root_dir, "memmap_buffer")
rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir)
for _ in range(buf_size // bs):
for i in range(buf_size // bs):
td = TensorDict(
{"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)},
batch_size=[bs],
)
td["dones"][-1] = 1
rb.add(td)
del td
assert rb[-1].is_memmap()
assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "meta.pt"))
assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.meta.pt"))
assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.memmap"))
assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.meta.pt"))
assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.memmap"))
assert rb.is_memmap
assert os.path.exists(os.path.join(memmap_dir, "meta.pt"))
assert os.path.exists(os.path.join(memmap_dir, "dones.meta.pt"))
assert os.path.exists(os.path.join(memmap_dir, "dones.memmap"))
assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt"))
assert os.path.exists(os.path.join(memmap_dir, "observations.memmap"))
for ep in rb.buffer:
del ep
del rb
shutil.rmtree(root_dir)

0 comments on commit 3c6f4af

Please sign in to comment.