From a7b6d3341cb3d6852378d290f5b9a778fe98974a Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 25 Jan 2024 11:54:33 +0100 Subject: [PATCH 001/113] update config --- examples/dreamer/config.yaml | 88 +++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 36 deletions(-) diff --git a/examples/dreamer/config.yaml b/examples/dreamer/config.yaml index db207136656..f1a10286e58 100644 --- a/examples/dreamer/config.yaml +++ b/examples/dreamer/config.yaml @@ -1,37 +1,53 @@ -env_name: cheetah -env_task: run -env_library: dm_control -catframes: 1 -async_collection: True -record_video: 0 -frame_skip: 2 -batch_size: 50 -batch_length: 50 -total_frames: 5000000 -world_model_lr: 6e-4 -actor_value_lr: 8e-5 -from_pixels: True +env: + name: cheetah + task: run + backend: dm_control + catframes: 1 + record_video: 0 + frame_skip: 2 + from_pixels: True + grayscale: False + image_size : 64 + batch_transform: 1 + # probably not needed vvvv + normalize_rewards_online: True + normalize_rewards_online_scale: 5.0 + normalize_rewards_online_decay: 0.99999 + reward_scaling: 1.0 + +collector: + async_collection: True + total_frames: 5000000 + init_env_steps: 1000 + init_random_frames: 5000 + max_frames_per_traj: 1000 + + env_per_collector: 8 + num_workers: 8 + collector_device: cuda:1 + frames_per_batch: 800 + +optimization: + grad_clip: 100 + batch_size: 50 + batch_length: 50 + + world_model_lr: 6e-4 + actor_value_lr: 8e-5 + optim_steps_per_batch: 80 + + # we want 50 frames / traj in the replay buffer. Given the frame_skip=2 this makes each traj 100 steps long -env_per_collector: 8 -num_workers: 8 -collector_device: cuda:1 -frames_per_batch: 800 -optim_steps_per_batch: 80 -record_interval: 30 -max_frames_per_traj: 1000 -record_frames: 1000 -batch_transform: 1 -state_dim: 30 -rssm_hidden_dim: 200 -grad_clip: 100 -grayscale: False -image_size : 64 -buffer_size: 20000 -init_env_steps: 1000 -init_random_frames: 5000 -logger: csv -offline_logging: False -normalize_rewards_online: True -normalize_rewards_online_scale: 5.0 -normalize_rewards_online_decay: 0.99999 -reward_scaling: 1.0 +networks: + state_dim: 30 + rssm_hidden_dim: 200 + +replay_buffer: + buffer_size: 20000 + + +logger: + offline_logging: False + logger: csv + record_interval: 30 + record_frames: 1000 From 2f273a85346366af4a7f43d81c6194fdf314aa20 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 6 Feb 2024 08:35:38 +0100 Subject: [PATCH 002/113] fixes --- examples/dreamer/config.yaml | 6 +- examples/dreamer/dreamer.py | 108 +++++++++++++++----------- examples/dreamer/dreamer_utils.py | 122 ++++++++++++++++++++---------- 3 files changed, 150 insertions(+), 86 deletions(-) diff --git a/examples/dreamer/config.yaml b/examples/dreamer/config.yaml index 1cd74802fdc..5891d4b0d6d 100644 --- a/examples/dreamer/config.yaml +++ b/examples/dreamer/config.yaml @@ -24,7 +24,7 @@ collector: init_random_frames: 1000 frames_per_batch: 1000 max_frames_per_traj: 1000 - device: cuda:0 + device: cpu optimization: @@ -64,9 +64,9 @@ replay_buffer: logger: - logger_type: wandb + backend: wandb project: dreamer-v1 exp_name: ${env.name}-${env.task}-${env.seed} - mode: offline + mode: online record_interval: 30 record_frames: 1000 diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index e1a419102ec..a8192f62b36 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -7,10 +7,11 @@ make_dreamer, make_environments, make_replay_buffer, + log_metrics, ) # float16 -from torch.cuda.amp import GradScaler +from torch.cuda.amp import autocast, GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl.objectives.dreamer import ( @@ -19,7 +20,7 @@ DreamerValueLoss, ) -# from torchrl.record.loggers import generate_exp_name, get_logger +from torchrl.record.loggers import generate_exp_name, get_logger # from torchrl.trainers.helpers.envs import ( # correct_for_frame_skip, @@ -40,17 +41,16 @@ def main(cfg: "DictConfig"): # noqa: F821 else: device = torch.device("cpu") - # exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) - # logger = get_logger( - # logger_type=cfg.logger.logger_type, - # logger_name="dreamer", - # experiment_name=exp_name, - # wandb_kwargs={ - # "project": cfg.logger.project, - # "mode": cfg.logger.mode, - # "config": cfg, - # }, - # ) + # Create logger + exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="dreamer_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg}, + ) train_env, test_env = make_environments(cfg=cfg, device=device) @@ -59,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821 value_key = "state_value" world_model, model_based_env, actor_model, value_model, policy = make_dreamer( config=cfg, - test_env=test_env, device=device, action_key=action_key, value_key=value_key, @@ -92,6 +91,7 @@ def main(cfg: "DictConfig"): # noqa: F821 buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, device=cfg.networks.device, + pixel_obs=cfg.env.from_pixels, ) # Training loop @@ -105,9 +105,9 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_opt = torch.optim.Adam(actor_model.parameters(), lr=cfg.optimization.actor_lr) value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.optimization.value_lr) - scaler1 = GradScaler() - scaler2 = GradScaler() - scaler3 = GradScaler() + # scaler1 = GradScaler() + # scaler2 = GradScaler() + # scaler3 = GradScaler() init_random_frames = cfg.collector.init_random_frames batch_size = cfg.optimization.batch_size @@ -119,6 +119,11 @@ def main(cfg: "DictConfig"): # noqa: F821 current_frames = tensordict.numel() collected_frames += current_frames + tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) + tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to( + torch.uint8 + ) + ep_reward = tensordict.get("episode_reward")[:, -1] replay_buffer.extend(tensordict.cpu()) if collected_frames >= init_random_frames: @@ -128,45 +133,62 @@ def main(cfg: "DictConfig"): # noqa: F821 device, non_blocking=True ) # update world model - # with autocast(dtype=torch.float16): - model_loss_td, sampled_tensordict = world_model_loss(sampled_tensordict) - loss_world_model = ( - model_loss_td["loss_model_kl"] - + model_loss_td["loss_model_reco"] - + model_loss_td["loss_model_reward"] - ) + with autocast(dtype=torch.float16): + model_loss_td, sampled_tensordict = world_model_loss( + sampled_tensordict + ) + loss_world_model = ( + model_loss_td["loss_model_kl"] + + model_loss_td["loss_model_reco"] + + model_loss_td["loss_model_reward"] + ) world_model_opt.zero_grad() - scaler1.scale(loss_world_model).backward() - scaler1.unscale_(world_model_opt) + loss_world_model.backward() + # scaler1.scale(loss_world_model).backward() + # scaler1.unscale_(world_model_opt) clip_grad_norm_(world_model.parameters(), grad_clip) - - scaler1.step(world_model_opt) - scaler1.update() + world_model_opt.step() + # scaler1.step(world_model_opt) + # scaler1.update() # update actor network - # with autocast(dtype=torch.float16): - actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) + with autocast(dtype=torch.float16): + actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() - scaler2.scale(actor_loss_td["loss_actor"]).backward() - scaler2.unscale_(actor_opt) + actor_loss_td["loss_actor"].backward() + # scaler2.scale(actor_loss_td["loss_actor"]).backward() + # scaler2.unscale_(actor_opt) clip_grad_norm_(actor_model.parameters(), grad_clip) - - scaler2.step(actor_opt) - scaler2.update() + actor_opt.step() + # scaler2.step(actor_opt) + # scaler2.update() # update value network - # with autocast(dtype=torch.float16): - value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) + with autocast(dtype=torch.float16): + value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() - scaler3.scale(value_loss_td["loss_value"]).backward() - scaler3.unscale_(value_opt) + value_loss_td["loss_value"].backward() + # scaler3.scale(value_loss_td["loss_value"]).backward() + # scaler3.unscale_(value_opt) clip_grad_norm_(value_model.parameters(), grad_clip) - - scaler3.step(value_opt) - scaler3.update() + value_opt.step() + # scaler3.step(value_opt) + # scaler3.update() + + metrics_to_log = { + "reward": ep_reward.item(), + "loss_model_kl": model_loss_td["loss_model_kl"].item(), + "loss_model_reco": model_loss_td["loss_model_reco"].item(), + "loss_model_reward": model_loss_td["loss_model_reward"].item(), + "loss_actor": actor_loss_td["loss_actor"].item(), + "loss_value": value_loss_td["loss_value"].item(), + } + + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) policy.step(current_frames) collector.update_policy_weights_() diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 38275118cdc..c07f89d1295 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -32,9 +32,12 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import ( + Compose, # CatFrames, # CenterCrop, DoubleToFloat, + DTypeCastTransform, + ExcludeTransform, FrameSkipTransform, GrayScale, # NoopResetEnv, @@ -45,9 +48,8 @@ RewardSum, ToTensorImage, TransformedEnv, - VecNorm, + UnsqueezeTransform, ) - from torchrl.envs.transforms.transforms import TensorDictPrimer # FlattenObservation from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( @@ -83,63 +85,72 @@ def _make_env(cfg, device): ) elif lib == "dm_control": env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) - env = TransformedEnv(env) - if cfg.env.from_pixels: - env.append_transform(ToTensorImage()) - if cfg.env.grayscale: - env.append_transform(GrayScale()) - img_size = cfg.env.image_size - env.append_transform(Resize(img_size, img_size)) - env.append_transform(VecNorm(in_keys=["pixels"])) - obs_stats = { - "loc": torch.zeros(()), - "scale": torch.ones(()), - } - obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) - env.append_transform(obs_norm) - else: - # TODO: - # concatenate vel and pos to observation - # .apppend_transform(C) - pass - - env.append_transform(DoubleToFloat()) - env.append_transform(RewardSum()) - env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) + return env + else: + raise NotImplementedError(f"Unknown lib {lib}.") + +def transform_env(cfg, env, parallel_envs, dummy=False): + env = TransformedEnv(env) + if cfg.env.from_pixels: + # transforms pixel from 0-255 to 0-1 (uint8 to float32) + env.append_transform(ToTensorImage(from_int=True)) + if cfg.env.grayscale: + env.append_transform(GrayScale()) + img_size = cfg.env.image_size + env.append_transform(Resize(img_size, img_size)) + else: + # TODO: + # concatenate vel and pos to observation + # .apppend_transform(C) + pass + + env.append_transform(DoubleToFloat()) + env.append_transform(RewardSum()) + env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) + if dummy: default_dict = { "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim)), "belief": UnboundedContinuousTensorSpec( shape=(cfg.networks.rssm_hidden_dim) ), } - env.append_transform( - TensorDictPrimer(random=False, default_value=0, **default_dict) - ) - return env - else: - raise NotImplementedError(f"Unknown lib {lib}.") + default_dict = { + "state": UnboundedContinuousTensorSpec( + shape=(parallel_envs, cfg.networks.state_dim) + ), + "belief": UnboundedContinuousTensorSpec( + shape=(parallel_envs, cfg.networks.rssm_hidden_dim) + ), + } + env.append_transform( + TensorDictPrimer(random=False, default_value=0, **default_dict) + ) + + return env -def make_environments(cfg, device): +def make_environments(cfg, device, parallel_envs=1): """Make environments for training and evaluation.""" train_env = ParallelEnv( - 1, + parallel_envs, EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), ) + train_env = transform_env(cfg, train_env, parallel_envs) train_env.set_seed(cfg.env.seed) eval_env = ParallelEnv( - 1, + parallel_envs, EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), ) + eval_env = transform_env(cfg, eval_env, parallel_envs) eval_env.set_seed(cfg.env.seed + 1) + return train_env, eval_env def make_dreamer( config, - test_env, device, action_key: str = "action", value_key: str = "state_value", @@ -156,6 +167,8 @@ def make_dreamer( obsevation_out_key = "reco_observation" raise NotImplementedError("Currently only pixel observations are supported.") + test_env = _make_env(config, device=device) + test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) # Make RSSM rssm_prior = RSSMPrior( hidden_dim=config.networks.rssm_hidden_dim, @@ -187,9 +200,8 @@ def make_dreamer( # Initialize world model with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - tensordict = test_env.fake_tensordict().unsqueeze(-1) + tensordict = test_env.rollout(5).unsqueeze(-1) tensordict = tensordict.to_tensordict() - tensordict = tensordict world_model(tensordict) # Create model-based environment @@ -251,6 +263,8 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=cfg.collector.device, + reset_at_each_iter=True, + ) collector.set_seed(cfg.env.seed) return collector @@ -264,13 +278,36 @@ def make_replay_buffer( buffer_scratch_dir=None, device="cpu", prefetch=3, + pixel_obs=True, ): with ( tempfile.TemporaryDirectory() if buffer_scratch_dir is None else nullcontext(buffer_scratch_dir) ) as scratch_dir: + transforms = [] crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) + transforms.append(crop_seq) + exclude_keys = ExcludeTransform( + ("next", "pixels"), + ("next", "belief"), + ("next", "state"), + ) + # transforms.append(exclude_keys) + if pixel_obs: + # dtype_transform = DTypeCastTransform(torch.double, torch.float32, in_keys=["pixels"]) + # from 0-255 to 0-1 + norm_obs = ObservationNorm( + loc=0, + scale=255, + standard_normal=True, + in_keys=["pixels", ("next", "pixels")], + ) + # transforms.append(dtype_transform) + transforms.append(norm_obs) + + transforms = Compose(*transforms) + replay_buffer = TensorDictReplayBuffer( pin_memory=False, prefetch=prefetch, @@ -279,7 +316,7 @@ def make_replay_buffer( scratch_dir=scratch_dir, device=device, ), - transform=crop_seq, + transform=transforms, batch_size=batch_size, ) return replay_buffer @@ -479,8 +516,8 @@ def _dreamer_make_mbenv( model_based_env.set_specs_from_env(test_env) model_based_env = TransformedEnv(model_based_env) default_dict = { - "state": UnboundedContinuousTensorSpec((1, state_dim)), - "belief": UnboundedContinuousTensorSpec((1, rssm_hidden_dim)), + "state": UnboundedContinuousTensorSpec(shape=(state_dim)), + "belief": UnboundedContinuousTensorSpec(shape=(rssm_hidden_dim)), } model_based_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -545,6 +582,11 @@ def _dreamer_make_world_model( return world_model +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + def get_activation(name): if name == "relu": return nn.ReLU From 6856587e203d16cbf48c2f6cc84598210fe7b5da Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 6 Feb 2024 08:37:43 +0100 Subject: [PATCH 003/113] fix --- examples/dreamer/dreamer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index a8192f62b36..70e2bd1401f 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -3,11 +3,11 @@ import torch.cuda import tqdm from dreamer_utils import ( + log_metrics, make_collector, make_dreamer, make_environments, make_replay_buffer, - log_metrics, ) # float16 From 578150ee5c8300dcb0c7e6546fab4697e4baa82e Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 6 Feb 2024 08:41:32 +0100 Subject: [PATCH 004/113] flake --- examples/dreamer/dreamer.py | 3 ++- examples/dreamer/dreamer_utils.py | 14 ++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 70e2bd1401f..efb6e32af42 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -11,7 +11,7 @@ ) # float16 -from torch.cuda.amp import autocast, GradScaler +from torch.cuda.amp import autocast # , GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl.objectives.dreamer import ( @@ -105,6 +105,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_opt = torch.optim.Adam(actor_model.parameters(), lr=cfg.optimization.actor_lr) value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.optimization.value_lr) + # Not sure we need those # scaler1 = GradScaler() # scaler2 = GradScaler() # scaler3 = GradScaler() diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index c07f89d1295..fc3fc2301dc 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -36,8 +36,8 @@ # CatFrames, # CenterCrop, DoubleToFloat, - DTypeCastTransform, - ExcludeTransform, + # DTypeCastTransform, + # ExcludeTransform, FrameSkipTransform, GrayScale, # NoopResetEnv, @@ -48,7 +48,7 @@ RewardSum, ToTensorImage, TransformedEnv, - UnsqueezeTransform, + # UnsqueezeTransform, ) from torchrl.envs.transforms.transforms import TensorDictPrimer # FlattenObservation from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -264,7 +264,6 @@ def make_collector(cfg, train_env, actor_model_explore): total_frames=cfg.collector.total_frames, device=cfg.collector.device, reset_at_each_iter=True, - ) collector.set_seed(cfg.env.seed) return collector @@ -288,12 +287,7 @@ def make_replay_buffer( transforms = [] crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) transforms.append(crop_seq) - exclude_keys = ExcludeTransform( - ("next", "pixels"), - ("next", "belief"), - ("next", "state"), - ) - # transforms.append(exclude_keys) + if pixel_obs: # dtype_transform = DTypeCastTransform(torch.double, torch.float32, in_keys=["pixels"]) # from 0-255 to 0-1 From 82fc9d826baa3e7cb0693b496d491ffcad91f6ff Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 8 Feb 2024 12:28:34 +0100 Subject: [PATCH 005/113] update and add dense networks --- examples/dreamer/dreamer.py | 9 ++----- examples/dreamer/dreamer_utils.py | 21 +++++++---------- torchrl/modules/models/model_based.py | 34 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index efb6e32af42..4d9bc97cbef 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -22,12 +22,6 @@ from torchrl.record.loggers import generate_exp_name, get_logger -# from torchrl.trainers.helpers.envs import ( -# correct_for_frame_skip, -# initialize_observation_norm_transforms, -# retrieve_observation_norms_state_dict, -# ) - @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 @@ -78,8 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821 value_model, model_based_env, imagination_horizon=cfg.optimization.imagination_horizon, + discount_loss=True, ) - value_loss = DreamerValueLoss(value_model) + value_loss = DreamerValueLoss(value_model, discount_loss=True) # Make collector collector = make_collector(cfg, train_env, policy) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index fc3fc2301dc..efcb19c707d 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -33,11 +33,8 @@ from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import ( Compose, - # CatFrames, # CenterCrop, DoubleToFloat, - # DTypeCastTransform, - # ExcludeTransform, FrameSkipTransform, GrayScale, # NoopResetEnv, @@ -63,6 +60,8 @@ ) from torchrl.modules.distributions import TanhNormal from torchrl.modules.models.model_based import ( + DenseDecoder, + DenseEncoder, DreamerActor, ObsDecoder, ObsEncoder, @@ -99,11 +98,6 @@ def transform_env(cfg, env, parallel_envs, dummy=False): env.append_transform(GrayScale()) img_size = cfg.env.image_size env.append_transform(Resize(img_size, img_size)) - else: - # TODO: - # concatenate vel and pos to observation - # .apppend_transform(C) - pass env.append_transform(DoubleToFloat()) env.append_transform(RewardSum()) @@ -156,6 +150,8 @@ def make_dreamer( value_key: str = "state_value", use_decoder_in_env: bool = False, ): + test_env = _make_env(config, device=device) + test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) # Make encoder and decoder if config.env.from_pixels: encoder = ObsEncoder() @@ -163,12 +159,14 @@ def make_dreamer( observation_in_key = "pixels" obsevation_out_key = "reco_pixels" else: + encoder = DenseEncoder() + decoder = DenseDecoder( + observation_dim=test_env.observation_spec["observation"].shape[-1] + ) observation_in_key = "observation" obsevation_out_key = "reco_observation" - raise NotImplementedError("Currently only pixel observations are supported.") + # raise NotImplementedError("Currently only pixel observations are supported.") - test_env = _make_env(config, device=device) - test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) # Make RSSM rssm_prior = RSSMPrior( hidden_dim=config.networks.rssm_hidden_dim, @@ -272,7 +270,6 @@ def make_collector(cfg, train_env, actor_model_explore): def make_replay_buffer( batch_size, batch_seq_len, - prb=False, buffer_size=1000000, buffer_scratch_dir=None, device="cpu", diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 6196d69c543..5abcf2bbb48 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -186,6 +186,40 @@ def forward(self, state, rnn_hidden): return obs_decoded +class DenseEncoder(nn.Module): + """Dense encoder network.""" + + def __init__(self, num_layer=3, hidden_dim=300, embedding_dim=1024): + super().__init__() + + layers = [nn.LazyLinear(hidden_dim), nn.ReLU()] + for _ in range(num_layer - 2): + layers += [nn.LazyLinear(hidden_dim), nn.ReLU()] + layers += [nn.LazyLinear(embedding_dim), nn.ReLU()] + + self.encoder = nn.Sequential(*layers) + + def forward(self, state): + return self.encoder(state) + + +class DenseDecoder(nn.Module): + """Dense decoder network.""" + + def __init__(self, observation_dim, num_layer=3, hidden_dim=300): + super().__init__() + + layers = [nn.LazyLinear(hidden_dim), nn.ReLU()] + for _ in range(num_layer - 2): + layers += [nn.LazyLinear(hidden_dim), nn.ReLU()] + layers += [nn.LazyLinear(observation_dim), nn.ReLU()] + + self.decoder = nn.Sequential(*layers) + + def forward(self, state): + return self.decoder(state) + + class RSSMRollout(TensorDictModuleBase): """Rollout the RSSM network. From 0e48d9a9b6b4402486f56200d5cd731a000f93d6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 9 Feb 2024 12:10:54 +0100 Subject: [PATCH 006/113] updates loss --- examples/dreamer/dreamer_utils.py | 95 +++++++++++++++++++++++-------- torchrl/objectives/dreamer.py | 25 ++++---- 2 files changed, 83 insertions(+), 37 deletions(-) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index efcb19c707d..cef40d9142b 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -51,8 +51,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - # NoisyLinear, - # NormalParamWrapper, + NormalParamWrapper, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, @@ -159,7 +158,7 @@ def make_dreamer( observation_in_key = "pixels" obsevation_out_key = "reco_pixels" else: - encoder = DenseEncoder() + encoder = DenseEncoder() # TODO: make them just MLPs decoder = DenseDecoder( observation_dim=test_env.observation_spec["observation"].shape[-1] ) @@ -179,7 +178,7 @@ def make_dreamer( ) # Make reward module reward_module = MLP( - out_features=1, + out_features=2, depth=2, num_cells=config.networks.hidden_dim, activation_class=get_activation(config.networks.activation), @@ -316,16 +315,26 @@ def make_replay_buffer( def _dreamer_make_value_model( hidden_dim: int = 400, activation: str = "elu", value_key: str = "state_value" ): - value_model = SafeModule( - MLP( - out_features=1, - depth=3, - num_cells=hidden_dim, - activation_class=get_activation(activation), + value_model = MLP( + out_features=2, + depth=3, + num_cells=hidden_dim, + activation_class=get_activation(activation), + ) + value_model = SafeProbabilisticTensorDictSequential( + SafeModule( + NormalParamWrapper(value_model), + in_keys=["state", "belief"], + out_keys=["loc", "scale"], + ), + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=[value_key], + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": False}, ), - in_keys=["state", "belief"], - out_keys=[value_key], ) + return value_model @@ -489,11 +498,31 @@ def _dreamer_make_mbenv( ], ), ) - reward_model = SafeModule( - reward_module, - in_keys=["state", "belief"], - out_keys=["reward"], + + reward_model = SafeProbabilisticTensorDictSequential( + SafeModule( + NormalParamWrapper(reward_module), + in_keys=["state", "belief"], + out_keys=["reward_loc", "reward_scale"], + # spec=CompositeSpec( + # **{ + # "reward_loc": UnboundedContinuousTensorSpec( + # 1, + # ), + # "reward_scale": UnboundedContinuousTensorSpec( + # 1, + # ), + # } + # ), + ), + SafeProbabilisticModule( + in_keys=["reward_loc", "reward_scale"], + out_keys=["reward"], + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": False}, + ), ) + model_based_env = DreamerEnv( world_model=WorldModelWrapper( transition_model, @@ -548,6 +577,20 @@ def _dreamer_make_world_model( ), ) + decoder = SafeProbabilisticTensorDictSequential( + SafeModule( + decoder, + in_keys=[("next", "state"), ("next", "belief")], + out_keys=["loc"], + ), + SafeProbabilisticModule( + in_keys=["loc"], + out_keys=[("next", observation_out_key)], + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": False}, + ), + ) + transition_model = SafeSequential( SafeModule( encoder, @@ -555,17 +598,23 @@ def _dreamer_make_world_model( out_keys=[("next", "encoded_latents")], ), rssm_rollout, + decoder, + ) + + reward_model = SafeProbabilisticTensorDictSequential( SafeModule( - decoder, + NormalParamWrapper(reward_module), in_keys=[("next", "state"), ("next", "belief")], - out_keys=[("next", observation_out_key)], + out_keys=[("next", "loc"), ("next", "scale")], + ), + SafeProbabilisticModule( + in_keys=[("next", "loc"), ("next", "scale")], + out_keys=[("next", "reward")], + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": False}, ), ) - reward_model = SafeModule( - reward_module, - in_keys=[("next", "state"), ("next", "belief")], - out_keys=[("next", "reward")], - ) + world_model = WorldModelWrapper( transition_model, reward_model, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 7bdfde573fa..39f9c57b2d6 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -129,20 +129,17 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict.get(("next", self.tensor_keys.posterior_mean)), tensordict.get(("next", self.tensor_keys.posterior_std)), ).unsqueeze(-1) - reco_loss = distance_loss( - tensordict.get(("next", self.tensor_keys.pixels)), - tensordict.get(("next", self.tensor_keys.reco_pixels)), - self.reco_loss, - ) - if not self.global_average: - reco_loss = reco_loss.sum((-3, -2, -1)) - reco_loss = reco_loss.mean().unsqueeze(-1) - reward_loss = distance_loss( - tensordict.get(("next", self.tensor_keys.true_reward)), - tensordict.get(("next", self.tensor_keys.reward)), - self.reward_loss, + decoder = self.world_model[0][-1] + dist = decoder.get_dist(tensordict) + reco_loss = -dist.log_prob(tensordict.get(("next", self.tensor_keys.pixels))) + + reward_model = self.world_model[1] + dist = reward_model.get_dist(tensordict) + reward_loss = -dist.log_prob( + tensordict.get(("next", self.tensor_keys.true_reward)) ) + if not self.global_average: reward_loss = reward_loss.squeeze(-1) reward_loss = reward_loss.mean().unsqueeze(-1) @@ -236,7 +233,7 @@ def __init__( model_based_env: DreamerEnv, *, imagination_horizon: int = 15, - discount_loss: bool = False, # for consistency with paper + discount_loss: bool = True, # for consistency with paper gamma: int = None, lmbda: int = None, ): @@ -392,7 +389,7 @@ def __init__( self, value_model: TensorDictModule, value_loss: Optional[str] = None, - discount_loss: bool = False, # for consistency with paper + discount_loss: bool = True, # for consistency with paper gamma: int = 0.99, ): super().__init__() From e839deef6e2df1cf3ccf774b904002bfa5b9bdaf Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 9 Feb 2024 15:34:27 +0100 Subject: [PATCH 007/113] update losses --- examples/dreamer/dreamer_utils.py | 51 ++++++++++++----------------- torchrl/objectives/dreamer.py | 54 ++++++++----------------------- 2 files changed, 34 insertions(+), 71 deletions(-) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index cef40d9142b..ab1911bddd0 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -15,6 +15,7 @@ import torch.nn as nn from tensordict.nn import InteractionType +from torch.distributions import Normal from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -51,7 +52,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - NormalParamWrapper, + # NormalParamWrapper, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, @@ -178,7 +179,7 @@ def make_dreamer( ) # Make reward module reward_module = MLP( - out_features=2, + out_features=1, depth=2, num_cells=config.networks.hidden_dim, activation_class=get_activation(config.networks.activation), @@ -316,22 +317,22 @@ def _dreamer_make_value_model( hidden_dim: int = 400, activation: str = "elu", value_key: str = "state_value" ): value_model = MLP( - out_features=2, + out_features=1, depth=3, num_cells=hidden_dim, activation_class=get_activation(activation), ) value_model = SafeProbabilisticTensorDictSequential( SafeModule( - NormalParamWrapper(value_model), + value_model, in_keys=["state", "belief"], - out_keys=["loc", "scale"], + out_keys=["loc"], ), SafeProbabilisticModule( - in_keys=["loc", "scale"], + in_keys=["loc"], out_keys=[value_key], - distribution_class=TanhNormal, - distribution_kwargs={"tanh_loc": False}, + distribution_class=Normal, + distribution_kwargs={"scale": 1.0}, ), ) @@ -501,25 +502,15 @@ def _dreamer_make_mbenv( reward_model = SafeProbabilisticTensorDictSequential( SafeModule( - NormalParamWrapper(reward_module), + reward_module, in_keys=["state", "belief"], - out_keys=["reward_loc", "reward_scale"], - # spec=CompositeSpec( - # **{ - # "reward_loc": UnboundedContinuousTensorSpec( - # 1, - # ), - # "reward_scale": UnboundedContinuousTensorSpec( - # 1, - # ), - # } - # ), + out_keys=["loc"], ), SafeProbabilisticModule( - in_keys=["reward_loc", "reward_scale"], + in_keys=["loc"], out_keys=["reward"], - distribution_class=TanhNormal, - distribution_kwargs={"tanh_loc": False}, + distribution_class=Normal, + distribution_kwargs={"scale": 1.0}, ), ) @@ -586,8 +577,8 @@ def _dreamer_make_world_model( SafeProbabilisticModule( in_keys=["loc"], out_keys=[("next", observation_out_key)], - distribution_class=TanhNormal, - distribution_kwargs={"tanh_loc": False}, + distribution_class=Normal, + distribution_kwargs={"scale": 1.0}, ), ) @@ -603,15 +594,15 @@ def _dreamer_make_world_model( reward_model = SafeProbabilisticTensorDictSequential( SafeModule( - NormalParamWrapper(reward_module), + reward_module, in_keys=[("next", "state"), ("next", "belief")], - out_keys=[("next", "loc"), ("next", "scale")], + out_keys=[("next", "loc")], ), SafeProbabilisticModule( - in_keys=[("next", "loc"), ("next", "scale")], + in_keys=[("next", "loc")], out_keys=[("next", "reward")], - distribution_class=TanhNormal, - distribution_kwargs={"tanh_loc": False}, + distribution_class=Normal, + distribution_kwargs={"scale": 1.0}, ), ) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 39f9c57b2d6..30099be83ad 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -17,7 +17,7 @@ from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, - distance_loss, + # distance_loss, hold_out_net, ValueEstimators, ) @@ -132,18 +132,16 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: decoder = self.world_model[0][-1] dist = decoder.get_dist(tensordict) - reco_loss = -dist.log_prob(tensordict.get(("next", self.tensor_keys.pixels))) + reco_loss = -dist.log_prob( + tensordict.get(("next", self.tensor_keys.pixels)) + ).mean() reward_model = self.world_model[1] dist = reward_model.get_dist(tensordict) reward_loss = -dist.log_prob( tensordict.get(("next", self.tensor_keys.true_reward)) - ) + ).mean() - if not self.global_average: - reward_loss = reward_loss.squeeze(-1) - reward_loss = reward_loss.mean().unsqueeze(-1) - # import ipdb; ipdb.set_trace() return ( TensorDict( { @@ -168,14 +166,8 @@ def kl_loss( + (posterior_std**2 + (prior_mean - posterior_mean) ** 2) / (2 * prior_std**2) - 0.5 - ) - if not self.global_average: - kl = kl.sum(-1) - if self.delayed_clamp: - kl = kl.mean().clamp_min(self.free_nats) - else: - kl = kl.clamp_min(self.free_nats).mean() - return kl + ).mean() + return kl.clamp_min(self.free_nats) class DreamerActorLoss(LossModule): @@ -262,7 +254,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.reshape(-1) with hold_out_net(self.model_based_env), set_exploration_type( - ExplorationType.RANDOM + ExplorationType.MODE ): tensordict = self.model_based_env.reset(tensordict.clone(recurse=False)) fake_data = self.model_based_env.rollout( @@ -289,9 +281,9 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: discount = gamma.expand(lambda_target.shape).clone() discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) - actor_loss = -(lambda_target * discount).sum((-2, -1)).mean() + actor_loss = -(lambda_target * discount).mean() else: - actor_loss = -lambda_target.sum((-2, -1)).mean() + actor_loss = -lambda_target.mean() loss_tensordict = TensorDict({"loss_actor": actor_loss}, []) return loss_tensordict, fake_data.detach() @@ -404,35 +396,15 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward(self, fake_data) -> torch.Tensor: lambda_target = fake_data.get("lambda_target") tensordict_select = fake_data.select(*self.value_model.in_keys) - self.value_model(tensordict_select) + dist = self.value_model.get_dist(tensordict_select) if self.discount_loss: discount = self.gamma * torch.ones_like( lambda_target, device=lambda_target.device ) discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) - value_loss = ( - ( - discount - * distance_loss( - tensordict_select.get(self.tensor_keys.value), - lambda_target, - self.value_loss, - ) - ) - .sum((-1, -2)) - .mean() - ) + value_loss = -(discount * dist.log_prob(lambda_target)).mean() else: - value_loss = ( - distance_loss( - tensordict_select.get(self.tensor_keys.value), - lambda_target, - self.value_loss, - ) - .sum((-1, -2)) - .mean() - ) - + value_loss = -dist.log_prob(lambda_target).mean() loss_tensordict = TensorDict({"loss_value": value_loss}, []) return loss_tensordict, fake_data From b555c9b7e7987c5b88749ecdce4efdb79d1f03c4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 12 Feb 2024 20:11:01 +0100 Subject: [PATCH 008/113] fixes --- examples/dreamer/dreamer_utils.py | 16 ++++++++-------- torchrl/envs/common.py | 4 ++++ torchrl/envs/model_based/dreamer.py | 13 +++++++++---- torchrl/objectives/dreamer.py | 15 ++++++++++++++- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index ab1911bddd0..a3306f6b404 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -525,14 +525,14 @@ def _dreamer_make_mbenv( ) model_based_env.set_specs_from_env(test_env) - model_based_env = TransformedEnv(model_based_env) - default_dict = { - "state": UnboundedContinuousTensorSpec(shape=(state_dim)), - "belief": UnboundedContinuousTensorSpec(shape=(rssm_hidden_dim)), - } - model_based_env.append_transform( - TensorDictPrimer(random=False, default_value=0, **default_dict) - ) + # model_based_env = TransformedEnv(model_based_env) + # default_dict = { + # "state": UnboundedContinuousTensorSpec(shape=(state_dim)), + # "belief": UnboundedContinuousTensorSpec(shape=(rssm_hidden_dim)), + # } + # model_based_env.append_transform( + # TensorDictPrimer(random=False, default_value=0, **default_dict) + # ) return model_based_env diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 87a51e6bef5..6e72be3ed20 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2328,6 +2328,10 @@ def _rollout_stop_early( for i in range(max_steps): if auto_cast_to_device: tensordict = tensordict.to(policy_device, non_blocking=True) + # TODO: tensordict states and beliefs are not detached + if "state" in tensordict.keys() and "belief" in tensordict.keys(): + tensordict["state"] = tensordict["state"].detach() + tensordict["belief"] = tensordict["belief"].detach() tensordict = policy(tensordict) if auto_cast_to_device: tensordict = tensordict.to(env_device, non_blocking=True) diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index e36ddf9e02a..916509171bf 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -57,10 +57,15 @@ def set_specs_from_env(self, env: EnvBase): def _reset(self, tensordict=None, **kwargs) -> TensorDict: batch_size = tensordict.batch_size if tensordict is not None else [] device = tensordict.device if tensordict is not None else self.device - td = self.state_spec.rand(shape=batch_size).to(device) - td.set("action", self.action_spec.rand(shape=batch_size).to(device)) - td[("next", "reward")] = self.reward_spec.rand(shape=batch_size).to(device) - td.update(self.observation_spec.rand(shape=batch_size).to(device)) + # TODO: why do we overright here incoming belief and states that are correct + if tensordict is None: + td = self.state_spec.rand(shape=batch_size).to(device) + # why dont we reuse actions taken at those steps? + td.set("action", self.action_spec.rand(shape=batch_size).to(device)) + td[("next", "reward")] = self.reward_spec.rand(shape=batch_size).to(device) + td.update(self.observation_spec.rand(shape=batch_size).to(device)) + else: + td = tensordict.clone() return td def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDict: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 30099be83ad..686e1d04af0 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -128,7 +128,7 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict.get(("next", self.tensor_keys.prior_std)), tensordict.get(("next", self.tensor_keys.posterior_mean)), tensordict.get(("next", self.tensor_keys.posterior_std)), - ).unsqueeze(-1) + ) decoder = self.world_model[0][-1] dist = decoder.get_dist(tensordict) @@ -250,13 +250,24 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: with torch.no_grad(): + # TODO: I think we need to take the "next" state and "next" beliefs tensordict = tensordict.select("state", self.tensor_keys.belief) tensordict = tensordict.reshape(-1) + # td = tensordict.select(("next", self.tensor_keys.state), ("next", self.tensor_keys.belief)) + # td = td.rename_key_(("next", "state"), "state") + # td = td.rename_key_(("next", "belief"), "belief") + # td = td.reshape(-1) + + # TODO: do we need exploration here? with hold_out_net(self.model_based_env), set_exploration_type( ExplorationType.MODE ): + # action_td = self.actor_model(td) + + # TODO: we are not using the actual batch beliefs as starting ones - should be solved! took of the primer for the mb_env tensordict = self.model_based_env.reset(tensordict.clone(recurse=False)) + # TODO: do we detach state gradients when passing again for new actions: action = self.actor(state.detach()) fake_data = self.model_based_env.rollout( max_steps=self.imagination_horizon, policy=self.actor_model, @@ -332,6 +343,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator( **hp, value_network=value_net, + vectorized=False, # TODO: vectorized version seems not to be similar to the non vectoried ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -395,6 +407,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward(self, fake_data) -> torch.Tensor: lambda_target = fake_data.get("lambda_target") + # TODO: I think this should be next state and belief tensordict_select = fake_data.select(*self.value_model.in_keys) dist = self.value_model.get_dist(tensordict_select) if self.discount_loss: From 46e72348a2970f1743ad58481db6152ea23d417d Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 13 Feb 2024 11:52:59 +0100 Subject: [PATCH 009/113] test changes --- examples/dreamer/dreamer.py | 8 ++++---- examples/dreamer/dreamer_utils.py | 24 ++++++++++++------------ torchrl/objectives/dreamer.py | 7 +++++-- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 4d9bc97cbef..a7585d1adba 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -115,10 +115,10 @@ def main(cfg: "DictConfig"): # noqa: F821 current_frames = tensordict.numel() collected_frames += current_frames - tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) - tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to( - torch.uint8 - ) + # tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) + # tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to( + # torch.uint8 + # ) ep_reward = tensordict.get("episode_reward")[:, -1] replay_buffer.extend(tensordict.cpu()) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index a3306f6b404..23bce7c40e3 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -39,7 +39,7 @@ FrameSkipTransform, GrayScale, # NoopResetEnv, - ObservationNorm, + # ObservationNorm, RandomCropTensorDict, Resize, # RewardScaling, @@ -285,17 +285,17 @@ def make_replay_buffer( crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) transforms.append(crop_seq) - if pixel_obs: - # dtype_transform = DTypeCastTransform(torch.double, torch.float32, in_keys=["pixels"]) - # from 0-255 to 0-1 - norm_obs = ObservationNorm( - loc=0, - scale=255, - standard_normal=True, - in_keys=["pixels", ("next", "pixels")], - ) - # transforms.append(dtype_transform) - transforms.append(norm_obs) + # if pixel_obs: + # # dtype_transform = DTypeCastTransform(torch.double, torch.float32, in_keys=["pixels"]) + # # from 0-255 to 0-1 + # norm_obs = ObservationNorm( + # loc=0, + # scale=255, + # standard_normal=True, + # in_keys=["pixels", ("next", "pixels")], + # ) + # # transforms.append(dtype_transform) + # transforms.append(norm_obs) transforms = Compose(*transforms) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 686e1d04af0..a7e9145ddb2 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -261,7 +261,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: # TODO: do we need exploration here? with hold_out_net(self.model_based_env), set_exploration_type( - ExplorationType.MODE + ExplorationType.MEAN ): # action_td = self.actor_model(td) @@ -408,7 +408,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward(self, fake_data) -> torch.Tensor: lambda_target = fake_data.get("lambda_target") # TODO: I think this should be next state and belief - tensordict_select = fake_data.select(*self.value_model.in_keys) + td = fake_data.select(("next", "state"), ("next", "belief")) + td = td.rename_key_(("next", "state"), "state") + tensordict_select = td.rename_key_(("next", "belief"), "belief") + # tensordict_select = fake_data.select(*self.value_model.in_keys) dist = self.value_model.get_dist(tensordict_select) if self.discount_loss: discount = self.gamma * torch.ones_like( From 12a594d447e7529060490f60789c826d5addc748 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 14 Feb 2024 18:23:25 +0100 Subject: [PATCH 010/113] add eval env --- examples/dreamer/config.yaml | 4 ++ examples/dreamer/dreamer.py | 71 ++++++++++++++++++------------- examples/dreamer/dreamer_utils.py | 58 +++++++++---------------- 3 files changed, 66 insertions(+), 67 deletions(-) diff --git a/examples/dreamer/config.yaml b/examples/dreamer/config.yaml index 5891d4b0d6d..d7d06d82395 100644 --- a/examples/dreamer/config.yaml +++ b/examples/dreamer/config.yaml @@ -58,11 +58,13 @@ networks: replay_buffer: + uint8_casting: True buffer_size: 20000 batch_size: 50 scratch_dir: ${logger.exp_name}_${env.seed} + logger: backend: wandb project: dreamer-v1 @@ -70,3 +72,5 @@ logger: mode: online record_interval: 30 record_frames: 1000 + eval_iter: 1000 + eval_rollout_steps: 1000 diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index a7585d1adba..55d14e7b0eb 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -3,6 +3,7 @@ import torch.cuda import tqdm from dreamer_utils import ( + cast_to_uint8, log_metrics, make_collector, make_dreamer, @@ -10,16 +11,16 @@ make_replay_buffer, ) -# float16 -from torch.cuda.amp import autocast # , GradScaler +# mixed precision training +from torch.cuda.amp import autocast, GradScaler from torch.nn.utils import clip_grad_norm_ +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives.dreamer import ( DreamerActorLoss, DreamerModelLoss, DreamerValueLoss, ) - from torchrl.record.loggers import generate_exp_name, get_logger @@ -87,6 +88,7 @@ def main(cfg: "DictConfig"): # noqa: F821 buffer_scratch_dir=cfg.replay_buffer.scratch_dir, device=cfg.networks.device, pixel_obs=cfg.env.from_pixels, + cast_to_uint8=cfg.replay_buffer.uint8_casting, ) # Training loop @@ -100,25 +102,29 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_opt = torch.optim.Adam(actor_model.parameters(), lr=cfg.optimization.actor_lr) value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.optimization.value_lr) - # Not sure we need those - # scaler1 = GradScaler() - # scaler2 = GradScaler() - # scaler3 = GradScaler() + # Grad scaler for mixed precision training https://pytorch.org/docs/stable/amp.html + scaler1 = GradScaler() + scaler2 = GradScaler() + scaler3 = GradScaler() init_random_frames = cfg.collector.init_random_frames batch_size = cfg.optimization.batch_size optim_steps_per_batch = cfg.optimization.optim_steps_per_batch grad_clip = cfg.optimization.grad_clip + uint8_casting = cfg.replay_buffer.uint8_casting + pixel_obs = cfg.env.from_pixels + frames_per_batch = cfg.collector.frames_per_batch + eval_iter = cfg.logger.eval_iter + eval_rollout_steps = cfg.logger.eval_rollout_steps for _, tensordict in enumerate(collector): pbar.update(tensordict.numel()) current_frames = tensordict.numel() collected_frames += current_frames - # tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) - # tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to( - # torch.uint8 - # ) + if uint8_casting and pixel_obs: + tensordict = cast_to_uint8(tensordict) + ep_reward = tensordict.get("episode_reward")[:, -1] replay_buffer.extend(tensordict.cpu()) @@ -140,39 +146,33 @@ def main(cfg: "DictConfig"): # noqa: F821 ) world_model_opt.zero_grad() - loss_world_model.backward() - # scaler1.scale(loss_world_model).backward() - # scaler1.unscale_(world_model_opt) + scaler1.scale(loss_world_model).backward() + scaler1.unscale_(world_model_opt) clip_grad_norm_(world_model.parameters(), grad_clip) - world_model_opt.step() - # scaler1.step(world_model_opt) - # scaler1.update() + scaler1.step(world_model_opt) + scaler1.update() # update actor network with autocast(dtype=torch.float16): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() - actor_loss_td["loss_actor"].backward() - # scaler2.scale(actor_loss_td["loss_actor"]).backward() - # scaler2.unscale_(actor_opt) + scaler2.scale(actor_loss_td["loss_actor"]).backward() + scaler2.unscale_(actor_opt) clip_grad_norm_(actor_model.parameters(), grad_clip) - actor_opt.step() - # scaler2.step(actor_opt) - # scaler2.update() + scaler2.step(actor_opt) + scaler2.update() # update value network with autocast(dtype=torch.float16): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() - value_loss_td["loss_value"].backward() - # scaler3.scale(value_loss_td["loss_value"]).backward() - # scaler3.unscale_(value_opt) + scaler3.scale(value_loss_td["loss_value"]).backward() + scaler3.unscale_(value_opt) clip_grad_norm_(value_model.parameters(), grad_clip) - value_opt.step() - # scaler3.step(value_opt) - # scaler3.update() + scaler3.step(value_opt) + scaler3.update() metrics_to_log = { "reward": ep_reward.item(), @@ -188,6 +188,19 @@ def main(cfg: "DictConfig"): # noqa: F821 policy.step(current_frames) collector.update_policy_weights_() + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_rollout = test_env.rollout( + eval_rollout_steps, + policy, + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) if __name__ == "__main__": diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 23bce7c40e3..e343595a911 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -5,14 +5,8 @@ import tempfile from contextlib import nullcontext -# from torchrl.record.loggers import Logger -# from torchrl.record.recorder import VideoRecorder - import torch -# from dataclasses import dataclass, field as dataclass_field -# from typing import Any, Callable, Optional, Sequence, Union - import torch.nn as nn from tensordict.nn import InteractionType from torch.distributions import Normal @@ -20,35 +14,27 @@ from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage -from torchrl.data.tensor_specs import ( - CompositeSpec, - # DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs import ParallelEnv -# from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import ( Compose, - # CenterCrop, DoubleToFloat, FrameSkipTransform, GrayScale, # NoopResetEnv, - # ObservationNorm, + ObservationNorm, RandomCropTensorDict, Resize, - # RewardScaling, RewardSum, ToTensorImage, TransformedEnv, - # UnsqueezeTransform, ) -from torchrl.envs.transforms.transforms import TensorDictPrimer # FlattenObservation +from torchrl.envs.transforms.transforms import TensorDictPrimer from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, @@ -159,13 +145,12 @@ def make_dreamer( observation_in_key = "pixels" obsevation_out_key = "reco_pixels" else: - encoder = DenseEncoder() # TODO: make them just MLPs + encoder = DenseEncoder() decoder = DenseDecoder( observation_dim=test_env.observation_spec["observation"].shape[-1] ) observation_in_key = "observation" obsevation_out_key = "reco_observation" - # raise NotImplementedError("Currently only pixel observations are supported.") # Make RSSM rssm_prior = RSSMPrior( @@ -275,6 +260,7 @@ def make_replay_buffer( device="cpu", prefetch=3, pixel_obs=True, + cast_to_uint8=True, ): with ( tempfile.TemporaryDirectory() @@ -285,17 +271,15 @@ def make_replay_buffer( crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) transforms.append(crop_seq) - # if pixel_obs: - # # dtype_transform = DTypeCastTransform(torch.double, torch.float32, in_keys=["pixels"]) - # # from 0-255 to 0-1 - # norm_obs = ObservationNorm( - # loc=0, - # scale=255, - # standard_normal=True, - # in_keys=["pixels", ("next", "pixels")], - # ) - # # transforms.append(dtype_transform) - # transforms.append(norm_obs) + if pixel_obs and cast_to_uint8: + # from 0-255 to 0-1 + norm_obs = ObservationNorm( + loc=0, + scale=255, + standard_normal=True, + in_keys=["pixels", ("next", "pixels")], + ) + transforms.append(norm_obs) transforms = Compose(*transforms) @@ -525,14 +509,6 @@ def _dreamer_make_mbenv( ) model_based_env.set_specs_from_env(test_env) - # model_based_env = TransformedEnv(model_based_env) - # default_dict = { - # "state": UnboundedContinuousTensorSpec(shape=(state_dim)), - # "belief": UnboundedContinuousTensorSpec(shape=(rssm_hidden_dim)), - # } - # model_based_env.append_transform( - # TensorDictPrimer(random=False, default_value=0, **default_dict) - # ) return model_based_env @@ -613,6 +589,12 @@ def _dreamer_make_world_model( return world_model +def cast_to_uint8(tensordict): + tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) + tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to(torch.uint8) + return tensordict + + def log_metrics(logger, metrics, step): for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) From 0f21038bd4321d0864fe26c529893ab9b6d75ce6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 15 Feb 2024 11:51:54 +0100 Subject: [PATCH 011/113] use independent normal + cleanup + dense encoder/decoder --- examples/dreamer/dreamer.py | 32 ++++++++-------- examples/dreamer/dreamer_utils.py | 54 +++++++++++++++++---------- torchrl/modules/models/model_based.py | 34 ----------------- torchrl/objectives/dreamer.py | 2 +- 4 files changed, 52 insertions(+), 70 deletions(-) diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 55d14e7b0eb..15faf975359 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -60,14 +60,12 @@ def main(cfg: "DictConfig"): # noqa: F821 use_decoder_in_env=False, ) - world_model.to(device) - model_based_env.to(device) - actor_model.to(device) - value_model.to(device) - policy.to(device) - # Losses world_model_loss = DreamerModelLoss(world_model) + # Adapt loss keys to gym backend + if cfg.env.backend == "gym": + world_model_loss.set_keys(pixels="observation", reco_pixels="reco_observation") + actor_loss = DreamerActorLoss( actor_model, value_model, @@ -174,14 +172,16 @@ def main(cfg: "DictConfig"): # noqa: F821 scaler3.step(value_opt) scaler3.update() - metrics_to_log = { - "reward": ep_reward.item(), - "loss_model_kl": model_loss_td["loss_model_kl"].item(), - "loss_model_reco": model_loss_td["loss_model_reco"].item(), - "loss_model_reward": model_loss_td["loss_model_reward"].item(), - "loss_actor": actor_loss_td["loss_actor"].item(), - "loss_value": value_loss_td["loss_value"].item(), - } + metrics_to_log = {"reward": ep_reward.item()} + if collected_frames >= init_random_frames: + loss_metrics = { + "loss_model_kl": model_loss_td["loss_model_kl"].item(), + "loss_model_reco": model_loss_td["loss_model_reco"].item(), + "loss_model_reward": model_loss_td["loss_model_reward"].item(), + "loss_actor": actor_loss_td["loss_actor"].item(), + "loss_value": value_loss_td["loss_value"].item(), + } + metrics_to_log.update(loss_metrics) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) @@ -198,9 +198,9 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - metrics_to_log["eval/reward"] = eval_reward + eval_metrics = {"eval/reward": eval_reward} if logger is not None: - log_metrics(logger, metrics_to_log, collected_frames) + log_metrics(logger, eval_metrics, collected_frames) if __name__ == "__main__": diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index e343595a911..2bf27071543 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -9,7 +9,6 @@ import torch.nn as nn from tensordict.nn import InteractionType -from torch.distributions import Normal from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -38,16 +37,13 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - # NormalParamWrapper, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, SafeSequential, ) -from torchrl.modules.distributions import TanhNormal +from torchrl.modules.distributions import IndependentNormal, TanhNormal from torchrl.modules.models.model_based import ( - DenseDecoder, - DenseEncoder, DreamerActor, ObsDecoder, ObsEncoder, @@ -69,7 +65,9 @@ def _make_env(cfg, device): device=device, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels, device=device + ) return env else: raise NotImplementedError(f"Unknown lib {lib}.") @@ -145,10 +143,22 @@ def make_dreamer( observation_in_key = "pixels" obsevation_out_key = "reco_pixels" else: - encoder = DenseEncoder() - decoder = DenseDecoder( - observation_dim=test_env.observation_spec["observation"].shape[-1] + encoder = MLP( + out_features=1024, + depth=2, + num_cells=config.networks.hidden_dim, + activation_class=get_activation(config.networks.activation), + ) + decoder = MLP( + out_features=test_env.observation_spec["observation"].shape[-1], + depth=2, + num_cells=config.networks.hidden_dim, + activation_class=get_activation(config.networks.activation), ) + # if config.env.backend == "dm_control": + # observation_in_key = ("position", "velocity") + # obsevation_out_key = "reco_observation" + # else: observation_in_key = "observation" obsevation_out_key = "reco_observation" @@ -180,10 +190,11 @@ def make_dreamer( observation_in_key=observation_in_key, observation_out_key=obsevation_out_key, ) + world_model.to(device) # Initialize world model with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - tensordict = test_env.rollout(5).unsqueeze(-1) + tensordict = test_env.rollout(5, auto_cast_to_device=True).unsqueeze(-1) tensordict = tensordict.to_tensordict() world_model(tensordict) @@ -227,6 +238,11 @@ def make_dreamer( value_key=value_key, ) + actor_simulator.to(device) + value_model.to(device) + actor_realworld.to(device) + model_based_env.to(device) + # Initialize model-based environment, actor and critic with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): tensordict = model_based_env.fake_tensordict().unsqueeze(-1) @@ -315,8 +331,8 @@ def _dreamer_make_value_model( SafeProbabilisticModule( in_keys=["loc"], out_keys=[value_key], - distribution_class=Normal, - distribution_kwargs={"scale": 1.0}, + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, ), ) @@ -493,8 +509,8 @@ def _dreamer_make_mbenv( SafeProbabilisticModule( in_keys=["loc"], out_keys=["reward"], - distribution_class=Normal, - distribution_kwargs={"scale": 1.0}, + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, ), ) @@ -543,7 +559,7 @@ def _dreamer_make_world_model( ], ), ) - + event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB decoder = SafeProbabilisticTensorDictSequential( SafeModule( decoder, @@ -553,8 +569,8 @@ def _dreamer_make_world_model( SafeProbabilisticModule( in_keys=["loc"], out_keys=[("next", observation_out_key)], - distribution_class=Normal, - distribution_kwargs={"scale": 1.0}, + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": event_dim}, ), ) @@ -577,8 +593,8 @@ def _dreamer_make_world_model( SafeProbabilisticModule( in_keys=[("next", "loc")], out_keys=[("next", "reward")], - distribution_class=Normal, - distribution_kwargs={"scale": 1.0}, + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, ), ) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 5abcf2bbb48..6196d69c543 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -186,40 +186,6 @@ def forward(self, state, rnn_hidden): return obs_decoded -class DenseEncoder(nn.Module): - """Dense encoder network.""" - - def __init__(self, num_layer=3, hidden_dim=300, embedding_dim=1024): - super().__init__() - - layers = [nn.LazyLinear(hidden_dim), nn.ReLU()] - for _ in range(num_layer - 2): - layers += [nn.LazyLinear(hidden_dim), nn.ReLU()] - layers += [nn.LazyLinear(embedding_dim), nn.ReLU()] - - self.encoder = nn.Sequential(*layers) - - def forward(self, state): - return self.encoder(state) - - -class DenseDecoder(nn.Module): - """Dense decoder network.""" - - def __init__(self, observation_dim, num_layer=3, hidden_dim=300): - super().__init__() - - layers = [nn.LazyLinear(hidden_dim), nn.ReLU()] - for _ in range(num_layer - 2): - layers += [nn.LazyLinear(hidden_dim), nn.ReLU()] - layers += [nn.LazyLinear(observation_dim), nn.ReLU()] - - self.decoder = nn.Sequential(*layers) - - def forward(self, state): - return self.decoder(state) - - class RSSMRollout(TensorDictModuleBase): """Rollout the RSSM network. diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index a7e9145ddb2..4bdca2169f7 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -419,7 +419,7 @@ def forward(self, fake_data) -> torch.Tensor: ) discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) - value_loss = -(discount * dist.log_prob(lambda_target)).mean() + value_loss = -(discount * dist.log_prob(lambda_target).unsqueeze(-1)).mean() else: value_loss = -dist.log_prob(lambda_target).mean() loss_tensordict = TensorDict({"loss_value": value_loss}, []) From fe65b95c67298042c342fe8ced8445d004cbdc34 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 15 Feb 2024 12:59:43 +0100 Subject: [PATCH 012/113] cleanup --- examples/dreamer/config.yaml | 23 +---------------------- examples/dreamer/dreamer_utils.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/examples/dreamer/config.yaml b/examples/dreamer/config.yaml index d7d06d82395..e5444003279 100644 --- a/examples/dreamer/config.yaml +++ b/examples/dreamer/config.yaml @@ -3,34 +3,20 @@ env: task: run seed: 0 backend: dm_control - catframes: 1 - record_video: 0 frame_skip: 2 from_pixels: True grayscale: False image_size : 64 - center_crop: False - batch_transform: 1 - - # # probably not needed vvvv - # normalize_rewards_online: True - # normalize_rewards_online_scale: 5.0 - # normalize_rewards_online_decay: 0.99999 - # reward_scaling: 1.0 collector: total_frames: 5_000_000 - init_env_steps: 1000 init_random_frames: 1000 frames_per_batch: 1000 max_frames_per_traj: 1000 device: cpu - optimization: train_every: 1000 - train_steps: 100 - pretrain: 100 grad_clip: 100 batch_size: 50 batch_length: 50 @@ -45,12 +31,9 @@ optimization: lambda: 0.95 imagination_horizon: 15 - -# we want 50 frames / traj in the replay buffer. Given the frame_skip=2 this makes each traj 100 steps long networks: - # additive gaussian exploration exploration_noise: 0.3 - device: cuda:0 + device: cuda:1 state_dim: 30 rssm_hidden_dim: 200 hidden_dim: 400 @@ -63,14 +46,10 @@ replay_buffer: batch_size: 50 scratch_dir: ${logger.exp_name}_${env.seed} - - logger: backend: wandb project: dreamer-v1 exp_name: ${env.name}-${env.task}-${env.seed} mode: online - record_interval: 30 - record_frames: 1000 eval_iter: 1000 eval_rollout_steps: 1000 diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 2bf27071543..207734a86d0 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -23,9 +23,9 @@ from torchrl.envs.transforms import ( Compose, DoubleToFloat, + ExcludeTransform, FrameSkipTransform, GrayScale, - # NoopResetEnv, ObservationNorm, RandomCropTensorDict, Resize, @@ -55,7 +55,6 @@ from torchrl.modules.tensordict_module.world_models import WorldModelWrapper -# TODO make env with action repeat transform def _make_env(cfg, device): lib = cfg.env.backend if lib in ("gym", "gymnasium"): @@ -65,9 +64,7 @@ def _make_env(cfg, device): device=device, ) elif lib == "dm_control": - env = DMControlEnv( - cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels, device=device - ) + env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) return env else: raise NotImplementedError(f"Unknown lib {lib}.") @@ -80,6 +77,7 @@ def transform_env(cfg, env, parallel_envs, dummy=False): env.append_transform(ToTensorImage(from_int=True)) if cfg.env.grayscale: env.append_transform(GrayScale()) + img_size = cfg.env.image_size env.append_transform(Resize(img_size, img_size)) @@ -194,7 +192,11 @@ def make_dreamer( # Initialize world model with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - tensordict = test_env.rollout(5, auto_cast_to_device=True).unsqueeze(-1) + tensordict = ( + test_env.rollout(5, auto_cast_to_device=True) + .unsqueeze(-1) + .to(world_model.device) + ) tensordict = tensordict.to_tensordict() world_model(tensordict) @@ -245,7 +247,9 @@ def make_dreamer( # Initialize model-based environment, actor and critic with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - tensordict = model_based_env.fake_tensordict().unsqueeze(-1) + tensordict = ( + model_based_env.fake_tensordict().unsqueeze(-1).to(value_model.device) + ) tensordict = tensordict tensordict = actor_simulator(tensordict) value_model(tensordict) @@ -263,8 +267,12 @@ def make_collector(cfg, train_env, actor_model_explore): total_frames=cfg.collector.total_frames, device=cfg.collector.device, reset_at_each_iter=True, + postproc=ExcludeTransform( + "belief", "state", ("next", "belief"), ("next", "state"), "encoded_latents" + ), ) collector.set_seed(cfg.env.seed) + return collector From 8adef8ae650d3680796a3c4e09c376f3db9bae47 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 15 Feb 2024 17:47:02 +0100 Subject: [PATCH 013/113] fixes --- examples/dreamer/config.yaml | 2 +- examples/dreamer/dreamer.py | 1 - examples/dreamer/dreamer_utils.py | 10 +++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/dreamer/config.yaml b/examples/dreamer/config.yaml index e5444003279..325aef7d5db 100644 --- a/examples/dreamer/config.yaml +++ b/examples/dreamer/config.yaml @@ -33,7 +33,7 @@ optimization: networks: exploration_noise: 0.3 - device: cuda:1 + device: cuda:0 state_dim: 30 rssm_hidden_dim: 200 hidden_dim: 400 diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 15faf975359..ac505197497 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -28,7 +28,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # cfg = correct_for_frame_skip(cfg) - # TODO really needed? if so then also needed for collector device if torch.cuda.is_available() and cfg.networks.device == "": device = torch.device("cuda:0") elif cfg.networks.device: diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 207734a86d0..73866cb3a06 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -23,7 +23,7 @@ from torchrl.envs.transforms import ( Compose, DoubleToFloat, - ExcludeTransform, + # ExcludeTransform, FrameSkipTransform, GrayScale, ObservationNorm, @@ -132,7 +132,7 @@ def make_dreamer( value_key: str = "state_value", use_decoder_in_env: bool = False, ): - test_env = _make_env(config, device=device) + test_env = _make_env(config, device="cpu") test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) # Make encoder and decoder if config.env.from_pixels: @@ -267,9 +267,9 @@ def make_collector(cfg, train_env, actor_model_explore): total_frames=cfg.collector.total_frames, device=cfg.collector.device, reset_at_each_iter=True, - postproc=ExcludeTransform( - "belief", "state", ("next", "belief"), ("next", "state"), "encoded_latents" - ), + # postproc=ExcludeTransform( + # "belief", "state", ("next", "belief"), ("next", "state"), "encoded_latents" + # ), ) collector.set_seed(cfg.env.seed) From 1faacc997a446bff784fad25bb8b92aa8ed666dd Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 16 Feb 2024 11:06:22 +0100 Subject: [PATCH 014/113] update naming --- examples/dreamer/{dreamer_utils.py => utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/dreamer/{dreamer_utils.py => utils.py} (100%) diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/utils.py similarity index 100% rename from examples/dreamer/dreamer_utils.py rename to examples/dreamer/utils.py From 99e9c3cbe0f529fec1d35e0043d6fbaca403678b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 11:27:25 +0200 Subject: [PATCH 015/113] amend --- sota-implementations/dreamer/dreamer.py | 5 + sota-implementations/dreamer/dreamer_utils.py | 637 ++++++++++++++++++ 2 files changed, 642 insertions(+) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 2f2a3710930..3aa824ab407 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import hydra import torch import torch.cuda diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index e69de29bb2d..73866cb3a06 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -0,0 +1,637 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import tempfile +from contextlib import nullcontext + +import torch + +import torch.nn as nn +from tensordict.nn import InteractionType +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage + +from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.envs import ParallelEnv + +from torchrl.envs.env_creator import EnvCreator +from torchrl.envs.libs.dm_control import DMControlEnv +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.model_based.dreamer import DreamerEnv +from torchrl.envs.transforms import ( + Compose, + DoubleToFloat, + # ExcludeTransform, + FrameSkipTransform, + GrayScale, + ObservationNorm, + RandomCropTensorDict, + Resize, + RewardSum, + ToTensorImage, + TransformedEnv, +) +from torchrl.envs.transforms.transforms import TensorDictPrimer +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ( + MLP, + SafeModule, + SafeProbabilisticModule, + SafeProbabilisticTensorDictSequential, + SafeSequential, +) +from torchrl.modules.distributions import IndependentNormal, TanhNormal +from torchrl.modules.models.model_based import ( + DreamerActor, + ObsDecoder, + ObsEncoder, + RSSMPosterior, + RSSMPrior, + RSSMRollout, +) +from torchrl.modules.tensordict_module.exploration import AdditiveGaussianWrapper +from torchrl.modules.tensordict_module.world_models import WorldModelWrapper + + +def _make_env(cfg, device): + lib = cfg.env.backend + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) + return env + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def transform_env(cfg, env, parallel_envs, dummy=False): + env = TransformedEnv(env) + if cfg.env.from_pixels: + # transforms pixel from 0-255 to 0-1 (uint8 to float32) + env.append_transform(ToTensorImage(from_int=True)) + if cfg.env.grayscale: + env.append_transform(GrayScale()) + + img_size = cfg.env.image_size + env.append_transform(Resize(img_size, img_size)) + + env.append_transform(DoubleToFloat()) + env.append_transform(RewardSum()) + env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) + if dummy: + default_dict = { + "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim)), + "belief": UnboundedContinuousTensorSpec( + shape=(cfg.networks.rssm_hidden_dim) + ), + } + else: + default_dict = { + "state": UnboundedContinuousTensorSpec( + shape=(parallel_envs, cfg.networks.state_dim) + ), + "belief": UnboundedContinuousTensorSpec( + shape=(parallel_envs, cfg.networks.rssm_hidden_dim) + ), + } + env.append_transform( + TensorDictPrimer(random=False, default_value=0, **default_dict) + ) + + return env + + +def make_environments(cfg, device, parallel_envs=1): + """Make environments for training and evaluation.""" + train_env = ParallelEnv( + parallel_envs, + EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), + ) + train_env = transform_env(cfg, train_env, parallel_envs) + train_env.set_seed(cfg.env.seed) + eval_env = ParallelEnv( + parallel_envs, + EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), + ) + eval_env = transform_env(cfg, eval_env, parallel_envs) + eval_env.set_seed(cfg.env.seed + 1) + + return train_env, eval_env + + +def make_dreamer( + config, + device, + action_key: str = "action", + value_key: str = "state_value", + use_decoder_in_env: bool = False, +): + test_env = _make_env(config, device="cpu") + test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) + # Make encoder and decoder + if config.env.from_pixels: + encoder = ObsEncoder() + decoder = ObsDecoder() + observation_in_key = "pixels" + obsevation_out_key = "reco_pixels" + else: + encoder = MLP( + out_features=1024, + depth=2, + num_cells=config.networks.hidden_dim, + activation_class=get_activation(config.networks.activation), + ) + decoder = MLP( + out_features=test_env.observation_spec["observation"].shape[-1], + depth=2, + num_cells=config.networks.hidden_dim, + activation_class=get_activation(config.networks.activation), + ) + # if config.env.backend == "dm_control": + # observation_in_key = ("position", "velocity") + # obsevation_out_key = "reco_observation" + # else: + observation_in_key = "observation" + obsevation_out_key = "reco_observation" + + # Make RSSM + rssm_prior = RSSMPrior( + hidden_dim=config.networks.rssm_hidden_dim, + rnn_hidden_dim=config.networks.rssm_hidden_dim, + state_dim=config.networks.state_dim, + action_spec=test_env.action_spec, + ) + rssm_posterior = RSSMPosterior( + hidden_dim=config.networks.rssm_hidden_dim, state_dim=config.networks.state_dim + ) + # Make reward module + reward_module = MLP( + out_features=1, + depth=2, + num_cells=config.networks.hidden_dim, + activation_class=get_activation(config.networks.activation), + ) + + # Make combined world model + world_model = _dreamer_make_world_model( + encoder, + decoder, + rssm_prior, + rssm_posterior, + reward_module, + observation_in_key=observation_in_key, + observation_out_key=obsevation_out_key, + ) + world_model.to(device) + + # Initialize world model + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + tensordict = ( + test_env.rollout(5, auto_cast_to_device=True) + .unsqueeze(-1) + .to(world_model.device) + ) + tensordict = tensordict.to_tensordict() + world_model(tensordict) + + # Create model-based environment + model_based_env = _dreamer_make_mbenv( + reward_module=reward_module, + rssm_prior=rssm_prior, + decoder=decoder, + observation_out_key=obsevation_out_key, + test_env=test_env, + use_decoder_in_env=use_decoder_in_env, + state_dim=config.networks.state_dim, + rssm_hidden_dim=config.networks.rssm_hidden_dim, + ) + + # Make actor + actor_simulator, actor_realworld = _dreamer_make_actors( + encoder=encoder, + observation_in_key=observation_in_key, + rssm_prior=rssm_prior, + rssm_posterior=rssm_posterior, + mlp_num_units=config.networks.hidden_dim, + activation=get_activation(config.networks.activation), + action_key=action_key, + test_env=test_env, + ) + # Exploration noise to be added to the actor_realworld + actor_realworld = AdditiveGaussianWrapper( + actor_realworld, + sigma_init=1.0, + sigma_end=1.0, + annealing_num_steps=1, + mean=0.0, + std=config.networks.exploration_noise, + ) + + # Make Critic + value_model = _dreamer_make_value_model( + hidden_dim=config.networks.hidden_dim, + activation=config.networks.activation, + value_key=value_key, + ) + + actor_simulator.to(device) + value_model.to(device) + actor_realworld.to(device) + model_based_env.to(device) + + # Initialize model-based environment, actor and critic + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + tensordict = ( + model_based_env.fake_tensordict().unsqueeze(-1).to(value_model.device) + ) + tensordict = tensordict + tensordict = actor_simulator(tensordict) + value_model(tensordict) + + return world_model, model_based_env, actor_simulator, value_model, actor_realworld + + +def make_collector(cfg, train_env, actor_model_explore): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=cfg.collector.device, + reset_at_each_iter=True, + # postproc=ExcludeTransform( + # "belief", "state", ("next", "belief"), ("next", "state"), "encoded_latents" + # ), + ) + collector.set_seed(cfg.env.seed) + + return collector + + +def make_replay_buffer( + batch_size, + batch_seq_len, + buffer_size=1000000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, + pixel_obs=True, + cast_to_uint8=True, +): + with ( + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) + ) as scratch_dir: + transforms = [] + crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) + transforms.append(crop_seq) + + if pixel_obs and cast_to_uint8: + # from 0-255 to 0-1 + norm_obs = ObservationNorm( + loc=0, + scale=255, + standard_normal=True, + in_keys=["pixels", ("next", "pixels")], + ) + transforms.append(norm_obs) + + transforms = Compose(*transforms) + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + transform=transforms, + batch_size=batch_size, + ) + return replay_buffer + + +def _dreamer_make_value_model( + hidden_dim: int = 400, activation: str = "elu", value_key: str = "state_value" +): + value_model = MLP( + out_features=1, + depth=3, + num_cells=hidden_dim, + activation_class=get_activation(activation), + ) + value_model = SafeProbabilisticTensorDictSequential( + SafeModule( + value_model, + in_keys=["state", "belief"], + out_keys=["loc"], + ), + SafeProbabilisticModule( + in_keys=["loc"], + out_keys=[value_key], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, + ), + ) + + return value_model + + +def _dreamer_make_actors( + encoder, + observation_in_key, + rssm_prior, + rssm_posterior, + mlp_num_units, + activation, + action_key, + test_env, +): + actor_module = DreamerActor( + out_features=test_env.action_spec.shape[-1], + depth=3, + num_cells=mlp_num_units, + activation_class=activation, + ) + actor_simulator = _dreamer_make_actor_sim(action_key, test_env, actor_module) + actor_realworld = _dreamer_make_actor_real( + encoder, + observation_in_key, + rssm_prior, + rssm_posterior, + actor_module, + action_key, + test_env, + ) + return actor_simulator, actor_realworld + + +def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): + actor_simulator = SafeProbabilisticTensorDictSequential( + SafeModule( + actor_module, + in_keys=["state", "belief"], + out_keys=["loc", "scale"], + spec=CompositeSpec( + **{ + "loc": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + device=proof_environment.action_spec.device, + ), + "scale": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + device=proof_environment.action_spec.device, + ), + } + ), + ), + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=[action_key], + default_interaction_type=InteractionType.RANDOM, + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": True}, + spec=CompositeSpec(**{action_key: proof_environment.action_spec}), + ), + ) + return actor_simulator + + +def _dreamer_make_actor_real( + encoder, + observation_in_key, + rssm_prior, + rssm_posterior, + actor_module, + action_key, + proof_environment, +): + # actor for real world: interacts with states ~ posterior + # Out actor differs from the original paper where first they compute prior and posterior and then act on it + # but we found that this approach worked better. + actor_realworld = SafeSequential( + SafeModule( + encoder, + in_keys=[observation_in_key], + out_keys=["encoded_latents"], + ), + SafeModule( + rssm_posterior, + in_keys=["belief", "encoded_latents"], + out_keys=[ + "_", + "_", + "state", + ], + ), + SafeProbabilisticTensorDictSequential( + SafeModule( + actor_module, + in_keys=["state", "belief"], + out_keys=["loc", "scale"], + spec=CompositeSpec( + **{ + "loc": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + ), + "scale": UnboundedContinuousTensorSpec( + proof_environment.action_spec.shape, + ), + } + ), + ), + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=[action_key], + default_interaction_type=InteractionType.MODE, + distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": True}, + spec=CompositeSpec( + **{action_key: proof_environment.action_spec.to("cpu")} + ), + ), + ), + SafeModule( + rssm_prior, + in_keys=["state", "belief", action_key], + out_keys=[ + "_", + "_", + "_", # we don't need the prior state + ("next", "belief"), + ], + ), + ) + return actor_realworld + + +def _dreamer_make_mbenv( + reward_module, + rssm_prior, + test_env, + decoder, + observation_out_key: str = "reco_pixels", + use_decoder_in_env: bool = False, + state_dim: int = 30, + rssm_hidden_dim: int = 200, +): + # MB environment + if use_decoder_in_env: + mb_env_obs_decoder = SafeModule( + decoder, + in_keys=[("next", "state"), ("next", "belief")], + out_keys=[("next", observation_out_key)], + ) + else: + mb_env_obs_decoder = None + + transition_model = SafeSequential( + SafeModule( + rssm_prior, + in_keys=["state", "belief", "action"], + out_keys=[ + "_", + "_", + "state", + "belief", + ], + ), + ) + + reward_model = SafeProbabilisticTensorDictSequential( + SafeModule( + reward_module, + in_keys=["state", "belief"], + out_keys=["loc"], + ), + SafeProbabilisticModule( + in_keys=["loc"], + out_keys=["reward"], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, + ), + ) + + model_based_env = DreamerEnv( + world_model=WorldModelWrapper( + transition_model, + reward_model, + ), + prior_shape=torch.Size([state_dim]), + belief_shape=torch.Size([rssm_hidden_dim]), + obs_decoder=mb_env_obs_decoder, + ) + + model_based_env.set_specs_from_env(test_env) + return model_based_env + + +def _dreamer_make_world_model( + encoder, + decoder, + rssm_prior, + rssm_posterior, + reward_module, + observation_in_key: str = "pixels", + observation_out_key: str = "reco_pixels", +): + # World Model and reward model + rssm_rollout = RSSMRollout( + SafeModule( + rssm_prior, + in_keys=["state", "belief", "action"], + out_keys=[ + ("next", "prior_mean"), + ("next", "prior_std"), + "_", + ("next", "belief"), + ], + ), + SafeModule( + rssm_posterior, + in_keys=[("next", "belief"), ("next", "encoded_latents")], + out_keys=[ + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "state"), + ], + ), + ) + event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB + decoder = SafeProbabilisticTensorDictSequential( + SafeModule( + decoder, + in_keys=[("next", "state"), ("next", "belief")], + out_keys=["loc"], + ), + SafeProbabilisticModule( + in_keys=["loc"], + out_keys=[("next", observation_out_key)], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": event_dim}, + ), + ) + + transition_model = SafeSequential( + SafeModule( + encoder, + in_keys=[("next", observation_in_key)], + out_keys=[("next", "encoded_latents")], + ), + rssm_rollout, + decoder, + ) + + reward_model = SafeProbabilisticTensorDictSequential( + SafeModule( + reward_module, + in_keys=[("next", "state"), ("next", "belief")], + out_keys=[("next", "loc")], + ), + SafeProbabilisticModule( + in_keys=[("next", "loc")], + out_keys=[("next", "reward")], + distribution_class=IndependentNormal, + distribution_kwargs={"scale": 1.0, "event_dim": 1}, + ), + ) + + world_model = WorldModelWrapper( + transition_model, + reward_model, + ) + return world_model + + +def cast_to_uint8(tensordict): + tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) + tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to(torch.uint8) + return tensordict + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(name): + if name == "relu": + return nn.ReLU + elif name == "tanh": + return nn.Tanh + elif name == "leaky_relu": + return nn.LeakyReLU + elif name == "elu": + return nn.ELU + else: + raise NotImplementedError From fbb09fa8e50a567eabc2ec92221d9811ee82de9d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 12:48:36 +0200 Subject: [PATCH 016/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 24 +++++++--- torchrl/envs/common.py | 3 -- torchrl/envs/model_based/common.py | 17 ++++--- torchrl/envs/model_based/dreamer.py | 25 +++++----- torchrl/envs/utils.py | 46 +++++++++++++++---- torchrl/objectives/dreamer.py | 12 +---- 6 files changed, 80 insertions(+), 47 deletions(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 73866cb3a06..598aa7e8f5f 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import tempfile from contextlib import nullcontext @@ -34,7 +35,7 @@ TransformedEnv, ) from torchrl.envs.transforms.transforms import TensorDictPrimer -from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( MLP, SafeModule, @@ -64,8 +65,7 @@ def _make_env(cfg, device): device=device, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) - return env + return DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) else: raise NotImplementedError(f"Unknown lib {lib}.") @@ -109,19 +109,23 @@ def transform_env(cfg, env, parallel_envs, dummy=False): def make_environments(cfg, device, parallel_envs=1): """Make environments for training and evaluation.""" + func = lambda _cfg=cfg: _make_env(cfg=_cfg, device=device) train_env = ParallelEnv( parallel_envs, - EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), + EnvCreator(func), + serial_for_single=True, ) train_env = transform_env(cfg, train_env, parallel_envs) train_env.set_seed(cfg.env.seed) eval_env = ParallelEnv( parallel_envs, - EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), + EnvCreator(func), + serial_for_single=True, ) eval_env = transform_env(cfg, eval_env, parallel_envs) eval_env.set_seed(cfg.env.seed + 1) - + check_env_specs(train_env) + check_env_specs(eval_env) return train_env, eval_env @@ -212,6 +216,14 @@ def make_dreamer( rssm_hidden_dim=config.networks.rssm_hidden_dim, ) + def detach_state_and_belief(data): + data.set("state", data.get("state").detach()) + data.set("belief", data.get("belief").detach()) + return data + + model_based_env = model_based_env.append_transform(detach_state_and_belief) + check_env_specs(model_based_env) + # Make actor actor_simulator, actor_realworld = _dreamer_make_actors( encoder=encoder, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 11fb916bf5c..4565aa8abca 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2610,9 +2610,6 @@ def _rollout_stop_early( else: tensordict.clear_device_() - if "state" in tensordict.keys() and "belief" in tensordict.keys(): - tensordict["state"] = tensordict["state"].detach() - tensordict["belief"] = tensordict["belief"].detach() tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index a04607829c6..c1940f75a8f 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -139,11 +139,15 @@ def __new__(cls, *args, **kwargs): def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" - self.observation_spec = env.observation_spec.clone().to(self.device) - self.reward_spec = env.reward_spec.clone().to(self.device) - self.action_spec = env.action_spec.clone().to(self.device) - self.done_spec = env.done_spec.clone().to(self.device) - self.state_spec = env.state_spec.clone().to(self.device) + device = self.device + output_spec = env.output_spec.clone() + input_spec = env.input_spec.clone() + if device is not None: + output_spec = output_spec.to(device) + input_spec = input_spec.to(device) + self.__dict__["_output_spec"] = output_spec + self.__dict__["_input_spec"] = input_spec + self.empty_cache() def _step( self, @@ -161,12 +165,13 @@ def _step( else: tensordict_out = self.world_model(tensordict_out) # done can be missing, it will be filled by `step` - return tensordict_out.select( + tensordict_out = tensordict_out.select( *self.observation_spec.keys(), *self.full_done_spec.keys(), *self.full_reward_spec.keys(), strict=False, ) + return tensordict_out @abc.abstractmethod def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 916509171bf..5f17fede18a 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -39,14 +39,6 @@ def __init__( def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" super().set_specs_from_env(env) - # self.observation_spec = CompositeSpec( - # next_state=UnboundedContinuousTensorSpec( - # shape=self.prior_shape, device=self.device - # ), - # next_belief=UnboundedContinuousTensorSpec( - # shape=self.belief_shape, device=self.device - # ), - # ) self.action_spec = self.action_spec.to(self.device) self.state_spec = CompositeSpec( state=self.observation_spec["state"], @@ -57,13 +49,18 @@ def set_specs_from_env(self, env: EnvBase): def _reset(self, tensordict=None, **kwargs) -> TensorDict: batch_size = tensordict.batch_size if tensordict is not None else [] device = tensordict.device if tensordict is not None else self.device - # TODO: why do we overright here incoming belief and states that are correct if tensordict is None: - td = self.state_spec.rand(shape=batch_size).to(device) - # why dont we reuse actions taken at those steps? - td.set("action", self.action_spec.rand(shape=batch_size).to(device)) - td[("next", "reward")] = self.reward_spec.rand(shape=batch_size).to(device) - td.update(self.observation_spec.rand(shape=batch_size).to(device)) + td = self.state_spec.rand(shape=batch_size) + # why don't we reuse actions taken at those steps? + td.set("action", self.action_spec.rand(shape=batch_size)) + td[("next", "reward")] = self.reward_spec.rand(shape=batch_size) + td.update(self.observation_spec.rand(shape=batch_size)) + if device is not None: + td = td.to(device, non_blocking=True) + if torch.cuda.is_available() and device.type == "cpu": + torch.cuda.synchronize() + elif torch.backends.mps.is_available(): + torch.mps.synchronize() else: td = tensordict.clone() return td diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e0fea1751ed..d7730322ff1 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -151,6 +151,11 @@ def __init__( self.keys_from_next = self._repr_key_list_as_tree(self.keys_from_next) self.validated = None + # Model based envs can have missing keys + from torchrl.envs import ModelBasedEnvBase + + self._allow_absent_keys = isinstance(env, ModelBasedEnvBase) + def validate(self, tensordict): if self.validated: return True @@ -196,7 +201,11 @@ def _repr_key_list_as_tree(key_list): @classmethod def _grab_and_place( - cls, nested_key_dict: dict, data_in: TensorDictBase, data_out: TensorDictBase + cls, + nested_key_dict: dict, + data_in: TensorDictBase, + data_out: TensorDictBase, + _allow_absent_keys: bool, ): for key, subdict in nested_key_dict.items(): val = data_in._get_str(key, NO_DEFAULT) @@ -208,7 +217,12 @@ def _grab_and_place( val = LazyStackedTensorDict( *( - cls._grab_and_place(subdict, _val, _val_out) + cls._grab_and_place( + subdict, + _val, + _val_out, + _allow_absent_keys=_allow_absent_keys, + ) for (_val, _val_out) in zip( val.unbind(val.stack_dim), val_out.unbind(val_out.stack_dim), @@ -217,10 +231,16 @@ def _grab_and_place( stack_dim=val.stack_dim, ) else: - val = cls._grab_and_place(subdict, val, val_out) - data_out._set_str( - key, val, validated=True, inplace=False, non_blocking=False - ) + val = cls._grab_and_place( + subdict, val, val_out, _allow_absent_keys=_allow_absent_keys + ) + if val is NO_DEFAULT: + if not _allow_absent_keys: + raise KeyError(f"key {key} not found.") + else: + data_out._set_str( + key, val, validated=True, inplace=False, non_blocking=False + ) return data_out @classmethod @@ -267,8 +287,18 @@ def __call__(self, tensordict): out = self._exclude(self.exclude_from_root, tensordict, out=None) else: out = next_td.empty() - self._grab_and_place(self.keys_from_root, tensordict, out) - self._grab_and_place(self.keys_from_next, next_td, out) + self._grab_and_place( + self.keys_from_root, + tensordict, + out, + _allow_absent_keys=self._allow_absent_keys, + ) + self._grab_and_place( + self.keys_from_next, + next_td, + out, + _allow_absent_keys=self._allow_absent_keys, + ) return out else: out = next_td.empty() diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index c50204324af..b2ffae6d2bf 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -253,11 +253,6 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.select("state", self.tensor_keys.belief) tensordict = tensordict.reshape(-1) - # td = tensordict.select(("next", self.tensor_keys.state), ("next", self.tensor_keys.belief)) - # td = td.rename_key_(("next", "state"), "state") - # td = td.rename_key_(("next", "belief"), "belief") - # td = td.reshape(-1) - # TODO: do we need exploration here? with hold_out_net(self.model_based_env), set_exploration_type( ExplorationType.MEAN @@ -265,7 +260,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: # action_td = self.actor_model(td) # TODO: we are not using the actual batch beliefs as starting ones - should be solved! took of the primer for the mb_env - tensordict = self.model_based_env.reset(tensordict.clone(recurse=False)) + tensordict = self.model_based_env.reset(tensordict.copy()) # TODO: do we detach state gradients when passing again for new actions: action = self.actor(state.detach()) fake_data = self.model_based_env.rollout( max_steps=self.imagination_horizon, @@ -274,10 +269,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict=tensordict, ) - next_tensordict = step_mdp( - fake_data, - keep_other=True, - ) + next_tensordict = step_mdp(fake_data, keep_other=True) with hold_out_net(self.value_model): next_tensordict = self.value_model(next_tensordict) From a7554c981daa1f80b7ffb0ec6cd23543978425f1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 12:51:30 +0200 Subject: [PATCH 017/113] amend --- torchrl/envs/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d7730322ff1..44701cafc31 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -152,9 +152,8 @@ def __init__( self.validated = None # Model based envs can have missing keys - from torchrl.envs import ModelBasedEnvBase - - self._allow_absent_keys = isinstance(env, ModelBasedEnvBase) + # TODO: do we want to always allow this? check_env_specs should catch these or downstream ops + self._allow_absent_keys = True def validate(self, tensordict): if self.validated: From ac5f2fab8ab51f0bd0d9ce219d3f5afc2b7355dc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 12:51:56 +0200 Subject: [PATCH 018/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 598aa7e8f5f..368a0bdffd4 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -109,7 +109,7 @@ def transform_env(cfg, env, parallel_envs, dummy=False): def make_environments(cfg, device, parallel_envs=1): """Make environments for training and evaluation.""" - func = lambda _cfg=cfg: _make_env(cfg=_cfg, device=device) + func = functools.partial(_make_env, cfg=cfg, device=device) train_env = ParallelEnv( parallel_envs, EnvCreator(func), From a912f7c45ea9ba924d85be3894b009af649d6334 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 14:56:55 +0200 Subject: [PATCH 019/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- sota-implementations/dreamer/dreamer.py | 9 +-- sota-implementations/dreamer/dreamer_utils.py | 74 +++++++++++-------- torchrl/_utils.py | 2 +- torchrl/data/replay_buffers/storages.py | 25 +++---- torchrl/envs/transforms/transforms.py | 53 ++++++++----- 6 files changed, 91 insertions(+), 74 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 325aef7d5db..48509d05fae 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -7,6 +7,7 @@ env: from_pixels: True grayscale: False image_size : 64 + horizon: 500 collector: total_frames: 5_000_000 @@ -41,7 +42,6 @@ networks: replay_buffer: - uint8_casting: True buffer_size: 20000 batch_size: 50 scratch_dir: ${logger.exp_name}_${env.seed} diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 3aa824ab407..392175d2dfa 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -8,7 +8,6 @@ import torch.cuda import tqdm from dreamer_utils import ( - cast_to_uint8, log_metrics, make_collector, make_dreamer, @@ -90,7 +89,8 @@ def main(cfg: "DictConfig"): # noqa: F821 buffer_scratch_dir=cfg.replay_buffer.scratch_dir, device=cfg.networks.device, pixel_obs=cfg.env.from_pixels, - cast_to_uint8=cfg.replay_buffer.uint8_casting, + grayscale=cfg.env.grayscale, + image_size=cfg.env.image_size, ) # Training loop @@ -113,8 +113,6 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size = cfg.optimization.batch_size optim_steps_per_batch = cfg.optimization.optim_steps_per_batch grad_clip = cfg.optimization.grad_clip - uint8_casting = cfg.replay_buffer.uint8_casting - pixel_obs = cfg.env.from_pixels frames_per_batch = cfg.collector.frames_per_batch eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.logger.eval_rollout_steps @@ -124,9 +122,6 @@ def main(cfg: "DictConfig"): # noqa: F821 current_frames = tensordict.numel() collected_frames += current_frames - if uint8_casting and pixel_obs: - tensordict = cast_to_uint8(tensordict) - ep_reward = tensordict.get("episode_reward")[:, -1] replay_buffer.extend(tensordict.cpu()) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 368a0bdffd4..fafc1bf99db 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -11,7 +11,7 @@ import torch.nn as nn from tensordict.nn import InteractionType from torchrl.collectors import SyncDataCollector -from torchrl.data import TensorDictReplayBuffer +from torchrl.data import SliceSampler, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec @@ -34,7 +34,12 @@ ToTensorImage, TransformedEnv, ) -from torchrl.envs.transforms.transforms import TensorDictPrimer +from torchrl.envs.transforms.transforms import ( + ExcludeTransform, + RenameTransform, + StepCounter, + TensorDictPrimer, +) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( MLP, @@ -74,16 +79,22 @@ def transform_env(cfg, env, parallel_envs, dummy=False): env = TransformedEnv(env) if cfg.env.from_pixels: # transforms pixel from 0-255 to 0-1 (uint8 to float32) - env.append_transform(ToTensorImage(from_int=True)) + env.append_transform( + RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"]) + ) + env.append_transform( + ToTensorImage(from_int=True, in_keys=["pixels_int"], out_keys=["pixels"]) + ) if cfg.env.grayscale: env.append_transform(GrayScale()) - img_size = cfg.env.image_size - env.append_transform(Resize(img_size, img_size)) + image_size = cfg.env.image_size + env.append_transform(Resize(image_size, image_size)) env.append_transform(DoubleToFloat()) env.append_transform(RewardSum()) env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) + env.append_transform(StepCounter(cfg.env.horizon)) if dummy: default_dict = { "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim)), @@ -278,10 +289,6 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=cfg.collector.device, - reset_at_each_iter=True, - # postproc=ExcludeTransform( - # "belief", "state", ("next", "belief"), ("next", "state"), "encoded_latents" - # ), ) collector.set_seed(cfg.env.seed) @@ -290,34 +297,41 @@ def make_collector(cfg, train_env, actor_model_explore): def make_replay_buffer( batch_size, + *, batch_seq_len, buffer_size=1000000, buffer_scratch_dir=None, device="cpu", prefetch=3, pixel_obs=True, - cast_to_uint8=True, + grayscale=True, + image_size, ): with ( tempfile.TemporaryDirectory() if buffer_scratch_dir is None else nullcontext(buffer_scratch_dir) ) as scratch_dir: - transforms = [] - crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) - transforms.append(crop_seq) - - if pixel_obs and cast_to_uint8: - # from 0-255 to 0-1 - norm_obs = ObservationNorm( - loc=0, - scale=255, - standard_normal=True, - in_keys=["pixels", ("next", "pixels")], + transforms = None + if pixel_obs: + + def check_no_pixels(data): + assert "pixels" not in data.keys() + return data + + transforms = Compose( + ExcludeTransform("pixels", ("next", "pixels"), inverse=True), + check_no_pixels, # will be called only during forward + ToTensorImage( + in_keys=["pixels_int", ("next", "pixels_int")], + out_keys=["pixels", ("next", "pixels")], + ), + ) + if grayscale: + transforms.append(GrayScale(in_keys=["pixels", ("next", "pixels")])) + transforms.append( + Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) ) - transforms.append(norm_obs) - - transforms = Compose(*transforms) replay_buffer = TensorDictReplayBuffer( pin_memory=False, @@ -326,6 +340,12 @@ def make_replay_buffer( buffer_size, scratch_dir=scratch_dir, device=device, + ndim=2, + ), + sampler=SliceSampler( + slice_len=batch_seq_len, + strict_length=False, + traj_key=("collector", "traj_ids"), ), transform=transforms, batch_size=batch_size, @@ -625,12 +645,6 @@ def _dreamer_make_world_model( return world_model -def cast_to_uint8(tensordict): - tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) - tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to(torch.uint8) - return tensordict - - def log_metrics(logger, metrics, step): for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 3c800c0b4aa..64615de467e 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -31,7 +31,7 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp -LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "DEBUG") +LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "LOGGING") logger = logging.getLogger("torchrl") logger.setLevel(getattr(logging, LOGGING_LEVEL)) # Disable propagation to the root logger diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index d90746a44a1..61822515365 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -29,7 +29,6 @@ _CKPT_BACKEND, implement_for, logger as torchrl_logger, - VERBOSE, ) from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES @@ -913,8 +912,7 @@ def _init( self, data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 ) -> None: - if VERBOSE: - torchrl_logger.info("Creating a TensorStorage...") + torchrl_logger.debug("Creating a TensorStorage...") if self.device == "auto": self.device = data.device @@ -1090,8 +1088,7 @@ def load_state_dict(self, state_dict): self._len = state_dict["_len"] def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: - if VERBOSE: - torchrl_logger.info("Creating a MemmapStorage...") + torchrl_logger.debug("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device if self.device.type != "cpu": @@ -1116,11 +1113,10 @@ def max_size_along_dim0(data_shape): for key, tensor in sorted( out.items(include_nested=True, leaves_only=True), key=str ): - if VERBOSE: - filesize = os.path.getsize(tensor.filename) / 1024 / 1024 - torchrl_logger.info( - f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." - ) + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + torchrl_logger.debug( + f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." + ) else: out = _init_pytree(self.scratch_dir, max_size_along_dim0, data) self._storage = out @@ -1476,11 +1472,10 @@ def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor): filename=total_tensor_path, dtype=tensor.dtype, ) - if VERBOSE: - filesize = os.path.getsize(out.filename) / 1024 / 1024 - torchrl_logger.info( - f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." - ) + filesize = os.path.getsize(out.filename) / 1024 / 1024 + torchrl_logger.debug( + f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." + ) return out diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b83de8b71f8..511eef6c410 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5668,6 +5668,8 @@ class ExcludeTransform(Transform): Args: *excluded_keys (iterable of NestedKey): The name of the keys to exclude. If the key is not present, it is simply ignored. + inverse (bool, optional): if ``True``, the exclusion will occur during the ``inv`` call. + Defaults to ``False``. Examples: >>> import gymnasium @@ -5696,7 +5698,7 @@ class ExcludeTransform(Transform): """ - def __init__(self, *excluded_keys): + def __init__(self, *excluded_keys, inverse: bool = False): super().__init__() try: excluded_keys = unravel_key_list(excluded_keys) @@ -5705,35 +5707,46 @@ def __init__(self, *excluded_keys): "excluded keys must be a list or tuple of strings or tuples of strings." ) self.excluded_keys = excluded_keys + self.inverse = inverse def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.exclude(*self.excluded_keys) + if not self.inverse: + return tensordict.exclude(*self.excluded_keys) + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.inverse: + return tensordict.exclude(*self.excluded_keys) + return tensordict forward = _call def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: - return tensordict_reset.exclude(*self.excluded_keys) + if not self.inverse: + return tensordict.exclude(*self.excluded_keys) + return tensordict def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: - full_done_spec = output_spec["full_done_spec"] - full_reward_spec = output_spec["full_reward_spec"] - full_observation_spec = output_spec["full_observation_spec"] - for key in self.excluded_keys: - # done_spec - if unravel_key(key) in list(full_done_spec.keys(True, True)): - del full_done_spec[key] - continue - # reward_spec - if unravel_key(key) in list(full_reward_spec.keys(True, True)): - del full_reward_spec[key] - continue - # observation_spec - if unravel_key(key) in list(full_observation_spec.keys(True, True)): - del full_observation_spec[key] - continue - raise KeyError(f"Key {key} not found in the environment outputs.") + if not self.inverse: + full_done_spec = output_spec["full_done_spec"] + full_reward_spec = output_spec["full_reward_spec"] + full_observation_spec = output_spec["full_observation_spec"] + for key in self.excluded_keys: + # done_spec + if unravel_key(key) in list(full_done_spec.keys(True, True)): + del full_done_spec[key] + continue + # reward_spec + if unravel_key(key) in list(full_reward_spec.keys(True, True)): + del full_reward_spec[key] + continue + # observation_spec + if unravel_key(key) in list(full_observation_spec.keys(True, True)): + del full_observation_spec[key] + continue + raise KeyError(f"Key {key} not found in the environment outputs.") return output_spec From 74cf3f8af5e5d29c9e356eb7b0ea685bee287155 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 14:58:42 +0200 Subject: [PATCH 020/113] amend --- torchrl/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 64615de467e..a9109f97354 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -31,7 +31,7 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp -LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "LOGGING") +LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") logger.setLevel(getattr(logging, LOGGING_LEVEL)) # Disable propagation to the root logger From 47cfff711ffa2c7db1a46a6b15c3ab88e3da5c27 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 14:59:18 +0200 Subject: [PATCH 021/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index fafc1bf99db..1a7ce0aa3f0 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -346,6 +346,8 @@ def check_no_pixels(data): slice_len=batch_seq_len, strict_length=False, traj_key=("collector", "traj_ids"), + cache_values=True, + compile=True, ), transform=transforms, batch_size=batch_size, From b8452e7c133a0050b6952dfe08193342e9bb2c9a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:02:59 +0200 Subject: [PATCH 022/113] amend --- sota-implementations/dreamer/dreamer.py | 1 - sota-implementations/dreamer/dreamer_utils.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 392175d2dfa..700eed0adae 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -87,7 +87,6 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_seq_len=cfg.optimization.batch_length, buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, - device=cfg.networks.device, pixel_obs=cfg.env.from_pixels, grayscale=cfg.env.grayscale, image_size=cfg.env.image_size, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 1a7ce0aa3f0..af09479085a 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -301,7 +301,6 @@ def make_replay_buffer( batch_seq_len, buffer_size=1000000, buffer_scratch_dir=None, - device="cpu", prefetch=3, pixel_obs=True, grayscale=True, @@ -339,7 +338,7 @@ def check_no_pixels(data): storage=LazyMemmapStorage( buffer_size, scratch_dir=scratch_dir, - device=device, + device="cpu", ndim=2, ), sampler=SliceSampler( From 2cb1b4ad65157c8a86dfdcc3ee955ee4887accfb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:04:03 +0200 Subject: [PATCH 023/113] amend --- sota-implementations/dreamer/dreamer.py | 1 + sota-implementations/dreamer/dreamer_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 700eed0adae..392175d2dfa 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -87,6 +87,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_seq_len=cfg.optimization.batch_length, buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + device=cfg.networks.device, pixel_obs=cfg.env.from_pixels, grayscale=cfg.env.grayscale, image_size=cfg.env.image_size, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index af09479085a..4e89c52deba 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -38,7 +38,7 @@ ExcludeTransform, RenameTransform, StepCounter, - TensorDictPrimer, + TensorDictPrimer, DeviceCastTransform, ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( @@ -301,6 +301,7 @@ def make_replay_buffer( batch_seq_len, buffer_size=1000000, buffer_scratch_dir=None, + device=None, prefetch=3, pixel_obs=True, grayscale=True, @@ -311,7 +312,7 @@ def make_replay_buffer( if buffer_scratch_dir is None else nullcontext(buffer_scratch_dir) ) as scratch_dir: - transforms = None + transforms = Compose() if pixel_obs: def check_no_pixels(data): @@ -331,6 +332,7 @@ def check_no_pixels(data): transforms.append( Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) ) + transforms.append(DeviceCastTransform(device=device)) replay_buffer = TensorDictReplayBuffer( pin_memory=False, From e81c9a511d280103d9404e345e2af4729b323cb1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:09:39 +0200 Subject: [PATCH 024/113] amend --- sota-implementations/dreamer/config.yaml | 3 ++- sota-implementations/dreamer/dreamer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 48509d05fae..363eac99db4 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -8,13 +8,14 @@ env: grayscale: False image_size : 64 horizon: 500 + n_parallel_envs: 8 collector: total_frames: 5_000_000 init_random_frames: 1000 frames_per_batch: 1000 max_frames_per_traj: 1000 - device: cpu + device: cuda:0 optimization: train_every: 1000 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 392175d2dfa..291d0d4294c 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -50,7 +50,7 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg}, ) - train_env, test_env = make_environments(cfg=cfg, device=device) + train_env, test_env = make_environments(cfg=cfg, device=device, parallel_envs=cfg.env.n_parallel_envs) # Make dreamer components action_key = "action" From 331faf46ee2b8aeaa1f4e26e03e75a2c8569a3d1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:13:06 +0200 Subject: [PATCH 025/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 4e89c52deba..76e61cbb3f2 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -65,17 +65,29 @@ def _make_env(cfg, device): lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): - return GymEnv( + env = GymEnv( cfg.env.name, device=device, ) elif lib == "dm_control": - return DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) + env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) else: raise NotImplementedError(f"Unknown lib {lib}.") + default_dict = { + "state": UnboundedContinuousTensorSpec( + shape=(cfg.networks.state_dim,) + ), + "belief": UnboundedContinuousTensorSpec( + shape=(cfg.networks.rssm_hidden_dim,) + ), + } + env = env.append_transform( + TensorDictPrimer(random=False, default_value=0, **default_dict) + ) + return env -def transform_env(cfg, env, parallel_envs, dummy=False): +def transform_env(cfg, env): env = TransformedEnv(env) if cfg.env.from_pixels: # transforms pixel from 0-255 to 0-1 (uint8 to float32) @@ -95,25 +107,6 @@ def transform_env(cfg, env, parallel_envs, dummy=False): env.append_transform(RewardSum()) env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) env.append_transform(StepCounter(cfg.env.horizon)) - if dummy: - default_dict = { - "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim)), - "belief": UnboundedContinuousTensorSpec( - shape=(cfg.networks.rssm_hidden_dim) - ), - } - else: - default_dict = { - "state": UnboundedContinuousTensorSpec( - shape=(parallel_envs, cfg.networks.state_dim) - ), - "belief": UnboundedContinuousTensorSpec( - shape=(parallel_envs, cfg.networks.rssm_hidden_dim) - ), - } - env.append_transform( - TensorDictPrimer(random=False, default_value=0, **default_dict) - ) return env @@ -126,14 +119,14 @@ def make_environments(cfg, device, parallel_envs=1): EnvCreator(func), serial_for_single=True, ) - train_env = transform_env(cfg, train_env, parallel_envs) + train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) eval_env = ParallelEnv( parallel_envs, EnvCreator(func), serial_for_single=True, ) - eval_env = transform_env(cfg, eval_env, parallel_envs) + eval_env = transform_env(cfg, eval_env) eval_env.set_seed(cfg.env.seed + 1) check_env_specs(train_env) check_env_specs(eval_env) @@ -148,7 +141,7 @@ def make_dreamer( use_decoder_in_env: bool = False, ): test_env = _make_env(config, device="cpu") - test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) + test_env = transform_env(config, test_env) # Make encoder and decoder if config.env.from_pixels: encoder = ObsEncoder() From 97b69d1449ab8ee3dbc8db0dc50b41cb0dda2102 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:16:34 +0200 Subject: [PATCH 026/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 76e61cbb3f2..e3aad6afc9e 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -84,11 +84,13 @@ def _make_env(cfg, device): env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) ) + assert env is not None return env def transform_env(cfg, env): - env = TransformedEnv(env) + if not isinstance(env, TransformedEnv): + env = TransformedEnv(env) if cfg.env.from_pixels: # transforms pixel from 0-255 to 0-1 (uint8 to float32) env.append_transform( From 36a672d95cc8905e73e2086b190012eb6966bd04 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:18:17 +0200 Subject: [PATCH 027/113] amend --- sota-implementations/dreamer/dreamer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 291d0d4294c..348575c9abf 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -171,7 +171,7 @@ def main(cfg: "DictConfig"): # noqa: F821 scaler3.step(value_opt) scaler3.update() - metrics_to_log = {"reward": ep_reward.item()} + metrics_to_log = {"reward": ep_reward.mean().item()} if collected_frames >= init_random_frames: loss_metrics = { "loss_model_kl": model_loss_td["loss_model_kl"].item(), From f517e452783743baefd6220b320c937d5cc39e80 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:40:20 +0200 Subject: [PATCH 028/113] amend --- sota-implementations/dreamer/config.yaml | 6 +++--- sota-implementations/dreamer/dreamer.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 13 ++++++------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 363eac99db4..e74e2a3fb5d 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -9,12 +9,12 @@ env: image_size : 64 horizon: 500 n_parallel_envs: 8 + device: null collector: total_frames: 5_000_000 - init_random_frames: 1000 - frames_per_batch: 1000 - max_frames_per_traj: 1000 + init_random_frames: 8000 + frames_per_batch: 8000 device: cuda:0 optimization: diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 348575c9abf..5321c5399f1 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -50,7 +50,7 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg}, ) - train_env, test_env = make_environments(cfg=cfg, device=device, parallel_envs=cfg.env.n_parallel_envs) + train_env, test_env = make_environments(cfg=cfg, parallel_envs=cfg.env.n_parallel_envs) # Make dreamer components action_key = "action" diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index e3aad6afc9e..7c23fe2cd3a 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -113,9 +113,9 @@ def transform_env(cfg, env): return env -def make_environments(cfg, device, parallel_envs=1): +def make_environments(cfg, parallel_envs=1): """Make environments for training and evaluation.""" - func = functools.partial(_make_env, cfg=cfg, device=device) + func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device) train_env = ParallelEnv( parallel_envs, EnvCreator(func), @@ -141,6 +141,7 @@ def make_dreamer( action_key: str = "action", value_key: str = "state_value", use_decoder_in_env: bool = False, + compile: bool=True, ): test_env = _make_env(config, device="cpu") test_env = transform_env(config, test_env) @@ -163,10 +164,6 @@ def make_dreamer( num_cells=config.networks.hidden_dim, activation_class=get_activation(config.networks.activation), ) - # if config.env.backend == "dm_control": - # observation_in_key = ("position", "velocity") - # obsevation_out_key = "reco_observation" - # else: observation_in_key = "observation" obsevation_out_key = "reco_observation" @@ -283,7 +280,9 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + policy_device=cfg.collector.device, + env_device=train_env.device, + storing_device="cpu", ) collector.set_seed(cfg.env.seed) From c7c5d47586651a5ab658794305fb9d08a2a926fd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:41:21 +0200 Subject: [PATCH 029/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index e74e2a3fb5d..b5ab4449b7b 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -28,7 +28,7 @@ optimization: value_lr: 8e-5 kl_scale: 1.0 free_nats: 3.0 - optim_steps_per_batch: 80 + optim_steps_per_batch: 640 gamma: 0.99 lambda: 0.95 imagination_horizon: 15 From 0b8c7e73370c07b500f639b4f9d065280f5497ae Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:43:00 +0200 Subject: [PATCH 030/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index b5ab4449b7b..e703b9e7959 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -13,7 +13,7 @@ env: collector: total_frames: 5_000_000 - init_random_frames: 8000 + init_random_frames: 1000 frames_per_batch: 8000 device: cuda:0 From 5ad10604104e1a5839806f5c987c2f0216d04c10 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:45:08 +0200 Subject: [PATCH 031/113] amend --- sota-implementations/dreamer/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index e703b9e7959..7f80695e6ca 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -14,7 +14,7 @@ env: collector: total_frames: 5_000_000 init_random_frames: 1000 - frames_per_batch: 8000 + frames_per_batch: 1000 device: cuda:0 optimization: @@ -28,7 +28,7 @@ optimization: value_lr: 8e-5 kl_scale: 1.0 free_nats: 3.0 - optim_steps_per_batch: 640 + optim_steps_per_batch: 80 gamma: 0.99 lambda: 0.95 imagination_horizon: 15 From 140ad7b3eb4c709e2259813e067a84d04c7a1b40 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:51:21 +0200 Subject: [PATCH 032/113] amend --- sota-implementations/dreamer/config.yaml | 3 ++- sota-implementations/dreamer/dreamer.py | 8 +++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 7f80695e6ca..0b2523fe44d 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -52,5 +52,6 @@ logger: project: dreamer-v1 exp_name: ${env.name}-${env.task}-${env.seed} mode: online - eval_iter: 1000 + # eval interval, in collection counts + eval_iter: 10 eval_rollout_steps: 1000 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 5321c5399f1..25e5b766906 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -117,7 +117,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.logger.eval_rollout_steps - for _, tensordict in enumerate(collector): + for i, tensordict in enumerate(collector): pbar.update(tensordict.numel()) current_frames = tensordict.numel() collected_frames += current_frames @@ -128,9 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if collected_frames >= init_random_frames: for _ in range(optim_steps_per_batch): # sample from replay buffer - sampled_tensordict = replay_buffer.sample(batch_size).to( - device, non_blocking=True - ) + sampled_tensordict = replay_buffer.sample(batch_size) # update world model with autocast(dtype=torch.float16): model_loss_td, sampled_tensordict = world_model_loss( @@ -188,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 policy.step(current_frames) collector.update_policy_weights_() # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: + if (i % eval_iter) == 0: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_rollout = test_env.rollout( eval_rollout_steps, From 13f3e05d33d50d8d8ec973ac018cfd27d3aa6ec4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 15:51:37 +0200 Subject: [PATCH 033/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 7c23fe2cd3a..b6d6bb7606f 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -124,7 +124,7 @@ def make_environments(cfg, parallel_envs=1): train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) eval_env = ParallelEnv( - parallel_envs, + 1, EnvCreator(func), serial_for_single=True, ) From c27559431f36454565c5e276794fcea4f1f6e43a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:01:56 +0200 Subject: [PATCH 034/113] amend --- sota-implementations/dreamer/dreamer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 25e5b766906..17559885ed2 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -78,6 +78,10 @@ def main(cfg: "DictConfig"): # noqa: F821 ) value_loss = DreamerValueLoss(value_model, discount_loss=True) + world_model_loss = torch.compile(world_model_loss) + actor_loss = torch.compile(actor_loss) + value_loss = torch.compile(value_loss) + # Make collector collector = make_collector(cfg, train_env, policy) From d02d1e335eaaaef65da12203adb4f37c3d3ec8c5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:08:30 +0200 Subject: [PATCH 035/113] amend --- sota-implementations/dreamer/dreamer.py | 27 ++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 17559885ed2..a01469b3ecb 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import time import hydra import torch @@ -121,18 +122,31 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.logger.eval_rollout_steps + t_collect_init = time.time() for i, tensordict in enumerate(collector): + t_collect = time.time() - t_collect_init + + t_preproc_init = time.time() pbar.update(tensordict.numel()) current_frames = tensordict.numel() collected_frames += current_frames ep_reward = tensordict.get("episode_reward")[:, -1] replay_buffer.extend(tensordict.cpu()) + t_preproc = time.time() - t_preproc_init if collected_frames >= init_random_frames: + t_loss_actor = 0.0 + t_loss_critic = 0.0 + t_loss_model = 0.0 + for _ in range(optim_steps_per_batch): # sample from replay buffer + t_sample_init = time.time() sampled_tensordict = replay_buffer.sample(batch_size) + t_sample = time.time() - t_sample_init + + t_loss_model_init = time.time() # update world model with autocast(dtype=torch.float16): model_loss_td, sampled_tensordict = world_model_loss( @@ -150,8 +164,10 @@ def main(cfg: "DictConfig"): # noqa: F821 clip_grad_norm_(world_model.parameters(), grad_clip) scaler1.step(world_model_opt) scaler1.update() + t_loss_model += (time.time()-t_loss_model_init) # update actor network + t_loss_actor_init = time.time() with autocast(dtype=torch.float16): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) @@ -161,8 +177,10 @@ def main(cfg: "DictConfig"): # noqa: F821 clip_grad_norm_(actor_model.parameters(), grad_clip) scaler2.step(actor_opt) scaler2.update() + t_loss_actor += time.time() - t_loss_actor_init # update value network + t_loss_critic_init = time.time() with autocast(dtype=torch.float16): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) @@ -172,6 +190,7 @@ def main(cfg: "DictConfig"): # noqa: F821 clip_grad_norm_(value_model.parameters(), grad_clip) scaler3.step(value_opt) scaler3.update() + t_loss_critic += time.time() - t_loss_critic_init metrics_to_log = {"reward": ep_reward.mean().item()} if collected_frames >= init_random_frames: @@ -181,6 +200,12 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_model_reward": model_loss_td["loss_model_reward"].item(), "loss_actor": actor_loss_td["loss_actor"].item(), "loss_value": value_loss_td["loss_value"].item(), + "t_loss_actor": t_loss_actor, + "t_loss_critic": t_loss_critic, + "t_loss_model": t_loss_model, + "t_sample": t_sample, + "t_preproc": t_preproc, + "t_collect": t_collect, } metrics_to_log.update(loss_metrics) @@ -202,7 +227,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_metrics = {"eval/reward": eval_reward} if logger is not None: log_metrics(logger, eval_metrics, collected_frames) - + t_collect_init = time.time() if __name__ == "__main__": main() From 9a1f5efe57fa37ef9155916291b72711c9a302aa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:09:34 +0200 Subject: [PATCH 036/113] amend --- sota-implementations/dreamer/dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a01469b3ecb..ce719e44fc8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -79,9 +79,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) value_loss = DreamerValueLoss(value_model, discount_loss=True) - world_model_loss = torch.compile(world_model_loss) - actor_loss = torch.compile(actor_loss) - value_loss = torch.compile(value_loss) + # world_model_loss = torch.compile(world_model_loss) + # actor_loss = torch.compile(actor_loss) + # value_loss = torch.compile(value_loss) # Make collector collector = make_collector(cfg, train_env, policy) From acbeb51a2cde32c1806a67fbe6e010419a1081e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:40:15 +0200 Subject: [PATCH 037/113] amend --- sota-implementations/dreamer/dreamer.py | 5 ++-- sota-implementations/dreamer/dreamer_utils.py | 23 ++++++++++--------- torchrl/modules/models/model_based.py | 2 +- torchrl/objectives/dreamer.py | 4 ++-- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index ce719e44fc8..1032a93302a 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -87,9 +87,10 @@ def main(cfg: "DictConfig"): # noqa: F821 collector = make_collector(cfg, train_env, policy) # Make replay buffer + batch_length = cfg.optimization.batch_length replay_buffer = make_replay_buffer( batch_size=cfg.replay_buffer.batch_size, - batch_seq_len=cfg.optimization.batch_length, + batch_seq_len=batch_length, buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, device=cfg.networks.device, @@ -143,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for _ in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() - sampled_tensordict = replay_buffer.sample(batch_size) + sampled_tensordict = replay_buffer.sample(batch_size).reshape(-1, batch_length) t_sample = time.time() - t_sample_init t_loss_model_init = time.time() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index b6d6bb7606f..b931ca4addc 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -9,7 +9,8 @@ import torch import torch.nn as nn -from tensordict.nn import InteractionType +from tensordict.nn import InteractionType, TensorDictModule, ProbabilisticTensorDictModule, \ + ProbabilisticTensorDictSequential, TensorDictSequential from torchrl.collectors import SyncDataCollector from torchrl.data import SliceSampler, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -576,7 +577,7 @@ def _dreamer_make_world_model( ): # World Model and reward model rssm_rollout = RSSMRollout( - SafeModule( + TensorDictModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -586,7 +587,7 @@ def _dreamer_make_world_model( ("next", "belief"), ], ), - SafeModule( + TensorDictModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ @@ -597,13 +598,13 @@ def _dreamer_make_world_model( ), ) event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB - decoder = SafeProbabilisticTensorDictSequential( - SafeModule( + decoder = ProbabilisticTensorDictSequential( + TensorDictModule( decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=["loc"], ), - SafeProbabilisticModule( + ProbabilisticTensorDictModule( in_keys=["loc"], out_keys=[("next", observation_out_key)], distribution_class=IndependentNormal, @@ -611,8 +612,8 @@ def _dreamer_make_world_model( ), ) - transition_model = SafeSequential( - SafeModule( + transition_model = TensorDictSequential( + TensorDictModule( encoder, in_keys=[("next", observation_in_key)], out_keys=[("next", "encoded_latents")], @@ -621,13 +622,13 @@ def _dreamer_make_world_model( decoder, ) - reward_model = SafeProbabilisticTensorDictSequential( - SafeModule( + reward_model = ProbabilisticTensorDictSequential( + TensorDictModule( reward_module, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "loc")], ), - SafeProbabilisticModule( + ProbabilisticTensorDictModule( in_keys=[("next", "loc")], out_keys=[("next", "reward")], distribution_class=IndependentNormal, diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 6196d69c543..c99b98b26ed 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -250,7 +250,7 @@ def forward(self, tensordict): ) _tensordict = update_values[..., t + 1].update(_tensordict) - return torch.stack(tensordict_out, tensordict.ndimension() - 1).contiguous() + return torch.stack(tensordict_out, tensordict.ndim - 1) class RSSMPrior(nn.Module): diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index b2ffae6d2bf..485d5f48570 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -116,8 +116,8 @@ def __init__( def _forward_value_estimator_keys(self, **kwargs) -> None: pass - def forward(self, tensordict: TensorDict) -> torch.Tensor: - tensordict = tensordict.clone(recurse=False) + def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: + tensordict = tensordict.copy() tensordict.rename_key_( ("next", self.tensor_keys.reward), ("next", self.tensor_keys.true_reward), From e22ba77ac1dac469434f727f5f3f9e2a44a9151c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:46:37 +0200 Subject: [PATCH 038/113] amend --- torchrl/modules/models/model_based.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index c99b98b26ed..3a8ef3279a1 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -208,8 +208,8 @@ def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModul _module = SafeSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys - self.rssm_prior = rssm_prior - self.rssm_posterior = rssm_posterior + self.rssm_prior = torch.compile(rssm_prior) + self.rssm_posterior = torch.compile(rssm_posterior) def forward(self, tensordict): """Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations. From 83b2074cb205e3bf54c1813f3b29f29ac42ba07b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:51:57 +0200 Subject: [PATCH 039/113] amend --- sota-implementations/dreamer/dreamer.py | 11 ++++-- sota-implementations/dreamer/dreamer_utils.py | 22 ++++++----- torchrl/data/replay_buffers/storages.py | 6 +-- torchrl/modules/models/model_based.py | 37 ++++++++++--------- 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 1032a93302a..0c7e67dd907 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -51,7 +51,9 @@ def main(cfg: "DictConfig"): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg}, ) - train_env, test_env = make_environments(cfg=cfg, parallel_envs=cfg.env.n_parallel_envs) + train_env, test_env = make_environments( + cfg=cfg, parallel_envs=cfg.env.n_parallel_envs + ) # Make dreamer components action_key = "action" @@ -144,7 +146,9 @@ def main(cfg: "DictConfig"): # noqa: F821 for _ in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() - sampled_tensordict = replay_buffer.sample(batch_size).reshape(-1, batch_length) + sampled_tensordict = replay_buffer.sample(batch_size).reshape( + -1, batch_length + ) t_sample = time.time() - t_sample_init t_loss_model_init = time.time() @@ -165,7 +169,7 @@ def main(cfg: "DictConfig"): # noqa: F821 clip_grad_norm_(world_model.parameters(), grad_clip) scaler1.step(world_model_opt) scaler1.update() - t_loss_model += (time.time()-t_loss_model_init) + t_loss_model += time.time() - t_loss_model_init # update actor network t_loss_actor_init = time.time() @@ -230,5 +234,6 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, eval_metrics, collected_frames) t_collect_init = time.time() + if __name__ == "__main__": main() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index b931ca4addc..2ce807f23ba 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -9,8 +9,13 @@ import torch import torch.nn as nn -from tensordict.nn import InteractionType, TensorDictModule, ProbabilisticTensorDictModule, \ - ProbabilisticTensorDictSequential, TensorDictSequential +from tensordict.nn import ( + InteractionType, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModule, + TensorDictSequential, +) from torchrl.collectors import SyncDataCollector from torchrl.data import SliceSampler, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage @@ -36,10 +41,11 @@ TransformedEnv, ) from torchrl.envs.transforms.transforms import ( + DeviceCastTransform, ExcludeTransform, RenameTransform, StepCounter, - TensorDictPrimer, DeviceCastTransform, + TensorDictPrimer, ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( @@ -75,12 +81,8 @@ def _make_env(cfg, device): else: raise NotImplementedError(f"Unknown lib {lib}.") default_dict = { - "state": UnboundedContinuousTensorSpec( - shape=(cfg.networks.state_dim,) - ), - "belief": UnboundedContinuousTensorSpec( - shape=(cfg.networks.rssm_hidden_dim,) - ), + "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim,)), + "belief": UnboundedContinuousTensorSpec(shape=(cfg.networks.rssm_hidden_dim,)), } env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -142,7 +144,7 @@ def make_dreamer( action_key: str = "action", value_key: str = "state_value", use_decoder_in_env: bool = False, - compile: bool=True, + compile: bool = True, ): test_env = _make_env(config, device="cpu") test_env = transform_env(config, test_env) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 61822515365..23177a301e1 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -25,11 +25,7 @@ from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten -from torchrl._utils import ( - _CKPT_BACKEND, - implement_for, - logger as torchrl_logger, -) +from torchrl._utils import _CKPT_BACKEND, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES try: diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 3a8ef3279a1..ccad0860ab3 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -6,11 +6,10 @@ import torch from packaging import version -from tensordict.nn import TensorDictModule, TensorDictModuleBase +from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictModuleBase from torch import nn from torchrl.envs.utils import step_mdp -from torchrl.modules.distributions import NormalParamWrapper from torchrl.modules.models.models import MLP from torchrl.modules.tensordict_module.sequence import SafeSequential @@ -49,14 +48,16 @@ def __init__( std_min_val=1e-4, ): super().__init__() - self.backbone = NormalParamWrapper( + self.backbone = nn.Sequential( MLP( out_features=2 * out_features, depth=depth, num_cells=num_cells, activation_class=activation_class, ), - scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", + NormalParamExtractor( + scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", + ), ) def forward(self, state, belief): @@ -289,14 +290,14 @@ def __init__( # Prior self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim) self.action_state_projector = nn.Sequential(nn.LazyLinear(hidden_dim), nn.ELU()) - self.rnn_to_prior_projector = NormalParamWrapper( - nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.ELU(), - nn.Linear(hidden_dim, 2 * state_dim), + self.rnn_to_prior_projector = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ELU(), + nn.Linear(hidden_dim, 2 * state_dim), + NormalParamExtractor( + scale_lb=scale_lb, + scale_mapping="softplus", ), - scale_lb=scale_lb, - scale_mapping="softplus", ) self.state_dim = state_dim @@ -344,14 +345,14 @@ class RSSMPosterior(nn.Module): def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1): super().__init__() - self.obs_rnn_to_post_projector = NormalParamWrapper( - nn.Sequential( - nn.LazyLinear(hidden_dim), - nn.ELU(), - nn.Linear(hidden_dim, 2 * state_dim), + self.obs_rnn_to_post_projector = nn.Sequential( + nn.LazyLinear(hidden_dim), + nn.ELU(), + nn.Linear(hidden_dim, 2 * state_dim), + NormalParamExtractor( + scale_lb=scale_lb, + scale_mapping="softplus", ), - scale_lb=scale_lb, - scale_mapping="softplus", ) self.hidden_dim = hidden_dim From 87107db33bf33fef84415dfeac5fe959720862ea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 16:56:43 +0200 Subject: [PATCH 040/113] amend --- sota-implementations/dreamer/dreamer.py | 5 ++++- torchrl/modules/models/model_based.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 0c7e67dd907..39dd1659f60 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -121,10 +121,13 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_size = cfg.optimization.batch_size optim_steps_per_batch = cfg.optimization.optim_steps_per_batch grad_clip = cfg.optimization.grad_clip - frames_per_batch = cfg.collector.frames_per_batch eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.logger.eval_rollout_steps + print('Compiling') + world_model_loss.world_model.rssm_prior = torch.compile(world_model_loss.world_model.rssm_prior) + world_model_loss.world_model.rssm_posterior = torch.compile(world_model_loss.world_model.rssm_posterior) + t_collect_init = time.time() for i, tensordict in enumerate(collector): t_collect = time.time() - t_collect_init diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index ccad0860ab3..2c5aede0be5 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -209,8 +209,8 @@ def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModul _module = SafeSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys - self.rssm_prior = torch.compile(rssm_prior) - self.rssm_posterior = torch.compile(rssm_posterior) + self.rssm_prior = rssm_prior + self.rssm_posterior = rssm_posterior def forward(self, tensordict): """Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations. From 4ede077dfab2a020b75ebf73263b7cea8bdc577f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:00:05 +0200 Subject: [PATCH 041/113] amend --- torchrl/modules/models/model_based.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 2c5aede0be5..074af0a31fb 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -48,13 +48,13 @@ def __init__( std_min_val=1e-4, ): super().__init__() - self.backbone = nn.Sequential( - MLP( + self.backbone = MLP( out_features=2 * out_features, depth=depth, num_cells=num_cells, activation_class=activation_class, - ), + ) + self.backbone.append( NormalParamExtractor( scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", ), From 5c5529f38057238f9996f49bb3da26597878e83a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:02:51 +0200 Subject: [PATCH 042/113] amend --- sota-implementations/dreamer/dreamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 39dd1659f60..d7aded61ea8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -125,8 +125,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_rollout_steps = cfg.logger.eval_rollout_steps print('Compiling') - world_model_loss.world_model.rssm_prior = torch.compile(world_model_loss.world_model.rssm_prior) - world_model_loss.world_model.rssm_posterior = torch.compile(world_model_loss.world_model.rssm_posterior) + world_model_loss.world_model[0][1].rssm_prior = torch.compile(world_model_loss.world_model.rssm_prior) + world_model_loss.world_model[0][1].rssm_posterior = torch.compile(world_model_loss.world_model.rssm_posterior) t_collect_init = time.time() for i, tensordict in enumerate(collector): From 0a795a47dd2adbb294b884d2f9c5fc1b0006d6be Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:05:28 +0200 Subject: [PATCH 043/113] amend --- sota-implementations/dreamer/dreamer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index d7aded61ea8..b4518ae70f3 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -125,8 +125,11 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_rollout_steps = cfg.logger.eval_rollout_steps print('Compiling') - world_model_loss.world_model[0][1].rssm_prior = torch.compile(world_model_loss.world_model.rssm_prior) - world_model_loss.world_model[0][1].rssm_posterior = torch.compile(world_model_loss.world_model.rssm_posterior) + def compile_rssms(module): + if isinstance(module, RSSMRollout): + module.rssm_prior = torch.compile(module.rssm_prior) + module.rssm_posterior = torch.compile(module.rssm_posterior) + world_model_loss.apply(compile_rssms) t_collect_init = time.time() for i, tensordict in enumerate(collector): From 0ee83cd2e80f3b25c8fc990039b239e81b0d5b6c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:05:36 +0200 Subject: [PATCH 044/113] amend --- sota-implementations/dreamer/dreamer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index b4518ae70f3..fb4ec37718b 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -20,6 +20,7 @@ from torch.cuda.amp import autocast, GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules.models.model_based import RSSMRollout from torchrl.objectives.dreamer import ( DreamerActorLoss, From 7c59b96c3a463c0ac38e1588dd0970451aa77f3f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:21:52 +0200 Subject: [PATCH 045/113] amend --- sota-implementations/dreamer/dreamer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index fb4ec37718b..105f171f60e 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -127,7 +127,8 @@ def main(cfg: "DictConfig"): # noqa: F821 print('Compiling') def compile_rssms(module): - if isinstance(module, RSSMRollout): + if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): + module._compiled = True module.rssm_prior = torch.compile(module.rssm_prior) module.rssm_posterior = torch.compile(module.rssm_posterior) world_model_loss.apply(compile_rssms) From 07d3e93af0aea26cbdfaf50a4c3700f9d5400448 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:35:29 +0200 Subject: [PATCH 046/113] amend --- sota-implementations/dreamer/dreamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 105f171f60e..5560d151655 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -129,8 +129,8 @@ def main(cfg: "DictConfig"): # noqa: F821 def compile_rssms(module): if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): module._compiled = True - module.rssm_prior = torch.compile(module.rssm_prior) - module.rssm_posterior = torch.compile(module.rssm_posterior) + module.rssm_prior.module = torch.compile(module.rssm_prior.module) + module.rssm_posterior.module = torch.compile(module.rssm_posterior.module) world_model_loss.apply(compile_rssms) t_collect_init = time.time() From 980a7ae2a10a6fe9a4bd6776c962fba7ccc2b4f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:41:52 +0200 Subject: [PATCH 047/113] amend --- sota-implementations/dreamer/dreamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 5560d151655..f81da7ae36f 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -129,8 +129,8 @@ def main(cfg: "DictConfig"): # noqa: F821 def compile_rssms(module): if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): module._compiled = True - module.rssm_prior.module = torch.compile(module.rssm_prior.module) - module.rssm_posterior.module = torch.compile(module.rssm_posterior.module) + module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") + module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") world_model_loss.apply(compile_rssms) t_collect_init = time.time() From 211882139f796de5e5361f475c98ef143d052247 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:46:09 +0200 Subject: [PATCH 048/113] amend --- sota-implementations/dreamer/config.yaml | 1 + sota-implementations/dreamer/dreamer.py | 14 ++++++++------ torchrl/modules/models/model_based.py | 3 ++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 0b2523fe44d..33e8ceffb8d 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -32,6 +32,7 @@ optimization: gamma: 0.99 lambda: 0.95 imagination_horizon: 15 + compile: True networks: exploration_noise: 0.3 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f81da7ae36f..c5bd87168b8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -21,6 +21,7 @@ from torch.nn.utils import clip_grad_norm_ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.models.model_based import RSSMRollout +from torchrl._utils import logger as torchrl_logger from torchrl.objectives.dreamer import ( DreamerActorLoss, @@ -125,12 +126,13 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.logger.eval_rollout_steps - print('Compiling') - def compile_rssms(module): - if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): - module._compiled = True - module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") - module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") + if cfg.optimization.compile: + torchrl_logger.info('Compiling') + def compile_rssms(module): + if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): + module._compiled = True + module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") + module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") world_model_loss.apply(compile_rssms) t_collect_init = time.time() diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 074af0a31fb..6b3279017b1 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -10,6 +10,7 @@ from torch import nn from torchrl.envs.utils import step_mdp +from torchrl.modules import GRUCell from torchrl.modules.models.models import MLP from torchrl.modules.tensordict_module.sequence import SafeSequential @@ -288,7 +289,7 @@ def __init__( super().__init__() # Prior - self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim) + self.rnn = GRUCell(hidden_dim, rnn_hidden_dim) self.action_state_projector = nn.Sequential(nn.LazyLinear(hidden_dim), nn.ELU()) self.rnn_to_prior_projector = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), From e85070f84821ea3a0fe4703b8c19af6e1d603950 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:46:40 +0200 Subject: [PATCH 049/113] amend --- torchrl/modules/models/model_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 6b3279017b1..7a296052f37 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -10,7 +10,7 @@ from torch import nn from torchrl.envs.utils import step_mdp -from torchrl.modules import GRUCell +from torchrl.modules.tensordict_module.rnn import GRUCell from torchrl.modules.models.models import MLP from torchrl.modules.tensordict_module.sequence import SafeSequential From 6b373294de1727b2d5e69a17ecca05cf3d98dbb6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:53:03 +0200 Subject: [PATCH 050/113] amend --- torchrl/modules/models/model_based.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 7a296052f37..2bc515f56e5 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -14,6 +14,7 @@ from torchrl.modules.models.models import MLP from torchrl.modules.tensordict_module.sequence import SafeSequential +UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") class DreamerActor(nn.Module): """Dreamer actor network. @@ -239,10 +240,12 @@ def forward(self, tensordict): for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] + print("_tensordict", _tensordict) self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] + print("_tensordict", _tensordict) self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) @@ -304,15 +307,12 @@ def __init__( self.state_dim = state_dim self.rnn_hidden_dim = rnn_hidden_dim self.action_shape = action_spec.shape - self._unsqueeze_rnn_input = version.parse(torch.__version__) < version.parse( - "1.11" - ) def forward(self, state, belief, action): projector_input = torch.cat([state, action], dim=-1) action_state = self.action_state_projector(projector_input) unsqueeze = False - if self._unsqueeze_rnn_input and action_state.ndimension() == 1: + if UNSQUEEZE_RNN_INPUT and action_state.ndimension() == 1: if belief is not None: belief = belief.unsqueeze(0) action_state = action_state.unsqueeze(0) From ac8f7449eb126cacb72211cfbeb0a8ef2e70f672 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:57:05 +0200 Subject: [PATCH 051/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 33e8ceffb8d..5abb903f2c4 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -20,7 +20,7 @@ collector: optimization: train_every: 1000 grad_clip: 100 - batch_size: 50 + batch_size: 2500 batch_length: 50 world_model_lr: 6e-4 From 9cc8445052da19bfc227070e4a44ceb66ffaddf8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 17:58:05 +0200 Subject: [PATCH 052/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 5abb903f2c4..a101f1df294 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -13,7 +13,7 @@ env: collector: total_frames: 5_000_000 - init_random_frames: 1000 + init_random_frames: 3000 frames_per_batch: 1000 device: cuda:0 From 1c204f4915dd9d1971fa40485677320296c4236b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:03:16 +0200 Subject: [PATCH 053/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index a101f1df294..fa5f26088ae 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -32,7 +32,7 @@ optimization: gamma: 0.99 lambda: 0.95 imagination_horizon: 15 - compile: True + compile: False networks: exploration_noise: 0.3 From 0e4b1eedc1afe37b5de985e026c169c4145f2e26 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:05:03 +0200 Subject: [PATCH 054/113] amend --- sota-implementations/dreamer/dreamer.py | 6 +----- torchrl/modules/models/model_based.py | 6 ++---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index c5bd87168b8..d52038ffbfa 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -83,10 +83,6 @@ def main(cfg: "DictConfig"): # noqa: F821 ) value_loss = DreamerValueLoss(value_model, discount_loss=True) - # world_model_loss = torch.compile(world_model_loss) - # actor_loss = torch.compile(actor_loss) - # value_loss = torch.compile(value_loss) - # Make collector collector = make_collector(cfg, train_env, policy) @@ -133,7 +129,7 @@ def compile_rssms(module): module._compiled = True module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") - world_model_loss.apply(compile_rssms) + world_model_loss.apply(compile_rssms) t_collect_init = time.time() for i, tensordict in enumerate(collector): diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 2bc515f56e5..d36af265730 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -6,7 +6,7 @@ import torch from packaging import version -from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictModuleBase +from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import nn from torchrl.envs.utils import step_mdp @@ -208,7 +208,7 @@ class RSSMRollout(TensorDictModuleBase): def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule): super().__init__() - _module = SafeSequential(rssm_prior, rssm_posterior) + _module = TensorDictSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys self.rssm_prior = rssm_prior @@ -240,12 +240,10 @@ def forward(self, tensordict): for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] - print("_tensordict", _tensordict) self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] - print("_tensordict", _tensordict) self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) From b0c94968ec507710fc2f5ed481852f219e7a8785 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:09:56 +0200 Subject: [PATCH 055/113] amend --- sota-implementations/dreamer/config.yaml | 3 +- sota-implementations/dreamer/dreamer.py | 45 ++++++++++++++---------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index fa5f26088ae..40f4b93e467 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -32,7 +32,8 @@ optimization: gamma: 0.99 lambda: 0.95 imagination_horizon: 15 - compile: False + compile: True + use_autocast: False networks: exploration_noise: 0.3 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index d52038ffbfa..f1caa240561 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib import time import hydra @@ -111,9 +112,11 @@ def main(cfg: "DictConfig"): # noqa: F821 value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.optimization.value_lr) # Grad scaler for mixed precision training https://pytorch.org/docs/stable/amp.html - scaler1 = GradScaler() - scaler2 = GradScaler() - scaler3 = GradScaler() + use_autocast = cfg.optimization.use_autocast + if use_autocast: + scaler1 = GradScaler() + scaler2 = GradScaler() + scaler3 = GradScaler() init_random_frames = cfg.collector.init_random_frames batch_size = cfg.optimization.batch_size @@ -159,7 +162,7 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model - with autocast(dtype=torch.float16): + with autocast(dtype=torch.float16) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -170,37 +173,43 @@ def compile_rssms(module): ) world_model_opt.zero_grad() - scaler1.scale(loss_world_model).backward() - scaler1.unscale_(world_model_opt) + if use_autocast: + scaler1.scale(loss_world_model).backward() + scaler1.unscale_(world_model_opt) clip_grad_norm_(world_model.parameters(), grad_clip) - scaler1.step(world_model_opt) - scaler1.update() + if use_autocast: + scaler1.step(world_model_opt) + scaler1.update() t_loss_model += time.time() - t_loss_model_init # update actor network t_loss_actor_init = time.time() - with autocast(dtype=torch.float16): + with autocast(dtype=torch.float16) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() - scaler2.scale(actor_loss_td["loss_actor"]).backward() - scaler2.unscale_(actor_opt) + if use_autocast: + scaler2.scale(actor_loss_td["loss_actor"]).backward() + scaler2.unscale_(actor_opt) clip_grad_norm_(actor_model.parameters(), grad_clip) - scaler2.step(actor_opt) - scaler2.update() + if use_autocast: + scaler2.step(actor_opt) + scaler2.update() t_loss_actor += time.time() - t_loss_actor_init # update value network t_loss_critic_init = time.time() - with autocast(dtype=torch.float16): + with autocast(dtype=torch.float16) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() - scaler3.scale(value_loss_td["loss_value"]).backward() - scaler3.unscale_(value_opt) + if use_autocast: + scaler3.scale(value_loss_td["loss_value"]).backward() + scaler3.unscale_(value_opt) clip_grad_norm_(value_model.parameters(), grad_clip) - scaler3.step(value_opt) - scaler3.update() + if use_autocast: + scaler3.step(value_opt) + scaler3.update() t_loss_critic += time.time() - t_loss_critic_init metrics_to_log = {"reward": ep_reward.mean().item()} From 1a4d97997476ba7bbc740e66c51851f6182d3b8d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:15:26 +0200 Subject: [PATCH 056/113] amend --- sota-implementations/dreamer/dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f1caa240561..a97b7ee1b28 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -21,6 +21,7 @@ from torch.cuda.amp import autocast, GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import WorldModelWrapper from torchrl.modules.models.model_based import RSSMRollout from torchrl._utils import logger as torchrl_logger @@ -128,10 +129,9 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.optimization.compile: torchrl_logger.info('Compiling') def compile_rssms(module): - if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): + if isinstance(module, WorldModelWrapper) and not getattr(module, "_compiled", False): module._compiled = True - module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") - module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") + module[0] = torch.compile(module[0], backend="cudagraphs") world_model_loss.apply(compile_rssms) t_collect_init = time.time() From de47d0c3dbb5bef188e551340fe4a8af158c5eb7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:17:49 +0200 Subject: [PATCH 057/113] amend --- sota-implementations/dreamer/dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a97b7ee1b28..d2869562d38 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -21,7 +21,6 @@ from torch.cuda.amp import autocast, GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import WorldModelWrapper from torchrl.modules.models.model_based import RSSMRollout from torchrl._utils import logger as torchrl_logger @@ -129,9 +128,10 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.optimization.compile: torchrl_logger.info('Compiling') def compile_rssms(module): - if isinstance(module, WorldModelWrapper) and not getattr(module, "_compiled", False): + if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): module._compiled = True - module[0] = torch.compile(module[0], backend="cudagraphs") + module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") + # module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") world_model_loss.apply(compile_rssms) t_collect_init = time.time() From 19b04acab4ac4b78af506ce57aaf6e3a3c81aa5c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:22:15 +0200 Subject: [PATCH 058/113] amend --- sota-implementations/dreamer/dreamer.py | 2 +- torchrl/modules/models/model_based.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index d2869562d38..f1caa240561 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -131,7 +131,7 @@ def compile_rssms(module): if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): module._compiled = True module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") - # module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") + module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") world_model_loss.apply(compile_rssms) t_collect_init = time.time() diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index d36af265730..b9183641985 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -234,12 +234,13 @@ def forward(self, tensordict): """ tensordict_out = [] *batch, time_steps = tensordict.shape - _tensordict = tensordict[..., 0] - update_values = tensordict.exclude(*self.out_keys) + update_values = tensordict.exclude(*self.out_keys).unbind(-1) + _tensordict = update_values[0] for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] + print('t', t) self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) @@ -251,7 +252,7 @@ def forward(self, tensordict): _tensordict = step_mdp( _tensordict.select(*self.out_keys, strict=False), keep_other=False ) - _tensordict = update_values[..., t + 1].update(_tensordict) + _tensordict = update_values[t + 1].update(_tensordict) return torch.stack(tensordict_out, tensordict.ndim - 1) From 3d42352df06c810029aeeb0ec0b1865ae06d61c6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:26:54 +0200 Subject: [PATCH 059/113] amend --- sota-implementations/dreamer/config.yaml | 4 ++-- sota-implementations/dreamer/dreamer.py | 1 + torchrl/modules/models/model_based.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 40f4b93e467..8e20fa002ee 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -32,8 +32,8 @@ optimization: gamma: 0.99 lambda: 0.95 imagination_horizon: 15 - compile: True - use_autocast: False + compile: False + use_autocast: True networks: exploration_noise: 0.3 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f1caa240561..41441fe97cc 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -171,6 +171,7 @@ def compile_rssms(module): + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) + print(loss_world_model.dtype) world_model_opt.zero_grad() if use_autocast: diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index b9183641985..b5d112a9a7b 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -240,7 +240,6 @@ def forward(self, tensordict): for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] - print('t', t) self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) From 3fd1dfa2aa49ae8a2acf4c68134722f462efb677 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:31:54 +0200 Subject: [PATCH 060/113] amend --- sota-implementations/dreamer/dreamer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 41441fe97cc..a97816ae51f 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -177,6 +177,8 @@ def compile_rssms(module): if use_autocast: scaler1.scale(loss_world_model).backward() scaler1.unscale_(world_model_opt) + else: + loss_world_model.backward() clip_grad_norm_(world_model.parameters(), grad_clip) if use_autocast: scaler1.step(world_model_opt) @@ -192,6 +194,8 @@ def compile_rssms(module): if use_autocast: scaler2.scale(actor_loss_td["loss_actor"]).backward() scaler2.unscale_(actor_opt) + else: + actor_loss_td["loss_actor"].backward() clip_grad_norm_(actor_model.parameters(), grad_clip) if use_autocast: scaler2.step(actor_opt) @@ -207,6 +211,8 @@ def compile_rssms(module): if use_autocast: scaler3.scale(value_loss_td["loss_value"]).backward() scaler3.unscale_(value_opt) + else: + value_loss_td["loss_value"].backward() clip_grad_norm_(value_model.parameters(), grad_clip) if use_autocast: scaler3.step(value_opt) From 31f8dbcf11160819c437fa91940465ecddd5caed Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:33:58 +0200 Subject: [PATCH 061/113] amend --- sota-implementations/dreamer/dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a97816ae51f..56605137d9a 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -162,7 +162,7 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model - with autocast(dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast(device_type="cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -187,7 +187,7 @@ def compile_rssms(module): # update actor network t_loss_actor_init = time.time() - with autocast(dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast(device_type="cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() @@ -204,7 +204,7 @@ def compile_rssms(module): # update value network t_loss_critic_init = time.time() - with autocast(dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast(device_type="cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() From f8229013a8e4e651df2c787ffca48ff477df0b4e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:40:52 +0200 Subject: [PATCH 062/113] amend --- sota-implementations/dreamer/dreamer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 56605137d9a..f7564e407ae 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -18,7 +18,7 @@ ) # mixed precision training -from torch.cuda.amp import autocast, GradScaler +from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.models.model_based import RSSMRollout @@ -36,7 +36,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # cfg = correct_for_frame_skip(cfg) - if torch.cuda.is_available() and cfg.networks.device == "": + if torch.cuda.is_available() and cfg.networks.device in (None, ""): device = torch.device("cuda:0") elif cfg.networks.device: device = torch.device(cfg.networks.device) @@ -162,7 +162,7 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model - with torch.autocast(device_type="cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast(device_type=device.type, dtype=torch.float16) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -187,7 +187,7 @@ def compile_rssms(module): # update actor network t_loss_actor_init = time.time() - with torch.autocast(device_type="cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast(device_type=device.type, dtype=torch.float16) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() @@ -204,7 +204,7 @@ def compile_rssms(module): # update value network t_loss_critic_init = time.time() - with torch.autocast(device_type="cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast(device_type=device.type, dtype=torch.float16) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() From 18ea308de6541b5bc0d43926bd2389447faf4ade Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 18:55:59 +0200 Subject: [PATCH 063/113] amend --- sota-implementations/dreamer/dreamer.py | 35 +++++++++++++------ sota-implementations/dreamer/dreamer_utils.py | 2 -- torchrl/modules/models/model_based.py | 22 ++++++++---- torchrl/objectives/dreamer.py | 8 ++--- 4 files changed, 44 insertions(+), 23 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f7564e407ae..5f1ca9c2f07 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -20,9 +20,9 @@ # mixed precision training from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ +from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.models.model_based import RSSMRollout -from torchrl._utils import logger as torchrl_logger from torchrl.objectives.dreamer import ( DreamerActorLoss, @@ -90,7 +90,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # Make replay buffer batch_length = cfg.optimization.batch_length replay_buffer = make_replay_buffer( - batch_size=cfg.replay_buffer.batch_size, batch_seq_len=batch_length, buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, @@ -126,12 +125,20 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_rollout_steps = cfg.logger.eval_rollout_steps if cfg.optimization.compile: - torchrl_logger.info('Compiling') + torchrl_logger.info("Compiling") + def compile_rssms(module): - if isinstance(module, RSSMRollout) and not getattr(module, "_compiled", False): + if isinstance(module, RSSMRollout) and not getattr( + module, "_compiled", False + ): module._compiled = True - module.rssm_prior.module = torch.compile(module.rssm_prior.module, backend="cudagraphs") - module.rssm_posterior.module = torch.compile(module.rssm_posterior.module, backend="cudagraphs") + module.rssm_prior.module = torch.compile( + module.rssm_prior.module, backend="cudagraphs" + ) + module.rssm_posterior.module = torch.compile( + module.rssm_posterior.module, backend="cudagraphs" + ) + world_model_loss.apply(compile_rssms) t_collect_init = time.time() @@ -162,7 +169,9 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model - with torch.autocast(device_type=device.type, dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast( + device_type=device.type, dtype=torch.float16 + ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -171,7 +180,9 @@ def compile_rssms(module): + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) - print(loss_world_model.dtype) + if use_autocast: + print(loss_world_model.dtype) + assert loss_world_model.dtype == torch.bfloat16 world_model_opt.zero_grad() if use_autocast: @@ -187,7 +198,9 @@ def compile_rssms(module): # update actor network t_loss_actor_init = time.time() - with torch.autocast(device_type=device.type, dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast( + device_type=device.type, dtype=torch.float16 + ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() @@ -204,7 +217,9 @@ def compile_rssms(module): # update value network t_loss_critic_init = time.time() - with torch.autocast(device_type=device.type, dtype=torch.float16) if use_autocast else contextlib.nullcontext(): + with torch.autocast( + device_type=device.type, dtype=torch.float16 + ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 2ce807f23ba..46111e7fca2 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -293,7 +293,6 @@ def make_collector(cfg, train_env, actor_model_explore): def make_replay_buffer( - batch_size, *, batch_seq_len, buffer_size=1000000, @@ -348,7 +347,6 @@ def check_no_pixels(data): compile=True, ), transform=transforms, - batch_size=batch_size, ) return replay_buffer diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index b5d112a9a7b..fef6cc8d431 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -6,16 +6,22 @@ import torch from packaging import version -from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictModuleBase, TensorDictSequential +from tensordict.nn import ( + NormalParamExtractor, + TensorDictModule, + TensorDictModuleBase, + TensorDictSequential, +) from torch import nn from torchrl.envs.utils import step_mdp -from torchrl.modules.tensordict_module.rnn import GRUCell from torchrl.modules.models.models import MLP +from torchrl.modules.tensordict_module.rnn import GRUCell from torchrl.modules.tensordict_module.sequence import SafeSequential UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") + class DreamerActor(nn.Module): """Dreamer actor network. @@ -51,11 +57,11 @@ def __init__( ): super().__init__() self.backbone = MLP( - out_features=2 * out_features, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, - ) + out_features=2 * out_features, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + ) self.backbone.append( NormalParamExtractor( scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", @@ -241,10 +247,12 @@ def forward(self, tensordict): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] self.rssm_prior(_tensordict) + print("prior", _tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] self.rssm_posterior(_tensordict) + print("posterior", _tensordict) tensordict_out.append(_tensordict) if t < time_steps - 1: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 485d5f48570..2f55669565a 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -112,6 +112,8 @@ def __init__( self.free_nats = free_nats self.delayed_clamp = delayed_clamp self.global_average = global_average + self.__dict__["decoder"] = self.world_model[0][-1] + self.__dict__["reward_model"] = self.world_model[1] def _forward_value_estimator_keys(self, **kwargs) -> None: pass @@ -131,14 +133,12 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict.get(("next", self.tensor_keys.posterior_std)), ) - decoder = self.world_model[0][-1] - dist = decoder.get_dist(tensordict) + dist = self.decoder.get_dist(tensordict) reco_loss = -dist.log_prob( tensordict.get(("next", self.tensor_keys.pixels)) ).mean() - reward_model = self.world_model[1] - dist = reward_model.get_dist(tensordict) + dist = self.reward_model.get_dist(tensordict) reward_loss = -dist.log_prob( tensordict.get(("next", self.tensor_keys.true_reward)) ).mean() From 9660f5159644751b8a7eb09d95b6a06e9c32397c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 19:18:44 +0200 Subject: [PATCH 064/113] amend --- sota-implementations/dreamer/dreamer.py | 14 ++++---- sota-implementations/dreamer/dreamer_utils.py | 8 +++-- torchrl/modules/models/model_based.py | 2 -- torchrl/objectives/dreamer.py | 32 +++++++++++++++---- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 5f1ca9c2f07..0bf43a06a5d 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -88,8 +88,10 @@ def main(cfg: "DictConfig"): # noqa: F821 collector = make_collector(cfg, train_env, policy) # Make replay buffer + batch_size = cfg.optimization.batch_size batch_length = cfg.optimization.batch_length replay_buffer = make_replay_buffer( + batch_size=batch_size, batch_seq_len=batch_length, buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, @@ -118,7 +120,6 @@ def main(cfg: "DictConfig"): # noqa: F821 scaler3 = GradScaler() init_random_frames = cfg.collector.init_random_frames - batch_size = cfg.optimization.batch_size optim_steps_per_batch = cfg.optimization.optim_steps_per_batch grad_clip = cfg.optimization.grad_clip eval_iter = cfg.logger.eval_iter @@ -162,7 +163,7 @@ def compile_rssms(module): for _ in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() - sampled_tensordict = replay_buffer.sample(batch_size).reshape( + sampled_tensordict = replay_buffer.sample().reshape( -1, batch_length ) t_sample = time.time() - t_sample_init @@ -170,7 +171,7 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model with torch.autocast( - device_type=device.type, dtype=torch.float16 + device_type=device.type ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict @@ -181,8 +182,7 @@ def compile_rssms(module): + model_loss_td["loss_model_reward"] ) if use_autocast: - print(loss_world_model.dtype) - assert loss_world_model.dtype == torch.bfloat16 + assert loss_world_model.dtype in (torch.bfloat16, torch.float16), loss_world_model.dtype world_model_opt.zero_grad() if use_autocast: @@ -199,7 +199,7 @@ def compile_rssms(module): # update actor network t_loss_actor_init = time.time() with torch.autocast( - device_type=device.type, dtype=torch.float16 + device_type=device.type ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) @@ -218,7 +218,7 @@ def compile_rssms(module): # update value network t_loss_critic_init = time.time() with torch.autocast( - device_type=device.type, dtype=torch.float16 + device_type=device.type ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 46111e7fca2..8be6fb22490 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -294,6 +294,7 @@ def make_collector(cfg, train_env, actor_model_explore): def make_replay_buffer( *, + batch_size, batch_seq_len, buffer_size=1000000, buffer_scratch_dir=None, @@ -347,6 +348,7 @@ def check_no_pixels(data): compile=True, ), transform=transforms, + batch_size=batch_size, ) return replay_buffer @@ -360,13 +362,13 @@ def _dreamer_make_value_model( num_cells=hidden_dim, activation_class=get_activation(activation), ) - value_model = SafeProbabilisticTensorDictSequential( - SafeModule( + value_model = ProbabilisticTensorDictSequential( + TensorDictModule( value_model, in_keys=["state", "belief"], out_keys=["loc"], ), - SafeProbabilisticModule( + ProbabilisticTensorDictModule( in_keys=["loc"], out_keys=[value_key], distribution_class=IndependentNormal, diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index fef6cc8d431..0292bbad696 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -247,12 +247,10 @@ def forward(self, tensordict): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] self.rssm_prior(_tensordict) - print("prior", _tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] self.rssm_posterior(_tensordict) - print("posterior", _tensordict) tensordict_out.append(_tensordict) if t < time_steps - 1: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 2f55669565a..5d9a4cbb8f1 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -14,6 +14,7 @@ from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp +from torchrl.modules import IndependentNormal from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, @@ -133,14 +134,30 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict.get(("next", self.tensor_keys.posterior_std)), ) - dist = self.decoder.get_dist(tensordict) - reco_loss = -dist.log_prob( - tensordict.get(("next", self.tensor_keys.pixels)) + dist: IndependentNormal = self.decoder.get_dist(tensordict) + # reco_loss = -dist.log_prob( + # tensordict.get(("next", self.tensor_keys.pixels)) + # ).mean() + x = tensordict.get(("next", self.tensor_keys.pixels)) + loc = dist.base_dist.loc + scale = dist.base_dist.scale + reco_loss = -self.normal_log_probability( + x, + loc, + scale ).mean() - dist = self.reward_model.get_dist(tensordict) - reward_loss = -dist.log_prob( - tensordict.get(("next", self.tensor_keys.true_reward)) + dist: IndependentNormal = self.reward_model.get_dist(tensordict) + # reward_loss = -dist.log_prob( + # tensordict.get(("next", self.tensor_keys.true_reward)) + # ).mean() + x = tensordict.get(("next", self.tensor_keys.true_reward)) + loc = dist.base_dist.loc + scale = dist.base_dist.scale + reward_loss = -self.normal_log_probability( + x, + loc, + scale ).mean() return ( @@ -154,6 +171,9 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: ), tensordict.detach(), ) + @staticmethod + def normal_log_probability(x, mean, std): + return -0.5 * ((x.to(mean.dtype) - mean) / std).pow(2) - std.log() # - 0.5 * math.log(2 * math.pi) def kl_loss( self, From 2ea1f90a136cb02a90eabd17d0faf6c08db9cbb5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 19:20:40 +0200 Subject: [PATCH 065/113] amend --- sota-implementations/dreamer/dreamer.py | 9 +++++---- sota-implementations/dreamer/dreamer_utils.py | 2 -- torchrl/objectives/dreamer.py | 17 ++++++----------- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 0bf43a06a5d..03119b10b95 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -163,9 +163,7 @@ def compile_rssms(module): for _ in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() - sampled_tensordict = replay_buffer.sample().reshape( - -1, batch_length - ) + sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) t_sample = time.time() - t_sample_init t_loss_model_init = time.time() @@ -182,7 +180,10 @@ def compile_rssms(module): + model_loss_td["loss_model_reward"] ) if use_autocast: - assert loss_world_model.dtype in (torch.bfloat16, torch.float16), loss_world_model.dtype + assert loss_world_model.dtype in ( + torch.bfloat16, + torch.float16, + ), loss_world_model.dtype world_model_opt.zero_grad() if use_autocast: diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 8be6fb22490..b3b9072ec04 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -33,8 +33,6 @@ # ExcludeTransform, FrameSkipTransform, GrayScale, - ObservationNorm, - RandomCropTensorDict, Resize, RewardSum, ToTensorImage, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 5d9a4cbb8f1..f5ba3d3b369 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -141,11 +141,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: x = tensordict.get(("next", self.tensor_keys.pixels)) loc = dist.base_dist.loc scale = dist.base_dist.scale - reco_loss = -self.normal_log_probability( - x, - loc, - scale - ).mean() + reco_loss = -self.normal_log_probability(x, loc, scale).mean() dist: IndependentNormal = self.reward_model.get_dist(tensordict) # reward_loss = -dist.log_prob( @@ -154,11 +150,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: x = tensordict.get(("next", self.tensor_keys.true_reward)) loc = dist.base_dist.loc scale = dist.base_dist.scale - reward_loss = -self.normal_log_probability( - x, - loc, - scale - ).mean() + reward_loss = -self.normal_log_probability(x, loc, scale).mean() return ( TensorDict( @@ -171,9 +163,12 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: ), tensordict.detach(), ) + @staticmethod def normal_log_probability(x, mean, std): - return -0.5 * ((x.to(mean.dtype) - mean) / std).pow(2) - std.log() # - 0.5 * math.log(2 * math.pi) + return ( + -0.5 * ((x.to(mean.dtype) - mean) / std).pow(2) - std.log() + ) # - 0.5 * math.log(2 * math.pi) def kl_loss( self, From de498498083bc40fb56c359564ebac9a3b687108 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 19:32:09 +0200 Subject: [PATCH 066/113] amend --- torchrl/modules/models/model_based.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 0292bbad696..0942f7730bf 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -17,7 +17,6 @@ from torchrl.envs.utils import step_mdp from torchrl.modules.models.models import MLP from torchrl.modules.tensordict_module.rnn import GRUCell -from torchrl.modules.tensordict_module.sequence import SafeSequential UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") From 15908b500d7efc3527c4761e56624909c7ff1992 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 19:32:35 +0200 Subject: [PATCH 067/113] amend --- sota-implementations/dreamer/dreamer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 03119b10b95..788c3c12085 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -183,7 +183,7 @@ def compile_rssms(module): assert loss_world_model.dtype in ( torch.bfloat16, torch.float16, - ), loss_world_model.dtype + ), model_loss_td world_model_opt.zero_grad() if use_autocast: From 0739aa059e2aaddbb9ba8214bf4bb92b05f6d9fe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 19:37:43 +0200 Subject: [PATCH 068/113] amend --- sota-implementations/dreamer/dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 788c3c12085..1d3c7427d46 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -169,7 +169,7 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model with torch.autocast( - device_type=device.type + device_type=device.type, dtype=torch.bfloat16, ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict @@ -200,7 +200,7 @@ def compile_rssms(module): # update actor network t_loss_actor_init = time.time() with torch.autocast( - device_type=device.type + device_type=device.type, dtype=torch.bfloat16 ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) @@ -219,7 +219,7 @@ def compile_rssms(module): # update value network t_loss_critic_init = time.time() with torch.autocast( - device_type=device.type + device_type=device.type, dtype=torch.bfloat16 ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) From 934b7f860fe06c129b310dc723ca970ba60482cd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 19:44:53 +0200 Subject: [PATCH 069/113] amend --- torchrl/modules/models/model_based.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 0942f7730bf..9de16c24e10 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -75,7 +75,7 @@ def forward(self, state, belief): class ObsEncoder(nn.Module): """Observation encoder network. - Takes an pixel observation and encodes it into a latent space. + Takes a pixel observation and encodes it into a latent space. Reference: https://arxiv.org/abs/1803.10122 @@ -246,10 +246,12 @@ def forward(self, tensordict): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] self.rssm_prior(_tensordict) + print("prior", _tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] self.rssm_posterior(_tensordict) + print("posterior", _tensordict) tensordict_out.append(_tensordict) if t < time_steps - 1: From 0308fb439371099c26c9bc230cb1f4a340cec58f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 20:45:47 +0200 Subject: [PATCH 070/113] amend --- torchrl/modules/models/model_based.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 9de16c24e10..d98e02902aa 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -63,7 +63,8 @@ def __init__( ) self.backbone.append( NormalParamExtractor( - scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", + # scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", + scale_mapping=f"exp", ), ) @@ -305,7 +306,7 @@ def __init__( nn.Linear(hidden_dim, 2 * state_dim), NormalParamExtractor( scale_lb=scale_lb, - scale_mapping="softplus", + scale_mapping="exp", ), ) @@ -357,7 +358,7 @@ def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1): nn.Linear(hidden_dim, 2 * state_dim), NormalParamExtractor( scale_lb=scale_lb, - scale_mapping="softplus", + scale_mapping="exp", ), ) self.hidden_dim = hidden_dim From 8015029dc1637ceb9069736aa6af64dff2edbd1a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 21:09:16 +0200 Subject: [PATCH 071/113] amend --- torchrl/modules/models/model_based.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index d98e02902aa..c3e0e6194c9 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -16,7 +16,8 @@ from torchrl.envs.utils import step_mdp from torchrl.modules.models.models import MLP -from torchrl.modules.tensordict_module.rnn import GRUCell +# from torchrl.modules.tensordict_module.rnn import GRUCell +from torch.nn import GRUCell UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") From b420c9551b9f72c57d4b67064424d903ebea3fdd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 21:13:36 +0200 Subject: [PATCH 072/113] amend --- torchrl/modules/models/model_based.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index c3e0e6194c9..41500da47e4 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -65,7 +65,7 @@ def __init__( self.backbone.append( NormalParamExtractor( # scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", - scale_mapping=f"exp", + scale_mapping=f"relu", ), ) @@ -307,7 +307,7 @@ def __init__( nn.Linear(hidden_dim, 2 * state_dim), NormalParamExtractor( scale_lb=scale_lb, - scale_mapping="exp", + scale_mapping="relu", ), ) @@ -359,7 +359,7 @@ def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1): nn.Linear(hidden_dim, 2 * state_dim), NormalParamExtractor( scale_lb=scale_lb, - scale_mapping="exp", + scale_mapping="relu", ), ) self.hidden_dim = hidden_dim From f3c858ed85c7c36865e60d36abe8d4940bde7525 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 21:25:23 +0200 Subject: [PATCH 073/113] amend --- sota-implementations/dreamer/dreamer.py | 13 +++++++------ torchrl/modules/models/model_based.py | 7 +++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 1d3c7427d46..b16c17c034f 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -169,7 +169,8 @@ def compile_rssms(module): t_loss_model_init = time.time() # update world model with torch.autocast( - device_type=device.type, dtype=torch.bfloat16, + device_type=device.type, + dtype=torch.bfloat16, ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict @@ -179,11 +180,11 @@ def compile_rssms(module): + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) - if use_autocast: - assert loss_world_model.dtype in ( - torch.bfloat16, - torch.float16, - ), model_loss_td + # if use_autocast: + # assert loss_world_model.dtype in ( + # torch.bfloat16, + # torch.float16, + # ), model_loss_td world_model_opt.zero_grad() if use_autocast: diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 41500da47e4..f4bae27ffef 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -14,11 +14,12 @@ ) from torch import nn -from torchrl.envs.utils import step_mdp -from torchrl.modules.models.models import MLP # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell +from torchrl.envs.utils import step_mdp +from torchrl.modules.models.models import MLP + UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") @@ -248,12 +249,10 @@ def forward(self, tensordict): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] self.rssm_prior(_tensordict) - print("prior", _tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] self.rssm_posterior(_tensordict) - print("posterior", _tensordict) tensordict_out.append(_tensordict) if t < time_steps - 1: From 4d221210515f246b279841f71a0a068236b59cc6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 21:41:39 +0200 Subject: [PATCH 074/113] amend --- torchrl/modules/models/model_based.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index f4bae27ffef..38d43cac9f7 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -65,8 +65,8 @@ def __init__( ) self.backbone.append( NormalParamExtractor( - # scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", - scale_mapping=f"relu", + scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}", + # scale_mapping="relu", ), ) @@ -306,7 +306,7 @@ def __init__( nn.Linear(hidden_dim, 2 * state_dim), NormalParamExtractor( scale_lb=scale_lb, - scale_mapping="relu", + scale_mapping="softplus", ), ) @@ -358,7 +358,7 @@ def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1): nn.Linear(hidden_dim, 2 * state_dim), NormalParamExtractor( scale_lb=scale_lb, - scale_mapping="relu", + scale_mapping="softplus", ), ) self.hidden_dim = hidden_dim From c4ccf0f9aef0b11d3c4e813db15fb777e012ca6e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 21:56:33 +0200 Subject: [PATCH 075/113] amend --- sota-implementations/dreamer/dreamer.py | 11 ++++++----- sota-implementations/dreamer/dreamer_utils.py | 5 ++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index b16c17c034f..2b3d2200a07 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -99,6 +99,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pixel_obs=cfg.env.from_pixels, grayscale=cfg.env.grayscale, image_size=cfg.env.image_size, + use_autocast=cfg.optimization.use_autocast, ) # Training loop @@ -180,11 +181,11 @@ def compile_rssms(module): + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) - # if use_autocast: - # assert loss_world_model.dtype in ( - # torch.bfloat16, - # torch.float16, - # ), model_loss_td + if use_autocast: + assert loss_world_model.dtype in ( + torch.bfloat16, + torch.float16, + ), model_loss_td world_model_opt.zero_grad() if use_autocast: diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index b3b9072ec04..84d3f67bb25 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -43,7 +43,7 @@ ExcludeTransform, RenameTransform, StepCounter, - TensorDictPrimer, + TensorDictPrimer, DTypeCastTransform, ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( @@ -301,6 +301,7 @@ def make_replay_buffer( pixel_obs=True, grayscale=True, image_size, +use_autocast, ): with ( tempfile.TemporaryDirectory() @@ -328,6 +329,8 @@ def check_no_pixels(data): Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) ) transforms.append(DeviceCastTransform(device=device)) + if use_autocast: + transforms.append(DTypeCastTransform(dtype_in=torch.float32, dtype_out=torch.float16)) replay_buffer = TensorDictReplayBuffer( pin_memory=False, From a79f9b540459e1ee68f092d351ba0833282a26a0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 22:05:22 +0200 Subject: [PATCH 076/113] amend --- sota-implementations/dreamer/dreamer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 84d3f67bb25..6bdac202c24 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -330,7 +330,7 @@ def check_no_pixels(data): ) transforms.append(DeviceCastTransform(device=device)) if use_autocast: - transforms.append(DTypeCastTransform(dtype_in=torch.float32, dtype_out=torch.float16)) + transforms.append(DTypeCastTransform(dtype_in=torch.float32, dtype_out=torch.bfloat16)) replay_buffer = TensorDictReplayBuffer( pin_memory=False, From 4b02cb717daf69112696b9c5b5cabb1ebd4bb3f4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 22:07:07 +0200 Subject: [PATCH 077/113] amend --- torchrl/envs/transforms/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 511eef6c410..791a0615628 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3494,10 +3494,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "this functionality is not covered. Consider passing the in_keys " "or not passing any out_keys." ) - for in_key, item in list(tensordict.items(True, True)): + def func(item): if item.dtype == self.dtype_in: item = self._apply_transform(item) - tensordict.set(in_key, item) + return item + tensordict = tensordict._fast_apply(func) else: # we made sure that if in_keys is not None, out_keys is not None either for in_key, out_key in zip(in_keys, out_keys): From b0adb46afa98947c33fc6d7481b52264046f489a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 10:27:42 +0200 Subject: [PATCH 078/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- sota-implementations/dreamer/dreamer.py | 13 +++++++++---- torchrl/_utils.py | 6 ++++++ torchrl/modules/models/model_based.py | 7 +++++-- torchrl/objectives/dreamer.py | 3 ++- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 8e20fa002ee..138bf99d182 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -56,4 +56,4 @@ logger: mode: online # eval interval, in collection counts eval_iter: 10 - eval_rollout_steps: 1000 + eval_rollout_steps: 500 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 2b3d2200a07..c2f4bbf009f 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -9,6 +9,7 @@ import torch import torch.cuda import tqdm +from torchrl._utils import timeit from dreamer_utils import ( log_metrics, make_collector, @@ -152,7 +153,7 @@ def compile_rssms(module): current_frames = tensordict.numel() collected_frames += current_frames - ep_reward = tensordict.get("episode_reward")[:, -1] + ep_reward = tensordict.get("episode_reward")[..., -1, 0] replay_buffer.extend(tensordict.cpu()) t_preproc = time.time() - t_preproc_init @@ -193,7 +194,7 @@ def compile_rssms(module): scaler1.unscale_(world_model_opt) else: loss_world_model.backward() - clip_grad_norm_(world_model.parameters(), grad_clip) + world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip) if use_autocast: scaler1.step(world_model_opt) scaler1.update() @@ -212,7 +213,7 @@ def compile_rssms(module): scaler2.unscale_(actor_opt) else: actor_loss_td["loss_actor"].backward() - clip_grad_norm_(actor_model.parameters(), grad_clip) + actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip) if use_autocast: scaler2.step(actor_opt) scaler2.update() @@ -231,7 +232,7 @@ def compile_rssms(module): scaler3.unscale_(value_opt) else: value_loss_td["loss_value"].backward() - clip_grad_norm_(value_model.parameters(), grad_clip) + critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip) if use_autocast: scaler3.step(value_opt) scaler3.update() @@ -245,12 +246,16 @@ def compile_rssms(module): "loss_model_reward": model_loss_td["loss_model_reward"].item(), "loss_actor": actor_loss_td["loss_actor"].item(), "loss_value": value_loss_td["loss_value"].item(), + "world_model_grad": world_model_grad, + "actor_model_grad": actor_model_grad, + "critic_model_grad": critic_model_grad, "t_loss_actor": t_loss_actor, "t_loss_critic": t_loss_critic, "t_loss_model": t_loss_model, "t_sample": t_sample, "t_preproc": t_preproc, "t_collect": t_collect, + **timeit.todict() } metrics_to_log.update(loss_metrics) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index a9109f97354..422245f8e72 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -94,6 +94,12 @@ def print(prefix=None): # noqa: T202 ) logger.info(" -- ".join(strings)) + @classmethod + def todict(cls, percall=True): + if percall: + return {key: val[0] for key, val in cls._REG.items()} + return {key: val[1] for key, val in cls._REG.items()} + @staticmethod def erase(): for k in timeit._REG: diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 38d43cac9f7..44867c4ae68 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from torchrl._utils import timeit import torch from packaging import version @@ -248,11 +249,13 @@ def forward(self, tensordict): for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] - self.rssm_prior(_tensordict) + with timeit("rssm_rollout/rssm_prior"): + self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] - self.rssm_posterior(_tensordict) + with timeit("rssm_rollout/rssm_posterior"): + self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) if t < time_steps - 1: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index f5ba3d3b369..38b61d85ca9 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -12,6 +12,7 @@ from tensordict.nn import TensorDictModule from tensordict.utils import NestedKey +from torchrl._utils import timeit from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.modules import IndependentNormal @@ -269,7 +270,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.reshape(-1) # TODO: do we need exploration here? - with hold_out_net(self.model_based_env), set_exploration_type( + with timeit("actor_loss/rollout"), hold_out_net(self.model_based_env), set_exploration_type( ExplorationType.MEAN ): # action_td = self.actor_model(td) From b8ec7b4fe36fcb25c785db04badf0f5a6c666900 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 10:33:46 +0200 Subject: [PATCH 079/113] amend --- sota-implementations/dreamer/dreamer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index c2f4bbf009f..718e7ae2942 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -255,8 +255,9 @@ def compile_rssms(module): "t_sample": t_sample, "t_preproc": t_preproc, "t_collect": t_collect, - **timeit.todict() + **timeit.todict(percall=False) } + timeit.erase() metrics_to_log.update(loss_metrics) if logger is not None: From b77f8056e804e984ffe603f185b930d6648d9afe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 10:38:12 +0200 Subject: [PATCH 080/113] amend --- torchrl/objectives/dreamer.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 38b61d85ca9..1a0d3832e29 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -136,22 +136,22 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: ) dist: IndependentNormal = self.decoder.get_dist(tensordict) - # reco_loss = -dist.log_prob( - # tensordict.get(("next", self.tensor_keys.pixels)) - # ).mean() - x = tensordict.get(("next", self.tensor_keys.pixels)) - loc = dist.base_dist.loc - scale = dist.base_dist.scale - reco_loss = -self.normal_log_probability(x, loc, scale).mean() + reco_loss = -dist.log_prob( + tensordict.get(("next", self.tensor_keys.pixels)) + ).mean() + # x = tensordict.get(("next", self.tensor_keys.pixels)) + # loc = dist.base_dist.loc + # scale = dist.base_dist.scale + # reco_loss = -self.normal_log_probability(x, loc, scale).mean() dist: IndependentNormal = self.reward_model.get_dist(tensordict) - # reward_loss = -dist.log_prob( - # tensordict.get(("next", self.tensor_keys.true_reward)) - # ).mean() - x = tensordict.get(("next", self.tensor_keys.true_reward)) - loc = dist.base_dist.loc - scale = dist.base_dist.scale - reward_loss = -self.normal_log_probability(x, loc, scale).mean() + reward_loss = -dist.log_prob( + tensordict.get(("next", self.tensor_keys.true_reward)) + ).mean() + # x = tensordict.get(("next", self.tensor_keys.true_reward)) + # loc = dist.base_dist.loc + # scale = dist.base_dist.scale + # reward_loss = -self.normal_log_probability(x, loc, scale).mean() return ( TensorDict( @@ -183,8 +183,8 @@ def kl_loss( + (posterior_std**2 + (prior_mean - posterior_mean) ** 2) / (2 * prior_std**2) - 0.5 - ).mean() - return kl.clamp_min(self.free_nats) + ) + return kl.clamp_min(self.free_nats).sum(-1).mean() class DreamerActorLoss(LossModule): From 49b0d7e08af5f2b0976e03c218f6abe2efcf618e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 10:47:19 +0200 Subject: [PATCH 081/113] amend --- torchrl/modules/models/model_based.py | 4 ++-- torchrl/objectives/dreamer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 44867c4ae68..9c671eabcab 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -249,12 +249,12 @@ def forward(self, tensordict): for t in range(time_steps): # samples according to p(s_{t+1} | s_t, a_t, b_t) # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] - with timeit("rssm_rollout/rssm_prior"): + with timeit("rssm_rollout/time-rssm_prior"): self.rssm_prior(_tensordict) # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] - with timeit("rssm_rollout/rssm_posterior"): + with timeit("rssm_rollout/time-rssm_posterior"): self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 1a0d3832e29..e23fbf178fb 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -270,7 +270,7 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.reshape(-1) # TODO: do we need exploration here? - with timeit("actor_loss/rollout"), hold_out_net(self.model_based_env), set_exploration_type( + with timeit("actor_loss/time-rollout"), hold_out_net(self.model_based_env), set_exploration_type( ExplorationType.MEAN ): # action_td = self.actor_model(td) From 32aa92c17a4fbdb1638ae4277310e27a37fd6b1e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 12:36:45 +0200 Subject: [PATCH 082/113] amend --- sota-implementations/dreamer/dreamer.py | 15 ++++-- sota-implementations/dreamer/dreamer_utils.py | 9 ++-- torchrl/envs/transforms/transforms.py | 52 ++++++++++++++----- torchrl/modules/models/model_based.py | 2 +- torchrl/objectives/dreamer.py | 10 ++-- 5 files changed, 61 insertions(+), 27 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 718e7ae2942..e8d355b8cb2 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -9,7 +9,6 @@ import torch import torch.cuda import tqdm -from torchrl._utils import timeit from dreamer_utils import ( log_metrics, make_collector, @@ -21,7 +20,7 @@ # mixed precision training from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.models.model_based import RSSMRollout @@ -168,6 +167,10 @@ def compile_rssms(module): sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) t_sample = time.time() - t_sample_init + # print("sampled_tensordict", sampled_tensordict) + # print("steps", sampled_tensordict["next", "steps"]) + # print("traj_ids", sampled_tensordict["collector", "traj_ids"]) + t_loss_model_init = time.time() # update world model with torch.autocast( @@ -198,6 +201,8 @@ def compile_rssms(module): if use_autocast: scaler1.step(world_model_opt) scaler1.update() + else: + world_model_opt.step() t_loss_model += time.time() - t_loss_model_init # update actor network @@ -217,6 +222,8 @@ def compile_rssms(module): if use_autocast: scaler2.step(actor_opt) scaler2.update() + else: + actor_opt.step() t_loss_actor += time.time() - t_loss_actor_init # update value network @@ -236,6 +243,8 @@ def compile_rssms(module): if use_autocast: scaler3.step(value_opt) scaler3.update() + else: + value_opt.step() t_loss_critic += time.time() - t_loss_critic_init metrics_to_log = {"reward": ep_reward.mean().item()} @@ -255,7 +264,7 @@ def compile_rssms(module): "t_sample": t_sample, "t_preproc": t_preproc, "t_collect": t_collect, - **timeit.todict(percall=False) + **timeit.todict(percall=False), } timeit.erase() metrics_to_log.update(loss_metrics) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 6bdac202c24..e82ec806610 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -40,10 +40,11 @@ ) from torchrl.envs.transforms.transforms import ( DeviceCastTransform, + DTypeCastTransform, ExcludeTransform, RenameTransform, StepCounter, - TensorDictPrimer, DTypeCastTransform, + TensorDictPrimer, ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( @@ -301,7 +302,7 @@ def make_replay_buffer( pixel_obs=True, grayscale=True, image_size, -use_autocast, + use_autocast, ): with ( tempfile.TemporaryDirectory() @@ -330,7 +331,9 @@ def check_no_pixels(data): ) transforms.append(DeviceCastTransform(device=device)) if use_autocast: - transforms.append(DTypeCastTransform(dtype_in=torch.float32, dtype_out=torch.bfloat16)) + transforms.append( + DTypeCastTransform(dtype_in=torch.float32, dtype_out=torch.bfloat16) + ) replay_buffer = TensorDictReplayBuffer( pin_memory=False, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 791a0615628..af3d54035c8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3494,10 +3494,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "this functionality is not covered. Consider passing the in_keys " "or not passing any out_keys." ) + def func(item): if item.dtype == self.dtype_in: item = self._apply_transform(item) return item + tensordict = tensordict._fast_apply(func) else: # we made sure that if in_keys is not None, out_keys is not None either @@ -4432,7 +4434,7 @@ class TensorDictPrimer(Transform): random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. - default_value (float, optional): if non-random filling is chosen, this + default_value (float, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): if non-random filling is chosen, this value will be used to populate the tensors. Defaults to `0.0`. reset_key (NestedKey, optional): the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the @@ -4490,8 +4492,8 @@ class TensorDictPrimer(Transform): def __init__( self, primers: dict | CompositeSpec = None, - random: bool = False, - default_value: float = 0.0, + random: bool | None = None, + default_value: float | Dict[NestedKey, float] | Dict[NestedKey, Callable] = 0.0, reset_key: NestedKey | None = None, **kwargs, ): @@ -4506,8 +4508,17 @@ def __init__( if not isinstance(kwargs, CompositeSpec): kwargs = CompositeSpec(kwargs) self.primers = kwargs + if ( + isinstance(default_value, dict) + and check_callable(...) + and random is not None + ): + raise ValueError self.random = random + if not isinstance(self.default_value, dict): + default_value = {key: default_value for key in primers.keys(True, True)} self.default_value = default_value + self._validated = False self.reset_key = reset_key # sanity check @@ -4560,6 +4571,9 @@ def to(self, *args, **kwargs): self.primers = self.primers.to(device) return super().to(*args, **kwargs) + def _maybe_expand_shape(self, spec): + return spec.expand((*self.parent.batch_size, *spec.shape)) + def transform_observation_spec( self, observation_spec: CompositeSpec ) -> CompositeSpec: @@ -4569,10 +4583,13 @@ def transform_observation_spec( ) for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." - ) + try: + spec = self._maybe_expand_shape(spec) + except Smth: + raise RuntimeError( + f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " + f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." + ) try: device = observation_spec.device except RuntimeError: @@ -4591,7 +4608,7 @@ def _batch_size(self): return self.parent.batch_size def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - for key, spec in self.primers.items(): + for key, spec in self.primers.items(True, True): if spec.shape[: len(tensordict.shape)] != tensordict.shape: raise RuntimeError( "The leading shape of the spec must match the tensordict's, " @@ -4602,10 +4619,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.random: value = spec.rand() else: - value = torch.full_like( - spec.zero(), - self.default_value, - ) + if callable(self.default_value[key]): + value = self.default_value[key]() + # validate the value + if not self._validated: + self.validate(value) + self._validated = True + else: + value = torch.full_like( + spec.zero(), + self.default_value[key], + ) tensordict.set(key, value) return tensordict @@ -4635,13 +4659,13 @@ def _reset( ) _reset = _get_reset(self.reset_key, tensordict) if _reset.any(): - for key, spec in self.primers.items(): + for key, spec in self.primers.items(True, True): if self.random: value = spec.rand(shape) else: value = torch.full_like( spec.zero(shape), - self.default_value, + self.default_value[key], ) prev_val = tensordict.get(key, 0.0) value = torch.where(expand_as_right(_reset, value), value, prev_val) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 9c671eabcab..94e5ce18ae5 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings -from torchrl._utils import timeit import torch from packaging import version @@ -17,6 +16,7 @@ # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell +from torchrl._utils import timeit from torchrl.envs.utils import step_mdp from torchrl.modules.models.models import MLP diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index e23fbf178fb..eea561f21c3 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -248,7 +248,7 @@ def __init__( ): super().__init__() self.actor_model = actor_model - self.value_model = value_model + self.__dict__["value_model"] = value_model self.model_based_env = model_based_env self.imagination_horizon = imagination_horizon self.discount_loss = discount_loss @@ -270,11 +270,9 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.reshape(-1) # TODO: do we need exploration here? - with timeit("actor_loss/time-rollout"), hold_out_net(self.model_based_env), set_exploration_type( - ExplorationType.MEAN - ): - # action_td = self.actor_model(td) - + with timeit("actor_loss/time-rollout"), hold_out_net( + self.model_based_env + ), set_exploration_type(ExplorationType.RANDOM): # TODO: we are not using the actual batch beliefs as starting ones - should be solved! took of the primer for the mb_env tensordict = self.model_based_env.reset(tensordict.copy()) # TODO: do we detach state gradients when passing again for new actions: action = self.actor(state.detach()) From 0d4ae71b3607904a23d1b42d31222343e48ec659 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 12:45:48 +0200 Subject: [PATCH 083/113] amend --- torchrl/envs/transforms/transforms.py | 52 ++++++++------------------- 1 file changed, 14 insertions(+), 38 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index af3d54035c8..791a0615628 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3494,12 +3494,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "this functionality is not covered. Consider passing the in_keys " "or not passing any out_keys." ) - def func(item): if item.dtype == self.dtype_in: item = self._apply_transform(item) return item - tensordict = tensordict._fast_apply(func) else: # we made sure that if in_keys is not None, out_keys is not None either @@ -4434,7 +4432,7 @@ class TensorDictPrimer(Transform): random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. - default_value (float, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): if non-random filling is chosen, this + default_value (float, optional): if non-random filling is chosen, this value will be used to populate the tensors. Defaults to `0.0`. reset_key (NestedKey, optional): the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the @@ -4492,8 +4490,8 @@ class TensorDictPrimer(Transform): def __init__( self, primers: dict | CompositeSpec = None, - random: bool | None = None, - default_value: float | Dict[NestedKey, float] | Dict[NestedKey, Callable] = 0.0, + random: bool = False, + default_value: float = 0.0, reset_key: NestedKey | None = None, **kwargs, ): @@ -4508,17 +4506,8 @@ def __init__( if not isinstance(kwargs, CompositeSpec): kwargs = CompositeSpec(kwargs) self.primers = kwargs - if ( - isinstance(default_value, dict) - and check_callable(...) - and random is not None - ): - raise ValueError self.random = random - if not isinstance(self.default_value, dict): - default_value = {key: default_value for key in primers.keys(True, True)} self.default_value = default_value - self._validated = False self.reset_key = reset_key # sanity check @@ -4571,9 +4560,6 @@ def to(self, *args, **kwargs): self.primers = self.primers.to(device) return super().to(*args, **kwargs) - def _maybe_expand_shape(self, spec): - return spec.expand((*self.parent.batch_size, *spec.shape)) - def transform_observation_spec( self, observation_spec: CompositeSpec ) -> CompositeSpec: @@ -4583,13 +4569,10 @@ def transform_observation_spec( ) for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - try: - spec = self._maybe_expand_shape(spec) - except Smth: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." - ) + raise RuntimeError( + f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " + f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." + ) try: device = observation_spec.device except RuntimeError: @@ -4608,7 +4591,7 @@ def _batch_size(self): return self.parent.batch_size def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - for key, spec in self.primers.items(True, True): + for key, spec in self.primers.items(): if spec.shape[: len(tensordict.shape)] != tensordict.shape: raise RuntimeError( "The leading shape of the spec must match the tensordict's, " @@ -4619,17 +4602,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.random: value = spec.rand() else: - if callable(self.default_value[key]): - value = self.default_value[key]() - # validate the value - if not self._validated: - self.validate(value) - self._validated = True - else: - value = torch.full_like( - spec.zero(), - self.default_value[key], - ) + value = torch.full_like( + spec.zero(), + self.default_value, + ) tensordict.set(key, value) return tensordict @@ -4659,13 +4635,13 @@ def _reset( ) _reset = _get_reset(self.reset_key, tensordict) if _reset.any(): - for key, spec in self.primers.items(True, True): + for key, spec in self.primers.items(): if self.random: value = spec.rand(shape) else: value = torch.full_like( spec.zero(shape), - self.default_value[key], + self.default_value, ) prev_val = tensordict.get(key, 0.0) value = torch.where(expand_as_right(_reset, value), value, prev_val) From 6124293a30ac6c7aa8312004cd9413d9e68a979c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:07:28 +0200 Subject: [PATCH 084/113] amend --- examples/dreamer/utils.py | 637 ------------------ sota-implementations/dreamer/dreamer_utils.py | 12 +- 2 files changed, 6 insertions(+), 643 deletions(-) delete mode 100644 examples/dreamer/utils.py diff --git a/examples/dreamer/utils.py b/examples/dreamer/utils.py deleted file mode 100644 index 73866cb3a06..00000000000 --- a/examples/dreamer/utils.py +++ /dev/null @@ -1,637 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import tempfile -from contextlib import nullcontext - -import torch - -import torch.nn as nn -from tensordict.nn import InteractionType -from torchrl.collectors import SyncDataCollector -from torchrl.data import TensorDictReplayBuffer -from torchrl.data.replay_buffers.storages import LazyMemmapStorage - -from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec -from torchrl.envs import ParallelEnv - -from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.model_based.dreamer import DreamerEnv -from torchrl.envs.transforms import ( - Compose, - DoubleToFloat, - # ExcludeTransform, - FrameSkipTransform, - GrayScale, - ObservationNorm, - RandomCropTensorDict, - Resize, - RewardSum, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.transforms.transforms import TensorDictPrimer -from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import ( - MLP, - SafeModule, - SafeProbabilisticModule, - SafeProbabilisticTensorDictSequential, - SafeSequential, -) -from torchrl.modules.distributions import IndependentNormal, TanhNormal -from torchrl.modules.models.model_based import ( - DreamerActor, - ObsDecoder, - ObsEncoder, - RSSMPosterior, - RSSMPrior, - RSSMRollout, -) -from torchrl.modules.tensordict_module.exploration import AdditiveGaussianWrapper -from torchrl.modules.tensordict_module.world_models import WorldModelWrapper - - -def _make_env(cfg, device): - lib = cfg.env.backend - if lib in ("gym", "gymnasium"): - with set_gym_backend(lib): - return GymEnv( - cfg.env.name, - device=device, - ) - elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) - return env - else: - raise NotImplementedError(f"Unknown lib {lib}.") - - -def transform_env(cfg, env, parallel_envs, dummy=False): - env = TransformedEnv(env) - if cfg.env.from_pixels: - # transforms pixel from 0-255 to 0-1 (uint8 to float32) - env.append_transform(ToTensorImage(from_int=True)) - if cfg.env.grayscale: - env.append_transform(GrayScale()) - - img_size = cfg.env.image_size - env.append_transform(Resize(img_size, img_size)) - - env.append_transform(DoubleToFloat()) - env.append_transform(RewardSum()) - env.append_transform(FrameSkipTransform(cfg.env.frame_skip)) - if dummy: - default_dict = { - "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim)), - "belief": UnboundedContinuousTensorSpec( - shape=(cfg.networks.rssm_hidden_dim) - ), - } - else: - default_dict = { - "state": UnboundedContinuousTensorSpec( - shape=(parallel_envs, cfg.networks.state_dim) - ), - "belief": UnboundedContinuousTensorSpec( - shape=(parallel_envs, cfg.networks.rssm_hidden_dim) - ), - } - env.append_transform( - TensorDictPrimer(random=False, default_value=0, **default_dict) - ) - - return env - - -def make_environments(cfg, device, parallel_envs=1): - """Make environments for training and evaluation.""" - train_env = ParallelEnv( - parallel_envs, - EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), - ) - train_env = transform_env(cfg, train_env, parallel_envs) - train_env.set_seed(cfg.env.seed) - eval_env = ParallelEnv( - parallel_envs, - EnvCreator(lambda cfg=cfg: _make_env(cfg, device=device)), - ) - eval_env = transform_env(cfg, eval_env, parallel_envs) - eval_env.set_seed(cfg.env.seed + 1) - - return train_env, eval_env - - -def make_dreamer( - config, - device, - action_key: str = "action", - value_key: str = "state_value", - use_decoder_in_env: bool = False, -): - test_env = _make_env(config, device="cpu") - test_env = transform_env(config, test_env, parallel_envs=1, dummy=True) - # Make encoder and decoder - if config.env.from_pixels: - encoder = ObsEncoder() - decoder = ObsDecoder() - observation_in_key = "pixels" - obsevation_out_key = "reco_pixels" - else: - encoder = MLP( - out_features=1024, - depth=2, - num_cells=config.networks.hidden_dim, - activation_class=get_activation(config.networks.activation), - ) - decoder = MLP( - out_features=test_env.observation_spec["observation"].shape[-1], - depth=2, - num_cells=config.networks.hidden_dim, - activation_class=get_activation(config.networks.activation), - ) - # if config.env.backend == "dm_control": - # observation_in_key = ("position", "velocity") - # obsevation_out_key = "reco_observation" - # else: - observation_in_key = "observation" - obsevation_out_key = "reco_observation" - - # Make RSSM - rssm_prior = RSSMPrior( - hidden_dim=config.networks.rssm_hidden_dim, - rnn_hidden_dim=config.networks.rssm_hidden_dim, - state_dim=config.networks.state_dim, - action_spec=test_env.action_spec, - ) - rssm_posterior = RSSMPosterior( - hidden_dim=config.networks.rssm_hidden_dim, state_dim=config.networks.state_dim - ) - # Make reward module - reward_module = MLP( - out_features=1, - depth=2, - num_cells=config.networks.hidden_dim, - activation_class=get_activation(config.networks.activation), - ) - - # Make combined world model - world_model = _dreamer_make_world_model( - encoder, - decoder, - rssm_prior, - rssm_posterior, - reward_module, - observation_in_key=observation_in_key, - observation_out_key=obsevation_out_key, - ) - world_model.to(device) - - # Initialize world model - with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - tensordict = ( - test_env.rollout(5, auto_cast_to_device=True) - .unsqueeze(-1) - .to(world_model.device) - ) - tensordict = tensordict.to_tensordict() - world_model(tensordict) - - # Create model-based environment - model_based_env = _dreamer_make_mbenv( - reward_module=reward_module, - rssm_prior=rssm_prior, - decoder=decoder, - observation_out_key=obsevation_out_key, - test_env=test_env, - use_decoder_in_env=use_decoder_in_env, - state_dim=config.networks.state_dim, - rssm_hidden_dim=config.networks.rssm_hidden_dim, - ) - - # Make actor - actor_simulator, actor_realworld = _dreamer_make_actors( - encoder=encoder, - observation_in_key=observation_in_key, - rssm_prior=rssm_prior, - rssm_posterior=rssm_posterior, - mlp_num_units=config.networks.hidden_dim, - activation=get_activation(config.networks.activation), - action_key=action_key, - test_env=test_env, - ) - # Exploration noise to be added to the actor_realworld - actor_realworld = AdditiveGaussianWrapper( - actor_realworld, - sigma_init=1.0, - sigma_end=1.0, - annealing_num_steps=1, - mean=0.0, - std=config.networks.exploration_noise, - ) - - # Make Critic - value_model = _dreamer_make_value_model( - hidden_dim=config.networks.hidden_dim, - activation=config.networks.activation, - value_key=value_key, - ) - - actor_simulator.to(device) - value_model.to(device) - actor_realworld.to(device) - model_based_env.to(device) - - # Initialize model-based environment, actor and critic - with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - tensordict = ( - model_based_env.fake_tensordict().unsqueeze(-1).to(value_model.device) - ) - tensordict = tensordict - tensordict = actor_simulator(tensordict) - value_model(tensordict) - - return world_model, model_based_env, actor_simulator, value_model, actor_realworld - - -def make_collector(cfg, train_env, actor_model_explore): - """Make collector.""" - collector = SyncDataCollector( - train_env, - actor_model_explore, - init_random_frames=cfg.collector.init_random_frames, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device=cfg.collector.device, - reset_at_each_iter=True, - # postproc=ExcludeTransform( - # "belief", "state", ("next", "belief"), ("next", "state"), "encoded_latents" - # ), - ) - collector.set_seed(cfg.env.seed) - - return collector - - -def make_replay_buffer( - batch_size, - batch_seq_len, - buffer_size=1000000, - buffer_scratch_dir=None, - device="cpu", - prefetch=3, - pixel_obs=True, - cast_to_uint8=True, -): - with ( - tempfile.TemporaryDirectory() - if buffer_scratch_dir is None - else nullcontext(buffer_scratch_dir) - ) as scratch_dir: - transforms = [] - crop_seq = RandomCropTensorDict(sub_seq_len=batch_seq_len, sample_dim=-1) - transforms.append(crop_seq) - - if pixel_obs and cast_to_uint8: - # from 0-255 to 0-1 - norm_obs = ObservationNorm( - loc=0, - scale=255, - standard_normal=True, - in_keys=["pixels", ("next", "pixels")], - ) - transforms.append(norm_obs) - - transforms = Compose(*transforms) - - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), - transform=transforms, - batch_size=batch_size, - ) - return replay_buffer - - -def _dreamer_make_value_model( - hidden_dim: int = 400, activation: str = "elu", value_key: str = "state_value" -): - value_model = MLP( - out_features=1, - depth=3, - num_cells=hidden_dim, - activation_class=get_activation(activation), - ) - value_model = SafeProbabilisticTensorDictSequential( - SafeModule( - value_model, - in_keys=["state", "belief"], - out_keys=["loc"], - ), - SafeProbabilisticModule( - in_keys=["loc"], - out_keys=[value_key], - distribution_class=IndependentNormal, - distribution_kwargs={"scale": 1.0, "event_dim": 1}, - ), - ) - - return value_model - - -def _dreamer_make_actors( - encoder, - observation_in_key, - rssm_prior, - rssm_posterior, - mlp_num_units, - activation, - action_key, - test_env, -): - actor_module = DreamerActor( - out_features=test_env.action_spec.shape[-1], - depth=3, - num_cells=mlp_num_units, - activation_class=activation, - ) - actor_simulator = _dreamer_make_actor_sim(action_key, test_env, actor_module) - actor_realworld = _dreamer_make_actor_real( - encoder, - observation_in_key, - rssm_prior, - rssm_posterior, - actor_module, - action_key, - test_env, - ) - return actor_simulator, actor_realworld - - -def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): - actor_simulator = SafeProbabilisticTensorDictSequential( - SafeModule( - actor_module, - in_keys=["state", "belief"], - out_keys=["loc", "scale"], - spec=CompositeSpec( - **{ - "loc": UnboundedContinuousTensorSpec( - proof_environment.action_spec.shape, - device=proof_environment.action_spec.device, - ), - "scale": UnboundedContinuousTensorSpec( - proof_environment.action_spec.shape, - device=proof_environment.action_spec.device, - ), - } - ), - ), - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=[action_key], - default_interaction_type=InteractionType.RANDOM, - distribution_class=TanhNormal, - distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec(**{action_key: proof_environment.action_spec}), - ), - ) - return actor_simulator - - -def _dreamer_make_actor_real( - encoder, - observation_in_key, - rssm_prior, - rssm_posterior, - actor_module, - action_key, - proof_environment, -): - # actor for real world: interacts with states ~ posterior - # Out actor differs from the original paper where first they compute prior and posterior and then act on it - # but we found that this approach worked better. - actor_realworld = SafeSequential( - SafeModule( - encoder, - in_keys=[observation_in_key], - out_keys=["encoded_latents"], - ), - SafeModule( - rssm_posterior, - in_keys=["belief", "encoded_latents"], - out_keys=[ - "_", - "_", - "state", - ], - ), - SafeProbabilisticTensorDictSequential( - SafeModule( - actor_module, - in_keys=["state", "belief"], - out_keys=["loc", "scale"], - spec=CompositeSpec( - **{ - "loc": UnboundedContinuousTensorSpec( - proof_environment.action_spec.shape, - ), - "scale": UnboundedContinuousTensorSpec( - proof_environment.action_spec.shape, - ), - } - ), - ), - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=[action_key], - default_interaction_type=InteractionType.MODE, - distribution_class=TanhNormal, - distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec( - **{action_key: proof_environment.action_spec.to("cpu")} - ), - ), - ), - SafeModule( - rssm_prior, - in_keys=["state", "belief", action_key], - out_keys=[ - "_", - "_", - "_", # we don't need the prior state - ("next", "belief"), - ], - ), - ) - return actor_realworld - - -def _dreamer_make_mbenv( - reward_module, - rssm_prior, - test_env, - decoder, - observation_out_key: str = "reco_pixels", - use_decoder_in_env: bool = False, - state_dim: int = 30, - rssm_hidden_dim: int = 200, -): - # MB environment - if use_decoder_in_env: - mb_env_obs_decoder = SafeModule( - decoder, - in_keys=[("next", "state"), ("next", "belief")], - out_keys=[("next", observation_out_key)], - ) - else: - mb_env_obs_decoder = None - - transition_model = SafeSequential( - SafeModule( - rssm_prior, - in_keys=["state", "belief", "action"], - out_keys=[ - "_", - "_", - "state", - "belief", - ], - ), - ) - - reward_model = SafeProbabilisticTensorDictSequential( - SafeModule( - reward_module, - in_keys=["state", "belief"], - out_keys=["loc"], - ), - SafeProbabilisticModule( - in_keys=["loc"], - out_keys=["reward"], - distribution_class=IndependentNormal, - distribution_kwargs={"scale": 1.0, "event_dim": 1}, - ), - ) - - model_based_env = DreamerEnv( - world_model=WorldModelWrapper( - transition_model, - reward_model, - ), - prior_shape=torch.Size([state_dim]), - belief_shape=torch.Size([rssm_hidden_dim]), - obs_decoder=mb_env_obs_decoder, - ) - - model_based_env.set_specs_from_env(test_env) - return model_based_env - - -def _dreamer_make_world_model( - encoder, - decoder, - rssm_prior, - rssm_posterior, - reward_module, - observation_in_key: str = "pixels", - observation_out_key: str = "reco_pixels", -): - # World Model and reward model - rssm_rollout = RSSMRollout( - SafeModule( - rssm_prior, - in_keys=["state", "belief", "action"], - out_keys=[ - ("next", "prior_mean"), - ("next", "prior_std"), - "_", - ("next", "belief"), - ], - ), - SafeModule( - rssm_posterior, - in_keys=[("next", "belief"), ("next", "encoded_latents")], - out_keys=[ - ("next", "posterior_mean"), - ("next", "posterior_std"), - ("next", "state"), - ], - ), - ) - event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB - decoder = SafeProbabilisticTensorDictSequential( - SafeModule( - decoder, - in_keys=[("next", "state"), ("next", "belief")], - out_keys=["loc"], - ), - SafeProbabilisticModule( - in_keys=["loc"], - out_keys=[("next", observation_out_key)], - distribution_class=IndependentNormal, - distribution_kwargs={"scale": 1.0, "event_dim": event_dim}, - ), - ) - - transition_model = SafeSequential( - SafeModule( - encoder, - in_keys=[("next", observation_in_key)], - out_keys=[("next", "encoded_latents")], - ), - rssm_rollout, - decoder, - ) - - reward_model = SafeProbabilisticTensorDictSequential( - SafeModule( - reward_module, - in_keys=[("next", "state"), ("next", "belief")], - out_keys=[("next", "loc")], - ), - SafeProbabilisticModule( - in_keys=[("next", "loc")], - out_keys=[("next", "reward")], - distribution_class=IndependentNormal, - distribution_kwargs={"scale": 1.0, "event_dim": 1}, - ), - ) - - world_model = WorldModelWrapper( - transition_model, - reward_model, - ) - return world_model - - -def cast_to_uint8(tensordict): - tensordict["pixels"] = (tensordict["pixels"] * 255).to(torch.uint8) - tensordict["next", "pixels"] = (tensordict["next", "pixels"] * 255).to(torch.uint8) - return tensordict - - -def log_metrics(logger, metrics, step): - for metric_name, metric_value in metrics.items(): - logger.log_scalar(metric_name, metric_value, step) - - -def get_activation(name): - if name == "relu": - return nn.ReLU - elif name == "tanh": - return nn.Tanh - elif name == "leaky_relu": - return nn.LeakyReLU - elif name == "elu": - return nn.ELU - else: - raise NotImplementedError diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index e82ec806610..b0e5957943f 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -221,12 +221,12 @@ def make_dreamer( rssm_hidden_dim=config.networks.rssm_hidden_dim, ) - def detach_state_and_belief(data): - data.set("state", data.get("state").detach()) - data.set("belief", data.get("belief").detach()) - return data - - model_based_env = model_based_env.append_transform(detach_state_and_belief) + # def detach_state_and_belief(data): + # data.set("state", data.get("state").detach()) + # data.set("belief", data.get("belief").detach()) + # return data + # + # model_based_env = model_based_env.append_transform(detach_state_and_belief) check_env_specs(model_based_env) # Make actor From c5071b353bfa4cdaa02fdf4e19bbca35ac3bdb55 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:13:18 +0200 Subject: [PATCH 085/113] amend --- sota-implementations/dreamer/dreamer.py | 2 ++ torchrl/objectives/dreamer.py | 8 ++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e8d355b8cb2..32b8b660490 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -82,7 +82,9 @@ def main(cfg: "DictConfig"): # noqa: F821 imagination_horizon=cfg.optimization.imagination_horizon, discount_loss=True, ) + actor_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmda=cfg.optimization.lmda) value_loss = DreamerValueLoss(value_model, discount_loss=True) + value_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmda=cfg.optimization.lmda) # Make collector collector = make_collector(cfg, train_env, policy) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index eea561f21c3..e30a47aa36f 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -264,18 +264,14 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: ) def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: - with torch.no_grad(): - # TODO: I think we need to take the "next" state and "next" beliefs - tensordict = tensordict.select("state", self.tensor_keys.belief) - tensordict = tensordict.reshape(-1) + tensordict = tensordict.select("state", self.tensor_keys.belief).detach() + tensordict = tensordict.reshape(-1) # TODO: do we need exploration here? with timeit("actor_loss/time-rollout"), hold_out_net( self.model_based_env ), set_exploration_type(ExplorationType.RANDOM): - # TODO: we are not using the actual batch beliefs as starting ones - should be solved! took of the primer for the mb_env tensordict = self.model_based_env.reset(tensordict.copy()) - # TODO: do we detach state gradients when passing again for new actions: action = self.actor(state.detach()) fake_data = self.model_based_env.rollout( max_steps=self.imagination_horizon, policy=self.actor_model, From 10173532b1e5d9b9d34297663dc38893a2ae7f21 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:18:50 +0200 Subject: [PATCH 086/113] amend --- sota-implementations/dreamer/dreamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 32b8b660490..a1038b2a90e 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -82,9 +82,9 @@ def main(cfg: "DictConfig"): # noqa: F821 imagination_horizon=cfg.optimization.imagination_horizon, discount_loss=True, ) - actor_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmda=cfg.optimization.lmda) + actor_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda) value_loss = DreamerValueLoss(value_model, discount_loss=True) - value_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmda=cfg.optimization.lmda) + value_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda) # Make collector collector = make_collector(cfg, train_env, policy) From 5ada854d7821e10d9bc9c08ab704f578ac1607f0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:22:00 +0200 Subject: [PATCH 087/113] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 138bf99d182..928b17192a8 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -30,7 +30,7 @@ optimization: free_nats: 3.0 optim_steps_per_batch: 80 gamma: 0.99 - lambda: 0.95 + lmbda: 0.95 imagination_horizon: 15 compile: False use_autocast: True From de70612ff4f6b2011fa440ef6b332a798536f381 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:26:01 +0200 Subject: [PATCH 088/113] amend --- torchrl/objectives/dreamer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index e30a47aa36f..c4d104acc53 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -389,6 +389,7 @@ class _AcceptedKeys: value: NestedKey = "state_value" default_keys = _AcceptedKeys() + default_value_estimator = TDLambdaEstimator def __init__( self, From 0dc1b08c77162fe57f2f860c0ea89cf10e74f8a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:27:13 +0200 Subject: [PATCH 089/113] amend --- sota-implementations/dreamer/dreamer.py | 3 +-- torchrl/objectives/dreamer.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a1038b2a90e..a9e24c588d0 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -83,8 +83,7 @@ def main(cfg: "DictConfig"): # noqa: F821 discount_loss=True, ) actor_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda) - value_loss = DreamerValueLoss(value_model, discount_loss=True) - value_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda) + value_loss = DreamerValueLoss(value_model, discount_loss=True, gamma=cfg.optimization.gamma) # Make collector collector = make_collector(cfg, train_env, policy) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index c4d104acc53..e30a47aa36f 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -389,7 +389,6 @@ class _AcceptedKeys: value: NestedKey = "state_value" default_keys = _AcceptedKeys() - default_value_estimator = TDLambdaEstimator def __init__( self, From 4149a87565afabb651d760f7bdf892ced8b6a175 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:37:02 +0200 Subject: [PATCH 090/113] amend --- sota-implementations/dreamer/config.yaml | 1 + sota-implementations/dreamer/dreamer.py | 9 +++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 928b17192a8..71ab6508c09 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -33,6 +33,7 @@ optimization: lmbda: 0.95 imagination_horizon: 15 compile: False + compile_backend: inductor use_autocast: True networks: diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a9e24c588d0..38e27ccccc5 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -129,6 +129,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.optimization.compile: torchrl_logger.info("Compiling") + backend = cfg.optimization.compile_backend def compile_rssms(module): if isinstance(module, RSSMRollout) and not getattr( @@ -136,10 +137,10 @@ def compile_rssms(module): ): module._compiled = True module.rssm_prior.module = torch.compile( - module.rssm_prior.module, backend="cudagraphs" + module.rssm_prior.module, backend=backend ) module.rssm_posterior.module = torch.compile( - module.rssm_posterior.module, backend="cudagraphs" + module.rssm_posterior.module, backend=backend ) world_model_loss.apply(compile_rssms) @@ -168,10 +169,6 @@ def compile_rssms(module): sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) t_sample = time.time() - t_sample_init - # print("sampled_tensordict", sampled_tensordict) - # print("steps", sampled_tensordict["next", "steps"]) - # print("traj_ids", sampled_tensordict["collector", "traj_ids"]) - t_loss_model_init = time.time() # update world model with torch.autocast( From 9fb63d4ccc008e209ef0363ed43e2a30569429e2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 15 Apr 2024 07:19:33 -0700 Subject: [PATCH 091/113] amend --- sota-implementations/dreamer/dreamer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 38e27ccccc5..6f0368764e1 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import contextlib import time +from torch.profiler import profile, record_function, ProfilerActivity import hydra import torch @@ -163,7 +164,7 @@ def compile_rssms(module): t_loss_critic = 0.0 t_loss_model = 0.0 - for _ in range(optim_steps_per_batch): + for k in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) @@ -174,7 +175,7 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16, - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(), (profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if (i == 1 and k == 1) else contextlib.nullcontext()) as prof: model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -188,6 +189,8 @@ def compile_rssms(module): torch.bfloat16, torch.float16, ), model_loss_td + if (i == 1 and k == 1): + prof.export_chrome_trace("trace_world_model.json") world_model_opt.zero_grad() if use_autocast: @@ -207,9 +210,12 @@ def compile_rssms(module): t_loss_actor_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(), (profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if (i == 1 and k == 1) else contextlib.nullcontext()) as prof: actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) + if (i == 1 and k == 1): + prof.export_chrome_trace("trace_actor.json") + actor_opt.zero_grad() if use_autocast: scaler2.scale(actor_loss_td["loss_actor"]).backward() @@ -228,9 +234,12 @@ def compile_rssms(module): t_loss_critic_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(), (profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if (i == 1 and k == 1) else contextlib.nullcontext()) as prof: value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) + if (i == 1 and k == 1): + prof.export_chrome_trace("trace_critic.json") + value_opt.zero_grad() if use_autocast: scaler3.scale(value_loss_td["loss_value"]).backward() From 55437617a16e10f209bed66b2b76f545fdec984d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 15:54:14 +0100 Subject: [PATCH 092/113] amend --- sota-implementations/dreamer/dreamer.py | 8 +- torchrl/envs/transforms/transforms.py | 2 + torchrl/objectives/dreamer.py | 137 ++++++++++++++++++------ 3 files changed, 112 insertions(+), 35 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 38e27ccccc5..db11ec3af7f 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -82,8 +82,12 @@ def main(cfg: "DictConfig"): # noqa: F821 imagination_horizon=cfg.optimization.imagination_horizon, discount_loss=True, ) - actor_loss.make_value_estimator(gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda) - value_loss = DreamerValueLoss(value_model, discount_loss=True, gamma=cfg.optimization.gamma) + actor_loss.make_value_estimator( + gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda + ) + value_loss = DreamerValueLoss( + value_model, discount_loss=True, gamma=cfg.optimization.gamma + ) # Make collector collector = make_collector(cfg, train_env, policy) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 791a0615628..5ee9929ea36 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3494,10 +3494,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "this functionality is not covered. Consider passing the in_keys " "or not passing any out_keys." ) + def func(item): if item.dtype == self.dtype_in: item = self._apply_transform(item) return item + tensordict = tensordict._fast_apply(func) else: # we made sure that if in_keys is not None, out_keys is not None either diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index e30a47aa36f..90e99379b72 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -15,11 +15,11 @@ from torchrl._utils import timeit from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp -from torchrl.modules import IndependentNormal from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, + distance_loss, # distance_loss, hold_out_net, ValueEstimators, @@ -120,8 +120,8 @@ def __init__( def _forward_value_estimator_keys(self, **kwargs) -> None: pass - def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: - tensordict = tensordict.copy() + def forward(self, tensordict: TensorDict) -> torch.Tensor: + tensordict = tensordict.clone(recurse=False) tensordict.rename_key_( ("next", self.tensor_keys.reward), ("next", self.tensor_keys.true_reward), @@ -133,25 +133,59 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict.get(("next", self.tensor_keys.prior_std)), tensordict.get(("next", self.tensor_keys.posterior_mean)), tensordict.get(("next", self.tensor_keys.posterior_std)), + ).unsqueeze(-1) + reco_loss = distance_loss( + tensordict.get(("next", self.tensor_keys.pixels)), + tensordict.get(("next", self.tensor_keys.reco_pixels)), + self.reco_loss, ) - - dist: IndependentNormal = self.decoder.get_dist(tensordict) - reco_loss = -dist.log_prob( - tensordict.get(("next", self.tensor_keys.pixels)) - ).mean() - # x = tensordict.get(("next", self.tensor_keys.pixels)) - # loc = dist.base_dist.loc - # scale = dist.base_dist.scale - # reco_loss = -self.normal_log_probability(x, loc, scale).mean() - - dist: IndependentNormal = self.reward_model.get_dist(tensordict) - reward_loss = -dist.log_prob( - tensordict.get(("next", self.tensor_keys.true_reward)) - ).mean() - # x = tensordict.get(("next", self.tensor_keys.true_reward)) - # loc = dist.base_dist.loc - # scale = dist.base_dist.scale - # reward_loss = -self.normal_log_probability(x, loc, scale).mean() + if not self.global_average: + reco_loss = reco_loss.sum((-3, -2, -1)) + reco_loss = reco_loss.mean().unsqueeze(-1) + + reward_loss = distance_loss( + tensordict.get(("next", self.tensor_keys.true_reward)), + tensordict.get(("next", self.tensor_keys.reward)), + self.reward_loss, + ) + if not self.global_average: + reward_loss = reward_loss.squeeze(-1) + reward_loss = reward_loss.mean().unsqueeze(-1) + # import ipdb; ipdb.set_trace() + + # Alternative: + # def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: + # tensordict = tensordict.copy() + # tensordict.rename_key_( + # ("next", self.tensor_keys.reward), + # ("next", self.tensor_keys.true_reward), + # ) + # tensordict = self.world_model(tensordict) + # # compute model loss + # kl_loss = self.kl_loss( + # tensordict.get(("next", self.tensor_keys.prior_mean)), + # tensordict.get(("next", self.tensor_keys.prior_std)), + # tensordict.get(("next", self.tensor_keys.posterior_mean)), + # tensordict.get(("next", self.tensor_keys.posterior_std)), + # ) + # + # dist: IndependentNormal = self.decoder.get_dist(tensordict) + # reco_loss = -dist.log_prob( + # tensordict.get(("next", self.tensor_keys.pixels)) + # ).mean() + # # x = tensordict.get(("next", self.tensor_keys.pixels)) + # # loc = dist.base_dist.loc + # # scale = dist.base_dist.scale + # # reco_loss = -self.normal_log_probability(x, loc, scale).mean() + # + # dist: IndependentNormal = self.reward_model.get_dist(tensordict) + # reward_loss = -dist.log_prob( + # tensordict.get(("next", self.tensor_keys.true_reward)) + # ).mean() + # # x = tensordict.get(("next", self.tensor_keys.true_reward)) + # # loc = dist.base_dist.loc + # # scale = dist.base_dist.scale + # # reward_loss = -self.normal_log_probability(x, loc, scale).mean() return ( TensorDict( @@ -184,7 +218,13 @@ def kl_loss( / (2 * prior_std**2) - 0.5 ) - return kl.clamp_min(self.free_nats).sum(-1).mean() + if not self.global_average: + kl = kl.sum(-1) + if self.delayed_clamp: + kl = kl.mean().clamp_min(self.free_nats) + else: + kl = kl.clamp_min(self.free_nats).mean() + return kl class DreamerActorLoss(LossModule): @@ -293,9 +333,9 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: discount = gamma.expand(lambda_target.shape).clone() discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) - actor_loss = -(lambda_target * discount).mean() + actor_loss = -(lambda_target * discount).sum((-2, -1)).mean() else: - actor_loss = -lambda_target.mean() + actor_loss = -lambda_target.sum((-2, -1)).mean() loss_tensordict = TensorDict({"loss_actor": actor_loss}, []) return loss_tensordict, fake_data.detach() @@ -344,7 +384,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator( **hp, value_network=value_net, - vectorized=False, # TODO: vectorized version seems not to be similar to the non vectoried + vectorized=True, # TODO: vectorized version seems not to be similar to the non vectorised ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -407,21 +447,52 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: pass def forward(self, fake_data) -> torch.Tensor: + # lambda_target = fake_data.get("lambda_target") + # # TODO: I think this should be next state and belief + # td = fake_data.select(("next", "state"), ("next", "belief")) + # td = td.rename_key_(("next", "state"), "state") + # tensordict_select = td.rename_key_(("next", "belief"), "belief") + # # tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) + # dist = self.value_model.get_dist(tensordict_select) + # if self.discount_loss: + # discount = self.gamma * torch.ones_like( + # lambda_target, device=lambda_target.device + # ) + # discount[..., 0, :] = 1 + # discount = discount.cumprod(dim=-2) + # value_loss = -(discount * dist.log_prob(lambda_target).unsqueeze(-1)).mean() + # else: + # value_loss = -dist.log_prob(lambda_target).mean() lambda_target = fake_data.get("lambda_target") - # TODO: I think this should be next state and belief - td = fake_data.select(("next", "state"), ("next", "belief")) - td = td.rename_key_(("next", "state"), "state") - tensordict_select = td.rename_key_(("next", "belief"), "belief") - # tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) - dist = self.value_model.get_dist(tensordict_select) + tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) + self.value_model(tensordict_select) if self.discount_loss: discount = self.gamma * torch.ones_like( lambda_target, device=lambda_target.device ) discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) - value_loss = -(discount * dist.log_prob(lambda_target).unsqueeze(-1)).mean() + value_loss = ( + ( + discount + * distance_loss( + tensordict_select.get(self.tensor_keys.value), + lambda_target, + self.value_loss, + ) + ) + .sum((-1, -2)) + .mean() + ) else: - value_loss = -dist.log_prob(lambda_target).mean() + value_loss = ( + distance_loss( + tensordict_select.get(self.tensor_keys.value), + lambda_target, + self.value_loss, + ) + .sum((-1, -2)) + .mean() + ) loss_tensordict = TensorDict({"loss_value": value_loss}, []) return loss_tensordict, fake_data From 1f78519ed2d8a41bf4c6850a725bab538c7090bb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 13:07:23 +0100 Subject: [PATCH 093/113] amend --- sota-implementations/dreamer/config.yaml | 13 +++- sota-implementations/dreamer/dreamer.py | 66 +++++++++++------ sota-implementations/dreamer/dreamer_utils.py | 73 ++++++++++++------- torchrl/envs/__init__.py | 2 +- torchrl/envs/model_based/__init__.py | 1 + torchrl/modules/__init__.py | 1 + torchrl/modules/models/__init__.py | 9 ++- 7 files changed, 110 insertions(+), 55 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 71ab6508c09..92bd0044454 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -9,13 +9,17 @@ env: image_size : 64 horizon: 500 n_parallel_envs: 8 - device: null + device: + _target_: dreamer_utils._default_device + device: null collector: total_frames: 5_000_000 init_random_frames: 3000 frames_per_batch: 1000 - device: cuda:0 + device: + _target_: dreamer_utils._default_device + device: null optimization: train_every: 1000 @@ -38,7 +42,9 @@ optimization: networks: exploration_noise: 0.3 - device: cuda:0 + device: + _target_: dreamer_utils._default_device + device: null state_dim: 30 rssm_hidden_dim: 200 hidden_dim: 400 @@ -58,3 +64,4 @@ logger: # eval interval, in collection counts eval_iter: 10 eval_rollout_steps: 500 + video: False diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 507f4fc4ae3..04d9cc625be 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -4,26 +4,28 @@ # LICENSE file in the root directory of this source tree. import contextlib import time -from torch.profiler import profile, record_function, ProfilerActivity import hydra import torch import torch.cuda import tqdm from dreamer_utils import ( + dump_video, log_metrics, make_collector, make_dreamer, make_environments, make_replay_buffer, ) +from hydra.utils import instantiate # mixed precision training from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ +from torch.profiler import profile, ProfilerActivity, record_function from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules.models.model_based import RSSMRollout +from torchrl.modules import RSSMRollout from torchrl.objectives.dreamer import ( DreamerActorLoss, @@ -37,12 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # cfg = correct_for_frame_skip(cfg) - if torch.cuda.is_available() and cfg.networks.device in (None, ""): - device = torch.device("cuda:0") - elif cfg.networks.device: - device = torch.device(cfg.networks.device) - else: - device = torch.device("cpu") + device = torch.device(instantiate(cfg.networks.device)) # Create logger exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) @@ -56,7 +53,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) train_env, test_env = make_environments( - cfg=cfg, parallel_envs=cfg.env.n_parallel_envs + cfg=cfg, + parallel_envs=cfg.env.n_parallel_envs, + logger=logger, ) # Make dreamer components @@ -101,7 +100,7 @@ def main(cfg: "DictConfig"): # noqa: F821 batch_seq_len=batch_length, buffer_size=cfg.replay_buffer.buffer_size, buffer_scratch_dir=cfg.replay_buffer.scratch_dir, - device=cfg.networks.device, + device=device, pixel_obs=cfg.env.from_pixels, grayscale=cfg.env.grayscale, image_size=cfg.env.image_size, @@ -179,7 +178,11 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16, - ) if use_autocast else contextlib.nullcontext(), (profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if (i == 1 and k == 1) else contextlib.nullcontext()) as prof: + ) if use_autocast else contextlib.nullcontext(), ( + profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) + if (i == 1 and k == 1) + else contextlib.nullcontext() + ) as prof: model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -188,12 +191,7 @@ def compile_rssms(module): + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) - if use_autocast: - assert loss_world_model.dtype in ( - torch.bfloat16, - torch.float16, - ), model_loss_td - if (i == 1 and k == 1): + if i == 1 and k == 1: prof.export_chrome_trace("trace_world_model.json") world_model_opt.zero_grad() @@ -214,10 +212,14 @@ def compile_rssms(module): t_loss_actor_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(), (profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if (i == 1 and k == 1) else contextlib.nullcontext()) as prof: + ) if use_autocast else contextlib.nullcontext(), ( + profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) + if (i == 1 and k == 1) + else contextlib.nullcontext() + ) as prof: actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) - if (i == 1 and k == 1): + if i == 1 and k == 1: prof.export_chrome_trace("trace_actor.json") actor_opt.zero_grad() @@ -238,10 +240,14 @@ def compile_rssms(module): t_loss_critic_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(), (profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if (i == 1 and k == 1) else contextlib.nullcontext()) as prof: + ) if use_autocast else contextlib.nullcontext(), ( + profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) + if (i == 1 and k == 1) + else contextlib.nullcontext() + ) as prof: value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) - if (i == 1 and k == 1): + if i == 1 and k == 1: prof.export_chrome_trace("trace_critic.json") value_opt.zero_grad() @@ -287,17 +293,33 @@ def compile_rssms(module): collector.update_policy_weights_() # Evaluation if (i % eval_iter) == 0: + # Real env with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_rollout = test_env.rollout( eval_rollout_steps, policy, - auto_cast_to_device=True, + # auto_cast_to_device=True, break_when_any_done=True, ) + test_env.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() eval_metrics = {"eval/reward": eval_reward} if logger is not None: log_metrics(logger, eval_metrics, collected_frames) + # Simulated env + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_rollout = test_env.rollout( + eval_rollout_steps, + policy, + # auto_cast_to_device=True, + break_when_any_done=True, + ) + test_env.apply(dump_video) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_metrics = {"eval/simulated_reward": eval_reward} + if logger is not None: + log_metrics(logger, eval_metrics, collected_frames) + t_collect_init = time.time() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index b0e5957943f..88030bf9242 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from hydra.utils import instantiate from tensordict.nn import ( InteractionType, ProbabilisticTensorDictModule, @@ -17,54 +18,52 @@ TensorDictSequential, ) from torchrl.collectors import SyncDataCollector -from torchrl.data import SliceSampler, TensorDictReplayBuffer -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data import LazyMemmapStorage, SliceSampler, TensorDictReplayBuffer from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec -from torchrl.envs import ParallelEnv -from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.model_based.dreamer import DreamerEnv -from torchrl.envs.transforms import ( +from torchrl.envs import ( Compose, + DeviceCastTransform, + DMControlEnv, DoubleToFloat, + DreamerEnv, + DTypeCastTransform, + EnvCreator, + ExcludeTransform, # ExcludeTransform, FrameSkipTransform, GrayScale, + GymEnv, + ParallelEnv, + RenameTransform, Resize, RewardSum, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.transforms.transforms import ( - DeviceCastTransform, - DTypeCastTransform, - ExcludeTransform, - RenameTransform, + set_gym_backend, StepCounter, TensorDictPrimer, + ToTensorImage, + TransformedEnv, ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( - MLP, - SafeModule, - SafeProbabilisticModule, - SafeProbabilisticTensorDictSequential, - SafeSequential, -) -from torchrl.modules.distributions import IndependentNormal, TanhNormal -from torchrl.modules.models.model_based import ( + AdditiveGaussianWrapper, DreamerActor, + IndependentNormal, + MLP, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior, RSSMRollout, + SafeModule, + SafeProbabilisticModule, + SafeProbabilisticTensorDictSequential, + SafeSequential, + TanhNormal, + WorldModelWrapper, ) -from torchrl.modules.tensordict_module.exploration import AdditiveGaussianWrapper -from torchrl.modules.tensordict_module.world_models import WorldModelWrapper +from torchrl.record import VideoRecorder def _make_env(cfg, device): @@ -115,7 +114,7 @@ def transform_env(cfg, env): return env -def make_environments(cfg, parallel_envs=1): +def make_environments(cfg, parallel_envs=1, logger=None): """Make environments for training and evaluation.""" func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device) train_env = ParallelEnv( @@ -125,6 +124,9 @@ def make_environments(cfg, parallel_envs=1): ) train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) + func = functools.partial( + _make_env, cfg=cfg, device=cfg.env.device, from_pixels=cfg.logger.video + ) eval_env = ParallelEnv( 1, EnvCreator(func), @@ -132,11 +134,18 @@ def make_environments(cfg, parallel_envs=1): ) eval_env = transform_env(cfg, eval_env) eval_env.set_seed(cfg.env.seed + 1) + if cfg.logger.video: + eval_env.insert_transform(0, VideoRecorder(logger, tag="eval/video")) check_env_specs(train_env) check_env_specs(eval_env) return train_env, eval_env +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def make_dreamer( config, device, @@ -282,7 +291,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - policy_device=cfg.collector.device, + policy_device=instantiate(cfg.collector.device), env_device=train_env.device, storing_device="cpu", ) @@ -665,3 +674,11 @@ def get_activation(name): return nn.ELU else: raise NotImplementedError + + +def _default_device(device=None): + if device in ("", None): + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + return torch.device(device) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 2b02ba2feea..d5e08362904 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -34,7 +34,7 @@ VmasEnv, VmasWrapper, ) -from .model_based import ModelBasedEnvBase +from .model_based import DreamerEnv, ModelBasedEnvBase from .transforms import ( ActionMask, BatchSizeTransform, diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index 5f628079173..911cced5a9c 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -4,3 +4,4 @@ # LICENSE file in the root directory of this source tree. from .common import ModelBasedEnvBase +from .dreamer import DreamerEnv diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8f20f53fe5b..a987e701672 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -42,6 +42,7 @@ reset_noise, RSSMPosterior, RSSMPrior, + RSSMRollout, Squeeze2dLayer, SqueezeLayer, VDNMixer, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 7e8ace40dcd..7b11cae9515 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -8,7 +8,14 @@ from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise -from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior +from .model_based import ( + DreamerActor, + ObsDecoder, + ObsEncoder, + RSSMPosterior, + RSSMPrior, + RSSMRollout, +) from .models import ( Conv2dNet, Conv3dNet, From e9b6ebc85fcead7742828304ed33e20e493f8d68 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 20:54:01 +0100 Subject: [PATCH 094/113] amend --- sota-implementations/dreamer/config.yaml | 9 +- sota-implementations/dreamer/dreamer.py | 54 +++++++---- sota-implementations/dreamer/dreamer_utils.py | 95 +++++++++++++------ torchrl/envs/__init__.py | 2 +- torchrl/envs/model_based/__init__.py | 2 +- torchrl/envs/model_based/dreamer.py | 12 +++ 6 files changed, 118 insertions(+), 56 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 92bd0044454..ab101e8486a 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -24,8 +24,6 @@ collector: optimization: train_every: 1000 grad_clip: 100 - batch_size: 2500 - batch_length: 50 world_model_lr: 6e-4 actor_lr: 8e-5 @@ -52,9 +50,10 @@ networks: replay_buffer: - buffer_size: 20000 - batch_size: 50 - scratch_dir: ${logger.exp_name}_${env.seed} + batch_size: 2500 + buffer_size: 1000000 + batch_length: 50 + scratch_dir: null logger: backend: wandb diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 04d9cc625be..ce5f5e0090e 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -61,14 +61,21 @@ def main(cfg: "DictConfig"): # noqa: F821 # Make dreamer components action_key = "action" value_key = "state_value" - world_model, model_based_env, actor_model, value_model, policy = make_dreamer( - config=cfg, + ( + world_model, + model_based_env, + model_based_env_eval, + actor_model, + value_model, + policy, + ) = make_dreamer( + cfg=cfg, device=device, action_key=action_key, value_key=value_key, - use_decoder_in_env=False, + use_decoder_in_env=cfg.logger.video, + logger=logger, ) - # Losses world_model_loss = DreamerModelLoss(world_model) # Adapt loss keys to gym backend @@ -82,6 +89,7 @@ def main(cfg: "DictConfig"): # noqa: F821 imagination_horizon=cfg.optimization.imagination_horizon, discount_loss=True, ) + actor_loss.make_value_estimator( gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda ) @@ -93,13 +101,15 @@ def main(cfg: "DictConfig"): # noqa: F821 collector = make_collector(cfg, train_env, policy) # Make replay buffer - batch_size = cfg.optimization.batch_size - batch_length = cfg.optimization.batch_length + batch_size = cfg.replay_buffer.batch_size + batch_length = cfg.replay_buffer.batch_length + buffer_size = cfg.replay_buffer.buffer_size + scratch_dir = cfg.replay_buffer.scratch_dir replay_buffer = make_replay_buffer( batch_size=batch_size, batch_seq_len=batch_length, - buffer_size=cfg.replay_buffer.buffer_size, - buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + buffer_size=buffer_size, + buffer_scratch_dir=scratch_dir, device=device, pixel_obs=cfg.env.from_pixels, grayscale=cfg.env.grayscale, @@ -132,6 +142,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_rollout_steps = cfg.logger.eval_rollout_steps if cfg.optimization.compile: + torch._dynamo.config.capture_scalar_outputs = True + torchrl_logger.info("Compiling") backend = cfg.optimization.compile_backend @@ -166,7 +178,6 @@ def compile_rssms(module): t_loss_actor = 0.0 t_loss_critic = 0.0 t_loss_model = 0.0 - for k in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() @@ -307,18 +318,19 @@ def compile_rssms(module): if logger is not None: log_metrics(logger, eval_metrics, collected_frames) # Simulated env - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): - eval_rollout = test_env.rollout( - eval_rollout_steps, - policy, - # auto_cast_to_device=True, - break_when_any_done=True, - ) - test_env.apply(dump_video) - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - eval_metrics = {"eval/simulated_reward": eval_reward} - if logger is not None: - log_metrics(logger, eval_metrics, collected_frames) + if model_based_env_eval is not None: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_rollout = model_based_env_eval.rollout( + eval_rollout_steps, + policy, + # auto_cast_to_device=True, + break_when_any_done=True, + ) + model_based_env_eval.apply(dump_video) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_metrics = {"eval/simulated_reward": eval_reward} + if logger is not None: + log_metrics(logger, eval_metrics, collected_frames) t_collect_init = time.time() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 88030bf9242..b11dfed9e9f 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -18,15 +18,21 @@ TensorDictSequential, ) from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyMemmapStorage, SliceSampler, TensorDictReplayBuffer -from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data import ( + CompositeSpec, + LazyMemmapStorage, + SliceSampler, + TensorDictReplayBuffer, + UnboundedContinuousTensorSpec, +) from torchrl.envs import ( Compose, DeviceCastTransform, DMControlEnv, DoubleToFloat, + DreamerDecoder, DreamerEnv, DTypeCastTransform, EnvCreator, @@ -66,16 +72,23 @@ from torchrl.record import VideoRecorder -def _make_env(cfg, device): +def _make_env(cfg, device, from_pixels=False): lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): env = GymEnv( cfg.env.name, device=device, + from_pixels=cfg.env.from_pixels or from_pixels, + pixels_only=cfg.env.from_pixels, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels) + env = DMControlEnv( + cfg.env.name, + cfg.env.task, + from_pixels=cfg.env.from_pixels or from_pixels, + pixels_only=cfg.env.from_pixels, + ) else: raise NotImplementedError(f"Unknown lib {lib}.") default_dict = { @@ -147,17 +160,18 @@ def dump_video(module): def make_dreamer( - config, + cfg, device, action_key: str = "action", value_key: str = "state_value", use_decoder_in_env: bool = False, compile: bool = True, + logger=None, ): - test_env = _make_env(config, device="cpu") - test_env = transform_env(config, test_env) + test_env = _make_env(cfg, device="cpu") + test_env = transform_env(cfg, test_env) # Make encoder and decoder - if config.env.from_pixels: + if cfg.env.from_pixels: encoder = ObsEncoder() decoder = ObsDecoder() observation_in_key = "pixels" @@ -166,34 +180,34 @@ def make_dreamer( encoder = MLP( out_features=1024, depth=2, - num_cells=config.networks.hidden_dim, - activation_class=get_activation(config.networks.activation), + num_cells=cfg.networks.hidden_dim, + activation_class=get_activation(cfg.networks.activation), ) decoder = MLP( out_features=test_env.observation_spec["observation"].shape[-1], depth=2, - num_cells=config.networks.hidden_dim, - activation_class=get_activation(config.networks.activation), + num_cells=cfg.networks.hidden_dim, + activation_class=get_activation(cfg.networks.activation), ) observation_in_key = "observation" obsevation_out_key = "reco_observation" # Make RSSM rssm_prior = RSSMPrior( - hidden_dim=config.networks.rssm_hidden_dim, - rnn_hidden_dim=config.networks.rssm_hidden_dim, - state_dim=config.networks.state_dim, + hidden_dim=cfg.networks.rssm_hidden_dim, + rnn_hidden_dim=cfg.networks.rssm_hidden_dim, + state_dim=cfg.networks.state_dim, action_spec=test_env.action_spec, ) rssm_posterior = RSSMPosterior( - hidden_dim=config.networks.rssm_hidden_dim, state_dim=config.networks.state_dim + hidden_dim=cfg.networks.rssm_hidden_dim, state_dim=cfg.networks.state_dim ) # Make reward module reward_module = MLP( out_features=1, depth=2, - num_cells=config.networks.hidden_dim, - activation_class=get_activation(config.networks.activation), + num_cells=cfg.networks.hidden_dim, + activation_class=get_activation(cfg.networks.activation), ) # Make combined world model @@ -226,8 +240,8 @@ def make_dreamer( observation_out_key=obsevation_out_key, test_env=test_env, use_decoder_in_env=use_decoder_in_env, - state_dim=config.networks.state_dim, - rssm_hidden_dim=config.networks.rssm_hidden_dim, + state_dim=cfg.networks.state_dim, + rssm_hidden_dim=cfg.networks.rssm_hidden_dim, ) # def detach_state_and_belief(data): @@ -244,8 +258,8 @@ def make_dreamer( observation_in_key=observation_in_key, rssm_prior=rssm_prior, rssm_posterior=rssm_posterior, - mlp_num_units=config.networks.hidden_dim, - activation=get_activation(config.networks.activation), + mlp_num_units=cfg.networks.hidden_dim, + activation=get_activation(cfg.networks.activation), action_key=action_key, test_env=test_env, ) @@ -256,13 +270,13 @@ def make_dreamer( sigma_end=1.0, annealing_num_steps=1, mean=0.0, - std=config.networks.exploration_noise, + std=cfg.networks.exploration_noise, ) # Make Critic value_model = _dreamer_make_value_model( - hidden_dim=config.networks.hidden_dim, - activation=config.networks.activation, + hidden_dim=cfg.networks.hidden_dim, + activation=cfg.networks.activation, value_key=value_key, ) @@ -280,7 +294,32 @@ def make_dreamer( tensordict = actor_simulator(tensordict) value_model(tensordict) - return world_model, model_based_env, actor_simulator, value_model, actor_realworld + if cfg.logger.video: + model_based_env_eval = model_based_env.append_transform(DreamerDecoder()) + + def float_to_int(data): + reco_pixels = data.get("reco_pixels") * 255 + # assert (reco_pixels < 256).all() and (reco_pixels > 0).all(), (reco_pixels.min(), reco_pixels.max()) + reco_pixels = reco_pixels.to(torch.uint8) + return data.set("reco_pixels", reco_pixels) + + model_based_env_eval.append_transform(float_to_int) + model_based_env_eval.append_transform( + VideoRecorder( + logger=logger, tag="eval/simulated_rendering", in_keys=["reco_pixels"] + ) + ) + + else: + model_based_env_eval = None + return ( + world_model, + model_based_env, + model_based_env_eval, + actor_simulator, + value_model, + actor_realworld, + ) def make_collector(cfg, train_env, actor_model_explore): @@ -534,8 +573,8 @@ def _dreamer_make_mbenv( if use_decoder_in_env: mb_env_obs_decoder = SafeModule( decoder, - in_keys=[("next", "state"), ("next", "belief")], - out_keys=[("next", observation_out_key)], + in_keys=["state", "belief"], + out_keys=[observation_out_key], ) else: mb_env_obs_decoder = None diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index fc0637027a8..8b55ec089eb 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -34,7 +34,7 @@ VmasEnv, VmasWrapper, ) -from .model_based import DreamerEnv, ModelBasedEnvBase +from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase from .transforms import ( ActionMask, AutoResetEnv, diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index 911cced5a9c..437146a4909 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. from .common import ModelBasedEnvBase -from .dreamer import DreamerEnv +from .dreamer import DreamerDecoder, DreamerEnv diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 5f17fede18a..0aea4009b9c 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -14,6 +14,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase +from torchrl.envs.transforms.transforms import Transform class DreamerEnv(ModelBasedEnvBase): @@ -71,3 +72,14 @@ def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDic if compute_latents: tensordict = self.world_model(tensordict) return self.obs_decoder(tensordict) + + +class DreamerDecoder(Transform): + def _call(self, tensordict): + return self.parent.base_env.obs_decoder(tensordict) + + def _reset(self, tensordict, tensordict_reset): + return self._call(tensordict_reset) + + def transform_observation_spec(self, observation_spec): + return observation_spec From d8816137149ee2504340d876e13548623c70dfac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 20:59:34 +0100 Subject: [PATCH 095/113] amend --- docs/source/reference/envs.rst | 1 + torchrl/envs/model_based/dreamer.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index f24b31c71d3..969799f94af 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -799,6 +799,7 @@ Domain-specific ModelBasedEnvBase model_based.dreamer.DreamerEnv + model_based.dreamer.DreamerDecoder Libraries diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 0aea4009b9c..f44c4aa025c 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -75,6 +75,13 @@ def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDic class DreamerDecoder(Transform): + """A transform to record the decoded observations in Dreamer. + + Examples: + >>> model_based_env = DreamerEnv(...) + >>> model_based_env_eval = model_based_env.append_transform(DreamerDecoder()) + """ + def _call(self, tensordict): return self.parent.base_env.obs_decoder(tensordict) From 449f9623ca91a9a6bb4c886f2556ca83f9242a2c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 08:30:05 +0100 Subject: [PATCH 096/113] amend --- sota-implementations/dreamer/dreamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index ce5f5e0090e..3d840a24ce7 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -309,7 +309,7 @@ def compile_rssms(module): eval_rollout = test_env.rollout( eval_rollout_steps, policy, - # auto_cast_to_device=True, + auto_cast_to_device=True, break_when_any_done=True, ) test_env.apply(dump_video) @@ -323,7 +323,7 @@ def compile_rssms(module): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy, - # auto_cast_to_device=True, + auto_cast_to_device=True, break_when_any_done=True, ) model_based_env_eval.apply(dump_video) From b2473aab28d59e155eb974ba4dc240c107bb040b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 12:24:31 +0100 Subject: [PATCH 097/113] amend --- sota-implementations/dreamer/dreamer.py | 2 ++ sota-implementations/dreamer/dreamer_utils.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 3d840a24ce7..2ac1345184a 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -325,6 +325,8 @@ def compile_rssms(module): policy, auto_cast_to_device=True, break_when_any_done=True, + auto_reset=False, + tensordict=test_env._step_mdp(eval_rollout[..., -1]) ) model_based_env_eval.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index b11dfed9e9f..9f82b1b1722 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -175,7 +175,7 @@ def make_dreamer( encoder = ObsEncoder() decoder = ObsDecoder() observation_in_key = "pixels" - obsevation_out_key = "reco_pixels" + observation_out_key = "reco_pixels" else: encoder = MLP( out_features=1024, @@ -190,7 +190,7 @@ def make_dreamer( activation_class=get_activation(cfg.networks.activation), ) observation_in_key = "observation" - obsevation_out_key = "reco_observation" + observation_out_key = "reco_observation" # Make RSSM rssm_prior = RSSMPrior( @@ -218,7 +218,7 @@ def make_dreamer( rssm_posterior, reward_module, observation_in_key=observation_in_key, - observation_out_key=obsevation_out_key, + observation_out_key=observation_out_key, ) world_model.to(device) @@ -237,7 +237,7 @@ def make_dreamer( reward_module=reward_module, rssm_prior=rssm_prior, decoder=decoder, - observation_out_key=obsevation_out_key, + observation_out_key=observation_out_key, test_env=test_env, use_decoder_in_env=use_decoder_in_env, state_dim=cfg.networks.state_dim, @@ -298,9 +298,11 @@ def make_dreamer( model_based_env_eval = model_based_env.append_transform(DreamerDecoder()) def float_to_int(data): - reco_pixels = data.get("reco_pixels") * 255 + reco_pixels_float = data.get("reco_pixels") + reco_pixels = (reco_pixels_float * 255).floor() # assert (reco_pixels < 256).all() and (reco_pixels > 0).all(), (reco_pixels.min(), reco_pixels.max()) reco_pixels = reco_pixels.to(torch.uint8) + data.set("reco_pixels_float", reco_pixels_float) return data.set("reco_pixels", reco_pixels) model_based_env_eval.append_transform(float_to_int) @@ -378,10 +380,6 @@ def check_no_pixels(data): Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) ) transforms.append(DeviceCastTransform(device=device)) - if use_autocast: - transforms.append( - DTypeCastTransform(dtype_in=torch.float32, dtype_out=torch.bfloat16) - ) replay_buffer = TensorDictReplayBuffer( pin_memory=False, From 6d8e006b2707ce2933f1fefd6f10d553025b2890 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 12:33:13 +0100 Subject: [PATCH 098/113] amend --- sota-implementations/dreamer/dreamer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 2ac1345184a..2348f2f7ba0 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -326,7 +326,7 @@ def compile_rssms(module): auto_cast_to_device=True, break_when_any_done=True, auto_reset=False, - tensordict=test_env._step_mdp(eval_rollout[..., -1]) + tensordict=test_env._step_mdp(eval_rollout[..., -1]).to(device) ) model_based_env_eval.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() From 0a86244151b07af5b991c4f25a3448ae62b6ed16 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 13:46:49 +0100 Subject: [PATCH 099/113] Update torchrl/objectives/dreamer.py --- torchrl/objectives/dreamer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 90e99379b72..f3dcd6f6d2c 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -307,7 +307,6 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: tensordict = tensordict.select("state", self.tensor_keys.belief).detach() tensordict = tensordict.reshape(-1) - # TODO: do we need exploration here? with timeit("actor_loss/time-rollout"), hold_out_net( self.model_based_env ), set_exploration_type(ExplorationType.RANDOM): From dac6a3696f8eedf93af29fc41d64403642448a1d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 14:45:59 +0100 Subject: [PATCH 100/113] lint --- sota-implementations/dreamer/dreamer.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 2348f2f7ba0..4b80c9671f1 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -326,7 +326,7 @@ def compile_rssms(module): auto_cast_to_device=True, break_when_any_done=True, auto_reset=False, - tensordict=test_env._step_mdp(eval_rollout[..., -1]).to(device) + tensordict=test_env._step_mdp(eval_rollout[..., -1]).to(device), ) model_based_env_eval.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 9f82b1b1722..537f1ba99e9 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -10,6 +10,7 @@ import torch.nn as nn from hydra.utils import instantiate +from tensordict import NestedKey from tensordict.nn import ( InteractionType, ProbabilisticTensorDictModule, @@ -624,8 +625,8 @@ def _dreamer_make_world_model( rssm_prior, rssm_posterior, reward_module, - observation_in_key: str = "pixels", - observation_out_key: str = "reco_pixels", + observation_in_key: NestedKey = "pixels", + observation_out_key: NestedKey = "reco_pixels", ): # World Model and reward model rssm_rollout = RSSMRollout( From 7441f7930588bd69ca08006ee94e85f7f8c8654f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Apr 2024 18:49:40 +0100 Subject: [PATCH 101/113] lint --- sota-implementations/dreamer/dreamer.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 4b80c9671f1..adcbb6aaab8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -22,7 +22,7 @@ # mixed precision training from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ -from torch.profiler import profile, ProfilerActivity, record_function +from torch.profiler import profile, ProfilerActivity from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import RSSMRollout diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 537f1ba99e9..ff14871b011 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -35,7 +35,6 @@ DoubleToFloat, DreamerDecoder, DreamerEnv, - DTypeCastTransform, EnvCreator, ExcludeTransform, # ExcludeTransform, From 4f374d9bdc8e85c965cbc52a31aa8c55c2f98e3b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 19 Apr 2024 10:04:39 +0100 Subject: [PATCH 102/113] lint --- sota-implementations/dreamer/dreamer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index adcbb6aaab8..e72c4cf63b8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -326,7 +326,9 @@ def compile_rssms(module): auto_cast_to_device=True, break_when_any_done=True, auto_reset=False, - tensordict=test_env._step_mdp(eval_rollout[..., -1]).to(device), + tensordict=eval_rollout[..., 0] + .exclude("next", "action") + .to(device), ) model_based_env_eval.apply(dump_video) eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() From 2dfa7aed17c0fc00858955d3dbee6403df9b7e57 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 19 Apr 2024 04:10:50 -0700 Subject: [PATCH 103/113] amend --- sota-implementations/dreamer/dreamer.py | 27 +++---------------------- torchrl/modules/models/model_based.py | 10 ++++----- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e72c4cf63b8..a0679f65c05 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -22,7 +22,6 @@ # mixed precision training from torch.cuda.amp import GradScaler from torch.nn.utils import clip_grad_norm_ -from torch.profiler import profile, ProfilerActivity from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import RSSMRollout @@ -189,11 +188,7 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16, - ) if use_autocast else contextlib.nullcontext(), ( - profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) - if (i == 1 and k == 1) - else contextlib.nullcontext() - ) as prof: + ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -202,8 +197,6 @@ def compile_rssms(module): + model_loss_td["loss_model_reco"] + model_loss_td["loss_model_reward"] ) - if i == 1 and k == 1: - prof.export_chrome_trace("trace_world_model.json") world_model_opt.zero_grad() if use_autocast: @@ -223,16 +216,9 @@ def compile_rssms(module): t_loss_actor_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(), ( - profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) - if (i == 1 and k == 1) - else contextlib.nullcontext() - ) as prof: + ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) - if i == 1 and k == 1: - prof.export_chrome_trace("trace_actor.json") - actor_opt.zero_grad() if use_autocast: scaler2.scale(actor_loss_td["loss_actor"]).backward() @@ -251,16 +237,9 @@ def compile_rssms(module): t_loss_critic_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(), ( - profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) - if (i == 1 and k == 1) - else contextlib.nullcontext() - ) as prof: + ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) - if i == 1 and k == 1: - prof.export_chrome_trace("trace_critic.json") - value_opt.zero_grad() if use_autocast: scaler3.scale(value_loss_td["loss_value"]).backward() diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 94e5ce18ae5..cc34d79bb5b 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -13,7 +13,7 @@ TensorDictSequential, ) from torch import nn - +from tensordict import LazyStackedTensorDict # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell from torchrl._utils import timeit @@ -259,12 +259,12 @@ def forward(self, tensordict): tensordict_out.append(_tensordict) if t < time_steps - 1: - _tensordict = step_mdp( - _tensordict.select(*self.out_keys, strict=False), keep_other=False - ) + _tensordict = _tensordict.select(*self.in_keys, strict=False) _tensordict = update_values[t + 1].update(_tensordict) - return torch.stack(tensordict_out, tensordict.ndim - 1) + out = torch.stack(tensordict_out, tensordict.ndim - 1) + assert not any(isinstance(val, LazyStackedTensorDict) for val in out.values(True)), out + return out class RSSMPrior(nn.Module): From 4e7496986853c6a25ac576f8f2629ea3eeae20ee Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 16:11:07 +0100 Subject: [PATCH 104/113] Update torchrl/objectives/dreamer.py --- torchrl/objectives/dreamer.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index f3dcd6f6d2c..6921a38e0e8 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -153,40 +153,6 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: reward_loss = reward_loss.mean().unsqueeze(-1) # import ipdb; ipdb.set_trace() - # Alternative: - # def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: - # tensordict = tensordict.copy() - # tensordict.rename_key_( - # ("next", self.tensor_keys.reward), - # ("next", self.tensor_keys.true_reward), - # ) - # tensordict = self.world_model(tensordict) - # # compute model loss - # kl_loss = self.kl_loss( - # tensordict.get(("next", self.tensor_keys.prior_mean)), - # tensordict.get(("next", self.tensor_keys.prior_std)), - # tensordict.get(("next", self.tensor_keys.posterior_mean)), - # tensordict.get(("next", self.tensor_keys.posterior_std)), - # ) - # - # dist: IndependentNormal = self.decoder.get_dist(tensordict) - # reco_loss = -dist.log_prob( - # tensordict.get(("next", self.tensor_keys.pixels)) - # ).mean() - # # x = tensordict.get(("next", self.tensor_keys.pixels)) - # # loc = dist.base_dist.loc - # # scale = dist.base_dist.scale - # # reco_loss = -self.normal_log_probability(x, loc, scale).mean() - # - # dist: IndependentNormal = self.reward_model.get_dist(tensordict) - # reward_loss = -dist.log_prob( - # tensordict.get(("next", self.tensor_keys.true_reward)) - # ).mean() - # # x = tensordict.get(("next", self.tensor_keys.true_reward)) - # # loc = dist.base_dist.loc - # # scale = dist.base_dist.scale - # # reward_loss = -self.normal_log_probability(x, loc, scale).mean() - return ( TensorDict( { From b36f86bb0366decd53773e28e2039c089e9d0fd3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 16:11:15 +0100 Subject: [PATCH 105/113] Update torchrl/objectives/dreamer.py --- torchrl/objectives/dreamer.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 6921a38e0e8..30f6dd10699 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -412,22 +412,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: pass def forward(self, fake_data) -> torch.Tensor: - # lambda_target = fake_data.get("lambda_target") - # # TODO: I think this should be next state and belief - # td = fake_data.select(("next", "state"), ("next", "belief")) - # td = td.rename_key_(("next", "state"), "state") - # tensordict_select = td.rename_key_(("next", "belief"), "belief") - # # tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) - # dist = self.value_model.get_dist(tensordict_select) - # if self.discount_loss: - # discount = self.gamma * torch.ones_like( - # lambda_target, device=lambda_target.device - # ) - # discount[..., 0, :] = 1 - # discount = discount.cumprod(dim=-2) - # value_loss = -(discount * dist.log_prob(lambda_target).unsqueeze(-1)).mean() - # else: - # value_loss = -dist.log_prob(lambda_target).mean() lambda_target = fake_data.get("lambda_target") tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) self.value_model(tensordict_select) From 46e8ac0daa4b0a7e1f442cfcc913e1740171121d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 17:34:35 +0100 Subject: [PATCH 106/113] fix examples --- .../linux_examples/scripts/run_test.sh | 48 +++++++++---------- sota-implementations/dreamer/dreamer.py | 6 +-- test/test_libs.py | 4 +- torchrl/modules/models/model_based.py | 7 ++- 4 files changed, 33 insertions(+), 32 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 1d11d481e3c..6ce551cb140 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -167,19 +167,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di # logger.record_video=True \ # logger.record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=4 \ - env_per_collector=2 \ - collector_device=cuda:0 \ - model_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 + collector.total_frames=200 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + replay_buffer.batch_size=10 \ + env.n_parallel_envs=4 \ +# env_per_collector=2 \ + optimization.optim_steps_per_batch=1 \ + logger.video=True \ +# record_frames=4 \ + replay_buffer.buffer_size=120 \ + networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -223,19 +221,17 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - total_frames=200 \ - init_random_frames=10 \ - batch_size=10 \ - frames_per_batch=200 \ - num_workers=2 \ - env_per_collector=1 \ - collector_device=cuda:0 \ - model_device=cuda:0 \ - optim_steps_per_batch=1 \ - record_video=True \ - record_frames=4 \ - buffer_size=120 \ - rssm_hidden_dim=17 + collector.total_frames=200 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + replay_buffer.batch_size=10 \ + env.n_parallel_envs=1 \ +# env_per_collector=2 \ + optimization.optim_steps_per_batch=1 \ + logger.video=True \ +# record_frames=4 \ + replay_buffer.buffer_size=120 \ + networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index a0679f65c05..f002c6420e7 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -188,7 +188,7 @@ def compile_rssms(module): with torch.autocast( device_type=device.type, dtype=torch.bfloat16, - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(): model_loss_td, sampled_tensordict = world_model_loss( sampled_tensordict ) @@ -216,7 +216,7 @@ def compile_rssms(module): t_loss_actor_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(): actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict) actor_opt.zero_grad() @@ -237,7 +237,7 @@ def compile_rssms(module): t_loss_critic_init = time.time() with torch.autocast( device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): + ) if use_autocast else contextlib.nullcontext(): value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) value_opt.zero_grad() diff --git a/test/test_libs.py b/test/test_libs.py index 7ddb0d4fc02..cfcde85cc2c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3416,7 +3416,9 @@ def test_robohive(self, envname, from_pixels, from_depths): torchrl_logger.info("no camera") return try: - env = RoboHiveEnv(envname, from_pixels=from_pixels, from_depths=from_depths) + env = RoboHiveEnv( + envname, from_pixels=from_pixels, from_depths=from_depths + ) except AttributeError as err: if "'MjData' object has no attribute 'get_body_xipos'" in str(err): torchrl_logger.info("tcdm are broken") diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index cc34d79bb5b..11ca9d12232 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -6,6 +6,7 @@ import torch from packaging import version +from tensordict import LazyStackedTensorDict from tensordict.nn import ( NormalParamExtractor, TensorDictModule, @@ -13,7 +14,7 @@ TensorDictSequential, ) from torch import nn -from tensordict import LazyStackedTensorDict + # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell from torchrl._utils import timeit @@ -263,7 +264,9 @@ def forward(self, tensordict): _tensordict = update_values[t + 1].update(_tensordict) out = torch.stack(tensordict_out, tensordict.ndim - 1) - assert not any(isinstance(val, LazyStackedTensorDict) for val in out.values(True)), out + assert not any( + isinstance(val, LazyStackedTensorDict) for val in out.values(True) + ), out return out From e43aee440f05143f06bdcc73d2d218b54734d80f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 18:05:28 +0100 Subject: [PATCH 107/113] amend --- sota-implementations/dreamer/dreamer.py | 2 +- torchrl/modules/models/model_based.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index f002c6420e7..e7b346b2b22 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -177,7 +177,7 @@ def compile_rssms(module): t_loss_actor = 0.0 t_loss_critic = 0.0 t_loss_model = 0.0 - for k in range(optim_steps_per_batch): + for _ in range(optim_steps_per_batch): # sample from replay buffer t_sample_init = time.time() sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 11ca9d12232..f8ee69363d9 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -19,7 +19,6 @@ from torch.nn import GRUCell from torchrl._utils import timeit -from torchrl.envs.utils import step_mdp from torchrl.modules.models.models import MLP UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") From fd23a54dd45492e3bf952f2df43ae53aaeb4a78d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 18:06:12 +0100 Subject: [PATCH 108/113] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 6ce551cb140..517a2b1929a 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -228,6 +228,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr env.n_parallel_envs=1 \ # env_per_collector=2 \ optimization.optim_steps_per_batch=1 \ + logger.backend=csv \ logger.video=True \ # record_frames=4 \ replay_buffer.buffer_size=120 \ From 98d402088234052f4c1c952ec14fcccbf10eebe1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 18:55:26 +0100 Subject: [PATCH 109/113] init --- .github/unittest/linux_examples/scripts/run_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 517a2b1929a..1b0413a0716 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -175,6 +175,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr # env_per_collector=2 \ optimization.optim_steps_per_batch=1 \ logger.video=True \ + logger.backend=csv \ # record_frames=4 \ replay_buffer.buffer_size=120 \ networks.rssm_hidden_dim=17 From a9e1cb05d566433959cde9ccf6c4cdde5ada890d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 20:21:26 +0100 Subject: [PATCH 110/113] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 1b0413a0716..1d2434aeaf2 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -172,11 +172,9 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.frames_per_batch=200 \ replay_buffer.batch_size=10 \ env.n_parallel_envs=4 \ -# env_per_collector=2 \ optimization.optim_steps_per_batch=1 \ logger.video=True \ logger.backend=csv \ -# record_frames=4 \ replay_buffer.buffer_size=120 \ networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ @@ -227,11 +225,9 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.frames_per_batch=200 \ replay_buffer.batch_size=10 \ env.n_parallel_envs=1 \ -# env_per_collector=2 \ optimization.optim_steps_per_batch=1 \ logger.backend=csv \ logger.video=True \ -# record_frames=4 \ replay_buffer.buffer_size=120 \ networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ From 12db41f12bafe2b633afefb81e1029f4d3d24b70 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Apr 2024 21:22:25 +0100 Subject: [PATCH 111/113] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 1d2434aeaf2..aacd9484572 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -170,12 +170,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.total_frames=200 \ collector.init_random_frames=10 \ collector.frames_per_batch=200 \ - replay_buffer.batch_size=10 \ env.n_parallel_envs=4 \ optimization.optim_steps_per_batch=1 \ logger.video=True \ logger.backend=csv \ replay_buffer.buffer_size=120 \ + replay_buffer.batch_length=12 \ networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ @@ -223,12 +223,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.total_frames=200 \ collector.init_random_frames=10 \ collector.frames_per_batch=200 \ - replay_buffer.batch_size=10 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ logger.backend=csv \ logger.video=True \ replay_buffer.buffer_size=120 \ + replay_buffer.batch_length=12 \ networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ From 81ec41c9ac013c937f891d69d1bbaddbca1db3e5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Apr 2024 14:07:37 +0100 Subject: [PATCH 112/113] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index aacd9484572..4587be88ddc 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -175,6 +175,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr logger.video=True \ logger.backend=csv \ replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ replay_buffer.batch_length=12 \ networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ @@ -228,6 +229,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr logger.backend=csv \ logger.video=True \ replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ replay_buffer.batch_length=12 \ networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ddpg/ddpg.py \ From 7733c37b83d18061a295519fe497797a43097bc1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Apr 2024 14:17:08 +0100 Subject: [PATCH 113/113] amend --- torchrl/record/loggers/csv.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 6bcd3f50c86..3f188a02a61 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -40,6 +40,8 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) value = float(value) self.scalars[name].append((global_step, value)) filepath = os.path.join(self.log_dir, "scalars", "".join([name, ".csv"])) + if not os.path.isfile(filepath): + os.makedirs(Path(filepath).parent, exist_ok=True) if filepath not in self.files: self.files[filepath] = open(filepath, "a") fd = self.files[filepath] @@ -95,6 +97,8 @@ def add_text(self, tag, text, global_step: Optional[int] = None): filepath = os.path.join( self.log_dir, "texts", "".join([tag, str(global_step)]) + ".txt" ) + if not os.path.isfile(filepath): + os.makedirs(Path(filepath).parent, exist_ok=True) if filepath not in self.files: self.files[filepath] = open(filepath, "w+") fd = self.files[filepath]