From ec799575f44390bee7bad5ad157c8d1e04c0957f Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 21 Jul 2023 17:55:38 +0200 Subject: [PATCH 1/6] Set dir to save memmapped buffer --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 6 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 26 +- sheeprl/algos/droq/droq.py | 8 +- sheeprl/algos/p2e_dv1/p2e_dv1.py | 8 +- sheeprl/algos/p2e_dv2/p2e_dv2.py | 14 +- sheeprl/algos/ppo/ppo.py | 8 +- sheeprl/algos/ppo/ppo_decoupled.py | 8 +- .../algos/ppo_continuous/ppo_continuous.py | 8 +- sheeprl/algos/ppo_pixel/ppo_atari.py | 8 +- .../algos/ppo_pixel/ppo_pixel_continuous.py | 8 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 8 +- sheeprl/algos/sac/sac.py | 8 +- sheeprl/algos/sac/sac_decoupled.py | 8 +- .../algos/sac_pixel/sac_pixel_continuous.py | 6 +- sheeprl/data/buffers.py | 45 ++- tests/test_algos/test_algos.py | 326 ++++++++++-------- tests/test_data/test_buffers.py | 7 +- 17 files changed, 345 insertions(+), 165 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 4d3da750..5868d5fe 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -480,7 +480,11 @@ def main(): args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 2 ) rb = SequentialReplayBuffer( - buffer_size, args.num_envs, device=fabric.device if args.memmap_buffer else "cpu", memmap=args.memmap_buffer + buffer_size, + args.num_envs, + device=fabric.device if args.memmap_buffer else "cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) if args.checkpoint_path and args.checkpoint_buffer: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 6f71ba4f..2711654f 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -254,10 +254,13 @@ def train( imagined_trajectories[i] = imagined_latent_state # predict values and rewards - predicted_target_values = target_critic(imagined_trajectories) - predicted_rewards = world_model.reward_model(imagined_trajectories) + with torch.no_grad(): + predicted_target_values = Independent(Normal(target_critic(imagined_trajectories), 1), 1).mean + predicted_rewards = Independent(Normal(world_model.reward_model(imagined_trajectories), 1), 1).mean if args.use_continues and world_model.continue_model: - continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean + continues = Independent( + Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=False), 1 + ).mean true_done = (1 - data["dones"]).reshape(1, -1, 1) * args.gamma continues = torch.cat((true_done, continues[1:])) else: @@ -317,8 +320,7 @@ def train( dynamics = lambda_values[1:] # Reinforce - baseline = target_critic(imagined_trajectories[:-2]) - advantage = (lambda_values[1:] - baseline).detach() + advantage = (lambda_values[1:] - predicted_target_values).detach() reinforce = ( torch.stack( [ @@ -546,10 +548,20 @@ def main(): ) buffer_type = args.buffer_type.lower() if buffer_type == "sequential": - rb = SequentialReplayBuffer(buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer) + rb = SequentialReplayBuffer( + buffer_size, + args.num_envs, + device="cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) elif buffer_type == "episode": rb = EpisodeBuffer( - buffer_size, sequence_length=args.per_rank_sequence_length, device="cpu", memmap=args.memmap_buffer + buffer_size, + sequence_length=args.per_rank_sequence_length, + device="cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) else: raise ValueError(f"Unrecognized buffer type: must be one of `sequential` or `episode`, received: {buffer_type}") diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 77554207..0ab7bc72 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -234,7 +234,13 @@ def main(): buffer_size = ( args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 1 ) - rb = ReplayBuffer(buffer_size, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + buffer_size, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 4e105587..3c56d85b 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -533,7 +533,13 @@ def main(): buffer_size = ( args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 4 ) - rb = SequentialReplayBuffer(buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer) + rb = SequentialReplayBuffer( + buffer_size, + args.num_envs, + device="cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) if args.checkpoint_path and args.checkpoint_buffer: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index efb810ab..dd7e3d42 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -691,10 +691,20 @@ def main(): ) buffer_type = args.buffer_type.lower() if buffer_type == "sequential": - rb = SequentialReplayBuffer(buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer) + rb = SequentialReplayBuffer( + buffer_size, + args.num_envs, + device="cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) elif buffer_type == "episode": rb = EpisodeBuffer( - buffer_size, sequence_length=args.per_rank_sequence_length, device="cpu", memmap=args.memmap_buffer + buffer_size, + sequence_length=args.per_rank_sequence_length, + device="cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) else: raise ValueError(f"Unrecognized buffer type: must be one of `sequential` or `episode`, received: {buffer_type}") diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 5d6828b8..aca3d2fd 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -202,7 +202,13 @@ def main(): ) # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + args.rollout_steps, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 28346647..49cbd716 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -128,7 +128,13 @@ def player(args: PPOArgs, world_collective: TorchCollective, player_trainer_coll ) # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + args.rollout_steps, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(logger.log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/ppo_continuous/ppo_continuous.py b/sheeprl/algos/ppo_continuous/ppo_continuous.py index dd8cc938..325757d0 100644 --- a/sheeprl/algos/ppo_continuous/ppo_continuous.py +++ b/sheeprl/algos/ppo_continuous/ppo_continuous.py @@ -193,7 +193,13 @@ def main(): ) # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + args.rollout_steps, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/ppo_pixel/ppo_atari.py b/sheeprl/algos/ppo_pixel/ppo_atari.py index ada019a5..af00d971 100644 --- a/sheeprl/algos/ppo_pixel/ppo_atari.py +++ b/sheeprl/algos/ppo_pixel/ppo_atari.py @@ -156,7 +156,13 @@ def player(args: PPOAtariArgs, world_collective: TorchCollective, player_trainer ) # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + args.rollout_steps, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(logger.log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/ppo_pixel/ppo_pixel_continuous.py b/sheeprl/algos/ppo_pixel/ppo_pixel_continuous.py index d312e08d..7433c0ca 100644 --- a/sheeprl/algos/ppo_pixel/ppo_pixel_continuous.py +++ b/sheeprl/algos/ppo_pixel/ppo_pixel_continuous.py @@ -164,7 +164,13 @@ def player(args: PPOPixelContinuousArgs, world_collective: TorchCollective, play ) # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + args.rollout_steps, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(logger.log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index fb0ed709..1ef51d0a 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -212,7 +212,13 @@ def main(): ) # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + args.rollout_steps, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[1, args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 186bbba3..57d3dcc1 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -191,7 +191,13 @@ def main(): # Local data buffer_size = args.buffer_size // int(args.num_envs * fabric.world_size) if not args.dry_run else 1 - rb = ReplayBuffer(buffer_size, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + buffer_size, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 66640a89..72e92797 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -106,7 +106,13 @@ def player(args: SACArgs, world_collective: TorchCollective, player_trainer_coll # Local data buffer_size = args.buffer_size // args.num_envs if not args.dry_run else 1 - rb = ReplayBuffer(buffer_size, args.num_envs, device=device, memmap=args.memmap_buffer) + rb = ReplayBuffer( + buffer_size, + args.num_envs, + device=device, + memmap=args.memmap_buffer, + memmap_dir=os.path.join(logger.log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + ) step_data = TensorDict({}, batch_size=[args.num_envs], device=device) # Global variables diff --git a/sheeprl/algos/sac_pixel/sac_pixel_continuous.py b/sheeprl/algos/sac_pixel/sac_pixel_continuous.py index 7e93a9c6..ba6c3151 100644 --- a/sheeprl/algos/sac_pixel/sac_pixel_continuous.py +++ b/sheeprl/algos/sac_pixel/sac_pixel_continuous.py @@ -276,7 +276,11 @@ def main(): # Local data buffer_size = args.buffer_size // int(args.num_envs * fabric.world_size) if not args.dry_run else 1 rb = ReplayBuffer( - buffer_size, args.num_envs, device=fabric.device if args.memmap_buffer else "cpu", memmap=args.memmap_buffer + buffer_size, + args.num_envs, + device=fabric.device if args.memmap_buffer else "cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) step_data = TensorDict({}, batch_size=[args.num_envs], device=fabric.device if args.memmap_buffer else "cpu") diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index edcf3ff2..3f54ed2f 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -1,4 +1,7 @@ +import os import typing +import warnings +from pathlib import Path from typing import List, Optional, Union import numpy as np @@ -15,6 +18,7 @@ def __init__( n_envs: int = 1, device: Union[device, str] = "cpu", memmap: bool = False, + memmap_dir: Optional[Union[str, os.PathLike]] = None, ): """A replay buffer which internally uses a TensorDict. @@ -34,7 +38,17 @@ def __init__( device = torch.device(device=device) self._device = device self._memmap = memmap + self._memmap_dir = memmap_dir if self._memmap: + if memmap_dir is None: + warnings.warn( + "The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" + " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + UserWarning, + ) + else: + self._memmap_dir = Path(self._memmap_dir) + self._memmap_dir.mkdir(parents=True, exist_ok=True) self._buf = None else: self._buf = TensorDict({}, batch_size=[buffer_size, n_envs], device=device) @@ -115,13 +129,18 @@ def add(self, data: Union["ReplayBuffer", TensorDictBase]) -> None: if self._memmap and self._buf is None: self._buf = TensorDict( { - k: MemmapTensor((self._buffer_size, self._n_envs, *v.shape[2:]), dtype=v.dtype, device=v.device) + k: MemmapTensor( + (self._buffer_size, self._n_envs, *v.shape[2:]), + dtype=v.dtype, + device=v.device, + filename=None if self._memmap_dir is None else self._memmap_dir / (k + ".memmap"), + ) for k, v in data_to_store.items() }, batch_size=[self._buffer_size, self._n_envs], device=self.device, ) - self._buf.memmap_() + self._buf.memmap_(prefix=self._memmap_dir) self._buf[idxes, :] = data_to_store if self._pos + data_len >= self._buffer_size: self._full = True @@ -207,8 +226,9 @@ def __init__( n_envs: int = 1, device: Union[device, str] = "cpu", memmap: bool = False, + memmap_dir: Optional[Union[str, os.PathLike]] = None, ): - super().__init__(buffer_size, n_envs, device, memmap) + super().__init__(buffer_size, n_envs, device, memmap, memmap_dir) def sample( self, @@ -338,6 +358,7 @@ def __init__( sequence_length: int, device: Union[device, str] = "cpu", memmap: bool = False, + memmap_dir: Optional[Union[str, os.PathLike]] = None, ) -> None: if buffer_size <= 0: raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") @@ -355,6 +376,16 @@ def __init__( device = torch.device(device=device) self._device = device self._memmap = memmap + self._memmap_dir = memmap_dir + if memmap_dir is None: + warnings.warn( + "The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" + " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + UserWarning, + ) + else: + self._memmap_dir = Path(self._memmap_dir) + self._memmap_dir.mkdir(parents=True, exist_ok=True) self._chunk_length = torch.arange(sequence_length, device=self.device).reshape(1, -1) @property @@ -424,8 +455,12 @@ def add(self, episode: TensorDictBase) -> None: self._cum_lengths.append(len(self) + ep_len) if self._memmap: for k, v in episode.items(): - episode[k] = MemmapTensor.from_tensor(v) - episode.memmap_() + episode[k] = MemmapTensor.from_tensor( + v, + filename=None if self._memmap_dir is None else self._memmap_dir / (k + ".memmap"), + transfer_ownership=False, + ) + episode.memmap_(prefix=self._memmap_dir, copy_existing=True) episode.to(self.device) self._buf.append(episode) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 03df5e5e..4a57ccd6 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -9,7 +9,6 @@ import pytest import torch.distributed as dist from lightning import Fabric -from lightning.fabric.fabric import _is_using_cli from sheeprl.utils.imports import _IS_ATARI_AVAILABLE, _IS_ATARI_ROMS_AVAILABLE, _IS_WINDOWS @@ -39,10 +38,8 @@ def mock_env_and_destroy(devices): dist.destroy_process_group() -def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool = True): +def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool = True, memmap_buffer: bool = False): fabric = Fabric(accelerator="cpu") - if not _is_using_cli(): - fabric.launch() # check the presence of the checkpoint assert os.path.isdir(ckpt_path) @@ -58,6 +55,15 @@ def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool = # check args are saved assert os.path.exists(os.path.join(os.path.dirname(ckpt_path), "args.json")) + # check that memmap buffer are still there + if memmap_buffer: + rb = state["rb"] + if isinstance(rb, list): + for i in range(len(rb)): + rb[i].add(rb[i].buffer[:1]) + else: + rb.add(rb.buffer[:1]) + @pytest.mark.timeout(60) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) @@ -84,12 +90,11 @@ def test_droq(standard_args, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "args", "global_step"} - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "args", "global_step"} + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -117,12 +122,11 @@ def test_sac(standard_args, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "args", "global_step"} - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "args", "global_step"} + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -150,24 +154,23 @@ def test_sac_pixel_continuous(standard_args, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = { - "agent", - "encoder", - "decoder", - "qf_optimizer", - "actor_optimizer", - "alpha_optimizer", - "encoder_optimizer", - "decoder_optimizer", - "args", - "global_step", - "batch_size", - } - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = { + "agent", + "encoder", + "decoder", + "qf_optimizer", + "actor_optimizer", + "alpha_optimizer", + "encoder_optimizer", + "decoder_optimizer", + "args", + "global_step", + "batch_size", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -212,12 +215,11 @@ def test_sac_decoupled(standard_args, checkpoint_buffer, start_time): torchrun.main(torchrun_args) if os.environ["LT_DEVICES"] != "1": - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "args", "global_step"} - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "args", "global_step"} + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -239,9 +241,8 @@ def test_ppo(standard_args, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - check_checkpoint(ckpt_path, {"actor", "critic", "optimizer", "args", "update_step", "scheduler"}) - shutil.rmtree("pytest_" + start_time) + check_checkpoint(ckpt_path, {"actor", "critic", "optimizer", "args", "update_step", "scheduler"}) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -282,9 +283,8 @@ def test_ppo_decoupled(standard_args, start_time): torchrun.main(torchrun_args) if os.environ["LT_DEVICES"] != "1": - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) - shutil.rmtree("pytest_" + start_time) + check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -330,9 +330,8 @@ def test_ppo_atari(standard_args, start_time): torchrun.main(torchrun_args) if os.environ["LT_DEVICES"] != "1": - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) - shutil.rmtree("pytest_" + start_time) + check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -354,9 +353,8 @@ def test_ppo_continuous(standard_args, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - check_checkpoint(ckpt_path, {"actor", "critic", "optimizer", "args", "update_step", "scheduler"}) - shutil.rmtree("pytest_" + start_time) + check_checkpoint(ckpt_path, {"actor", "critic", "optimizer", "args", "update_step", "scheduler"}) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -378,9 +376,8 @@ def test_ppo_recurrent(standard_args, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) - shutil.rmtree("pytest_" + start_time) + check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -421,9 +418,8 @@ def test_ppo_pixel_continuous(standard_args, start_time): torchrun.main(torchrun_args) if os.environ["LT_DEVICES"] != "1": - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) - shutil.rmtree("pytest_" + start_time) + check_checkpoint(ckpt_path, {"agent", "optimizer", "args", "update_step", "scheduler"}) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -458,23 +454,22 @@ def test_dreamer_v1(standard_args, env_id, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = { - "world_model", - "actor", - "critic", - "world_optimizer", - "actor_optimizer", - "critic_optimizer", - "expl_decay_steps", - "args", - "global_step", - "batch_size", - } - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = { + "world_model", + "actor", + "critic", + "world_optimizer", + "actor_optimizer", + "critic_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -509,29 +504,28 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = { - "world_model", - "actor_task", - "critic_task", - "ensembles", - "world_optimizer", - "actor_task_optimizer", - "critic_task_optimizer", - "ensemble_optimizer", - "expl_decay_steps", - "args", - "global_step", - "batch_size", - "actor_exploration", - "critic_exploration", - "actor_exploration_optimizer", - "critic_exploration_optimizer", - } - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = { + "world_model", + "actor_task", + "critic_task", + "ensembles", + "world_optimizer", + "actor_task_optimizer", + "critic_task_optimizer", + "ensemble_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + "actor_exploration", + "critic_exploration", + "actor_exploration_optimizer", + "critic_exploration_optimizer", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -570,23 +564,22 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = { - "world_model", - "actor", - "critic", - "world_optimizer", - "actor_optimizer", - "critic_optimizer", - "expl_decay_steps", - "args", - "global_step", - "batch_size", - } - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = { + "world_model", + "actor", + "critic", + "world_optimizer", + "actor_optimizer", + "critic_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) @pytest.mark.timeout(60) @@ -624,26 +617,83 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): - keys = { - "world_model", - "actor_task", - "critic_task", - "ensembles", - "world_optimizer", - "actor_task_optimizer", - "critic_task_optimizer", - "ensemble_optimizer", - "expl_decay_steps", - "args", - "global_step", - "batch_size", - "actor_exploration", - "critic_exploration", - "actor_exploration_optimizer", - "critic_exploration_optimizer", - } - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree("pytest_" + start_time) + keys = { + "world_model", + "actor_task", + "critic_task", + "ensembles", + "world_optimizer", + "actor_task_optimizer", + "critic_task_optimizer", + "ensemble_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + "actor_exploration", + "critic_exploration", + "actor_exploration_optimizer", + "critic_exploration_optimizer", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize("env_id", ["discrete_dummy"]) +@pytest.mark.parametrize("checkpoint_buffer", [True]) +def test_dreamer_v2_memmap_buffer(standard_args, env_id, checkpoint_buffer, start_time): + task = importlib.import_module("sheeprl.algos.dreamer_v2.dreamer_v2") + root_dir = os.path.join("pytest_" + start_time, "dreamer_v2", os.environ["LT_DEVICES"]) + run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" + ckpt_path = os.path.join(root_dir, run_name) + version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) + ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") + args = standard_args + [ + "--per_rank_batch_size=1", + "--per_rank_sequence_length=1", + f"--buffer_size={int(os.environ['LT_DEVICES'])}", + "--learning_starts=0", + "--gradient_steps=1", + "--horizon=2", + "--env_id=" + env_id, + "--root_dir=" + root_dir, + "--run_name=" + run_name, + "--dense_units=8", + "--cnn_channels_multiplier=2", + "--recurrent_state_size=8", + "--hidden_size=8", + "--cnn_keys=rgb", + "--pretrain_steps=1", + "--layer_norm=True", + "--memmap_buffer=True", + ] + if checkpoint_buffer: + args.append("--checkpoint_buffer") + + with mock.patch.object(sys, "argv", [task.__file__] + args): + for command in task.__all__: + if command == "main": + task.__dict__[command]() + + for i in range(int(os.environ["LT_DEVICES"])): + assert os.path.exists(os.path.join(os.path.dirname(ckpt_path), "memmap_buffer", f"rank_{i}")) + + keys = { + "world_model", + "actor", + "critic", + "world_optimizer", + "actor_optimizer", + "critic_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer, memmap_buffer=False) + shutil.rmtree("pytest_" + start_time) diff --git a/tests/test_data/test_buffers.py b/tests/test_data/test_buffers.py index 429853c8..1b7b72a0 100644 --- a/tests/test_data/test_buffers.py +++ b/tests/test_data/test_buffers.py @@ -152,7 +152,12 @@ def test_replay_buffer_sample_fail(): def test_memmap_replay_buffer(): buf_size = 1000000 n_envs = 4 - rb = ReplayBuffer(buf_size, n_envs, memmap=True) + with pytest.warns( + UserWarning, + match="The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" + " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + ): + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=None) td = TensorDict( {"observations": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8)}, batch_size=[10, n_envs] ) From 1a1f68793bf4c81c90d0390b5b4534453137e396 Mon Sep 17 00:00:00 2001 From: belerico Date: Mon, 24 Jul 2023 15:01:14 +0200 Subject: [PATCH 2/6] Add memmap buffer to file tests --- sheeprl/data/buffers.py | 6 +-- tests/test_algos/test_algos.py | 69 +------------------------- tests/test_data/test_buffers.py | 27 ++++++++++ tests/test_data/test_episode_buffer.py | 35 ++++++++++++- 4 files changed, 65 insertions(+), 72 deletions(-) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 3f54ed2f..ac6bc983 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -133,7 +133,7 @@ def add(self, data: Union["ReplayBuffer", TensorDictBase]) -> None: (self._buffer_size, self._n_envs, *v.shape[2:]), dtype=v.dtype, device=v.device, - filename=None if self._memmap_dir is None else self._memmap_dir / (k + ".memmap"), + filename=None if self._memmap_dir is None else self._memmap_dir / f"{k}.memmap", ) for k, v in data_to_store.items() }, @@ -457,10 +457,10 @@ def add(self, episode: TensorDictBase) -> None: for k, v in episode.items(): episode[k] = MemmapTensor.from_tensor( v, - filename=None if self._memmap_dir is None else self._memmap_dir / (k + ".memmap"), + filename=None if self._memmap_dir is None else self._memmap_dir / f"{k}.memmap", transfer_ownership=False, ) - episode.memmap_(prefix=self._memmap_dir, copy_existing=True) + episode.memmap_(prefix=self._memmap_dir) episode.to(self.device) self._buf.append(episode) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 4a57ccd6..85a2c39b 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -38,7 +38,7 @@ def mock_env_and_destroy(devices): dist.destroy_process_group() -def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool = True, memmap_buffer: bool = False): +def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool = True): fabric = Fabric(accelerator="cpu") # check the presence of the checkpoint @@ -55,15 +55,6 @@ def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool = # check args are saved assert os.path.exists(os.path.join(os.path.dirname(ckpt_path), "args.json")) - # check that memmap buffer are still there - if memmap_buffer: - rb = state["rb"] - if isinstance(rb, list): - for i in range(len(rb)): - rb[i].add(rb[i].buffer[:1]) - else: - rb.add(rb.buffer[:1]) - @pytest.mark.timeout(60) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) @@ -639,61 +630,3 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): keys.add("rb") check_checkpoint(ckpt_path, keys, checkpoint_buffer) shutil.rmtree("pytest_" + start_time) - - -@pytest.mark.timeout(60) -@pytest.mark.parametrize("env_id", ["discrete_dummy"]) -@pytest.mark.parametrize("checkpoint_buffer", [True]) -def test_dreamer_v2_memmap_buffer(standard_args, env_id, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.dreamer_v2.dreamer_v2") - root_dir = os.path.join("pytest_" + start_time, "dreamer_v2", os.environ["LT_DEVICES"]) - run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" - ckpt_path = os.path.join(root_dir, run_name) - version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) - ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") - args = standard_args + [ - "--per_rank_batch_size=1", - "--per_rank_sequence_length=1", - f"--buffer_size={int(os.environ['LT_DEVICES'])}", - "--learning_starts=0", - "--gradient_steps=1", - "--horizon=2", - "--env_id=" + env_id, - "--root_dir=" + root_dir, - "--run_name=" + run_name, - "--dense_units=8", - "--cnn_channels_multiplier=2", - "--recurrent_state_size=8", - "--hidden_size=8", - "--cnn_keys=rgb", - "--pretrain_steps=1", - "--layer_norm=True", - "--memmap_buffer=True", - ] - if checkpoint_buffer: - args.append("--checkpoint_buffer") - - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() - - for i in range(int(os.environ["LT_DEVICES"])): - assert os.path.exists(os.path.join(os.path.dirname(ckpt_path), "memmap_buffer", f"rank_{i}")) - - keys = { - "world_model", - "actor", - "critic", - "world_optimizer", - "actor_optimizer", - "critic_optimizer", - "expl_decay_steps", - "args", - "global_step", - "batch_size", - } - if checkpoint_buffer: - keys.add("rb") - check_checkpoint(ckpt_path, keys, checkpoint_buffer, memmap_buffer=False) - shutil.rmtree("pytest_" + start_time) diff --git a/tests/test_data/test_buffers.py b/tests/test_data/test_buffers.py index 1b7b72a0..8b57df42 100644 --- a/tests/test_data/test_buffers.py +++ b/tests/test_data/test_buffers.py @@ -1,5 +1,10 @@ +import os +import shutil +import time + import pytest import torch +from lightning import Fabric from tensordict import TensorDict from sheeprl.data.buffers import ReplayBuffer @@ -163,3 +168,25 @@ def test_memmap_replay_buffer(): ) rb.add(td) assert rb.buffer.is_memmap() + + +def test_memmap_to_file_replay_buffer(): + buf_size = 1000000 + n_envs = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir) + td = TensorDict( + {"observations": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8)}, batch_size=[10, n_envs] + ) + rb.add(td) + assert rb.buffer.is_memmap() + assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) + fabric = Fabric(devices=1, accelerator="cpu") + ckpt_file = os.path.join(root_dir, "checkpoint", "ckpt.ckpt") + fabric.save(ckpt_file, {"rb": rb}) + ckpt = fabric.load(ckpt_file) + assert (ckpt["rb"]["observations"][:10] == rb["observations"][:10]).all() + shutil.rmtree(root_dir) diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index 0f14bf35..d88d48bd 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -1,3 +1,7 @@ +import os +import shutil +import time + import pytest import torch from tensordict import TensorDict @@ -162,7 +166,30 @@ def test_memmap_episode_buffer(): buf_size = 10 bs = 4 sl = 4 - rb = EpisodeBuffer(buf_size, sl, memmap=True) + with pytest.warns( + UserWarning, + match="The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" + " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + ): + rb = EpisodeBuffer(buf_size, sl, memmap=True) + for _ in range(buf_size // bs): + td = TensorDict( + {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, + batch_size=[bs], + ) + td["dones"][-1] = 1 + rb.add(td) + assert rb[-1].is_memmap() + assert rb.is_memmap + + +def test_memmap_to_file_episode_buffer(): + buf_size = 10 + bs = 4 + sl = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir) for _ in range(buf_size // bs): td = TensorDict( {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, @@ -172,3 +199,9 @@ def test_memmap_episode_buffer(): rb.add(td) assert rb[-1].is_memmap() assert rb.is_memmap + assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "dones.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) + assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) + shutil.rmtree(root_dir) From ece0bbd08ffbe3444371a17040bdf8aa0fc61698 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 25 Jul 2023 17:30:28 +0200 Subject: [PATCH 3/6] Reduce buffer size in tests --- tests/test_data/test_buffers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_data/test_buffers.py b/tests/test_data/test_buffers.py index 8b57df42..4c3f3972 100644 --- a/tests/test_data/test_buffers.py +++ b/tests/test_data/test_buffers.py @@ -155,7 +155,7 @@ def test_replay_buffer_sample_fail(): def test_memmap_replay_buffer(): - buf_size = 1000000 + buf_size = 10 n_envs = 4 with pytest.warns( UserWarning, @@ -171,7 +171,7 @@ def test_memmap_replay_buffer(): def test_memmap_to_file_replay_buffer(): - buf_size = 1000000 + buf_size = 10 n_envs = 4 root_dir = os.path.join("pytest_" + str(int(time.time()))) memmap_dir = os.path.join(root_dir, "memmap_buffer") From 3c6f4af70062e542ca79a2ee52178fb9a686e4d9 Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 25 Jul 2023 19:09:18 +0200 Subject: [PATCH 4/6] Fix episode buffer memory mapping on windows --- sheeprl/data/buffers.py | 8 ++++++-- tests/test_data/test_buffers.py | 2 ++ tests/test_data/test_episode_buffer.py | 16 ++++++++++------ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index ac6bc983..2c6d849c 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -454,13 +454,17 @@ def add(self, episode: TensorDictBase) -> None: self._cum_lengths = cum_lengths.tolist() self._cum_lengths.append(len(self) + ep_len) if self._memmap: + episode_dir = None + if self._memmap_dir is not None: + episode_dir = self._memmap_dir / f"episode_{len(self._buf)}" + episode_dir.mkdir(parents=True, exist_ok=True) for k, v in episode.items(): episode[k] = MemmapTensor.from_tensor( v, - filename=None if self._memmap_dir is None else self._memmap_dir / f"{k}.memmap", + filename=None if episode_dir is None else episode_dir / f"{k}.memmap", transfer_ownership=False, ) - episode.memmap_(prefix=self._memmap_dir) + episode.memmap_(prefix=episode_dir) episode.to(self.device) self._buf.append(episode) diff --git a/tests/test_data/test_buffers.py b/tests/test_data/test_buffers.py index 4c3f3972..683ca72f 100644 --- a/tests/test_data/test_buffers.py +++ b/tests/test_data/test_buffers.py @@ -189,4 +189,6 @@ def test_memmap_to_file_replay_buffer(): fabric.save(ckpt_file, {"rb": rb}) ckpt = fabric.load(ckpt_file) assert (ckpt["rb"]["observations"][:10] == rb["observations"][:10]).all() + del rb + del ckpt shutil.rmtree(root_dir) diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index d88d48bd..cc3f196f 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -190,18 +190,22 @@ def test_memmap_to_file_episode_buffer(): root_dir = os.path.join("pytest_" + str(int(time.time()))) memmap_dir = os.path.join(root_dir, "memmap_buffer") rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir) - for _ in range(buf_size // bs): + for i in range(buf_size // bs): td = TensorDict( {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, batch_size=[bs], ) td["dones"][-1] = 1 rb.add(td) + del td assert rb[-1].is_memmap() + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.memmap")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.memmap")) assert rb.is_memmap - assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "dones.meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) - assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) + for ep in rb.buffer: + del ep + del rb shutil.rmtree(root_dir) From 6f41f44efe5276d5a5f4653e1a5260ec346780b5 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 26 Jul 2023 09:43:46 +0200 Subject: [PATCH 5/6] Remove unneeded buffers after those are removed from the list --- sheeprl/data/buffers.py | 18 +++++++++++++++--- tests/test_data/test_episode_buffer.py | 17 +++++++++-------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 2c6d849c..871f66e8 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -1,5 +1,6 @@ import os import typing +import uuid import warnings from pathlib import Path from typing import List, Optional, Union @@ -9,6 +10,7 @@ from tensordict import MemmapTensor, TensorDict from tensordict.tensordict import TensorDictBase from torch import Size, Tensor, device +import shutil class ReplayBuffer: @@ -449,14 +451,24 @@ def add(self, episode: TensorDictBase) -> None: if self.full or len(self) + ep_len > self._buffer_size: cum_lengths = np.array(self._cum_lengths) mask = (len(self) - cum_lengths + ep_len) <= self._buffer_size - self._buf = self._buf[mask.argmax() + 1 :] - cum_lengths = cum_lengths[mask.argmax() + 1 :] - cum_lengths[mask.argmax()] + last_to_remove = mask.argmax() + # Remove all memmaped episodes + if self._memmap and self._memmap_dir is not None: + for i in range(last_to_remove + 1): + filename = self._buf[i][self._buf[i].sorted_keys[0]].filename + for k in self._buf[i].sorted_keys: + f = self._buf[i][k].file + if f is not None: + f.close() + del self._buf[i] + shutil.rmtree(os.path.dirname(filename)) + cum_lengths = cum_lengths[last_to_remove + 1 :] - cum_lengths[last_to_remove] self._cum_lengths = cum_lengths.tolist() self._cum_lengths.append(len(self) + ep_len) if self._memmap: episode_dir = None if self._memmap_dir is not None: - episode_dir = self._memmap_dir / f"episode_{len(self._buf)}" + episode_dir = self._memmap_dir / f"episode_{str(uuid.uuid4())}" episode_dir.mkdir(parents=True, exist_ok=True) for k, v in episode.items(): episode[k] = MemmapTensor.from_tensor( diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index cc3f196f..92846318 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -184,13 +184,13 @@ def test_memmap_episode_buffer(): def test_memmap_to_file_episode_buffer(): - buf_size = 10 - bs = 4 + buf_size = 5 + bs = 5 sl = 4 root_dir = os.path.join("pytest_" + str(int(time.time()))) memmap_dir = os.path.join(root_dir, "memmap_buffer") rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir) - for i in range(buf_size // bs): + for i in range(4): td = TensorDict( {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, batch_size=[bs], @@ -199,11 +199,12 @@ def test_memmap_to_file_episode_buffer(): rb.add(td) del td assert rb[-1].is_memmap() - assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "dones.memmap")) - assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, f"episode_{i}", "observations.memmap")) + memmap_dir = os.path.dirname(rb.buffer[-1][rb.buffer[-1].sorted_keys[0]].filename) + assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "dones.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) + assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) + assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) assert rb.is_memmap for ep in rb.buffer: del ep From c51dcd5ed3bb0d2d84e51a2b3645e9d9342ac44e Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 26 Jul 2023 10:00:39 +0200 Subject: [PATCH 6/6] Fix remove buffer if not memmapped --- sheeprl/data/buffers.py | 12 +++++++----- tests/test_data/test_episode_buffer.py | 6 +++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 871f66e8..3f665fad 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -454,14 +454,16 @@ def add(self, episode: TensorDictBase) -> None: last_to_remove = mask.argmax() # Remove all memmaped episodes if self._memmap and self._memmap_dir is not None: - for i in range(last_to_remove + 1): - filename = self._buf[i][self._buf[i].sorted_keys[0]].filename - for k in self._buf[i].sorted_keys: - f = self._buf[i][k].file + for _ in range(last_to_remove + 1): + filename = self._buf[0][self._buf[0].sorted_keys[0]].filename + for k in self._buf[0].sorted_keys: + f = self._buf[0][k].file if f is not None: f.close() - del self._buf[i] + del self._buf[0] shutil.rmtree(os.path.dirname(filename)) + else: + self._buf = self._buf[last_to_remove + 1 :] cum_lengths = cum_lengths[last_to_remove + 1 :] - cum_lengths[last_to_remove] self._cum_lengths = cum_lengths.tolist() self._cum_lengths.append(len(self) + ep_len) diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index 92846318..e30fb76f 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -184,13 +184,17 @@ def test_memmap_episode_buffer(): def test_memmap_to_file_episode_buffer(): - buf_size = 5 + buf_size = 10 bs = 5 sl = 4 root_dir = os.path.join("pytest_" + str(int(time.time()))) memmap_dir = os.path.join(root_dir, "memmap_buffer") rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir) for i in range(4): + if i >= 2: + bs = 7 + else: + bs = 5 td = TensorDict( {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, batch_size=[bs],