From 1d40688e2fee53765c41d692843f3c4afbc60932 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Thu, 10 Aug 2023 18:11:21 +0200 Subject: [PATCH 1/2] Add Dreamer-V3 algo --- sheeprl/__init__.py | 1 + sheeprl/algos/args.py | 6 +- sheeprl/algos/dreamer_v1/agent.py | 2 +- sheeprl/algos/dreamer_v1/args.py | 4 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 12 +- sheeprl/algos/dreamer_v2/agent.py | 88 +- sheeprl/algos/dreamer_v2/args.py | 4 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 75 +- sheeprl/algos/dreamer_v2/utils.py | 54 +- sheeprl/algos/dreamer_v3/__init__.py | 0 sheeprl/algos/dreamer_v3/agent.py | 1052 ++++++++++++++++++++++++ sheeprl/algos/dreamer_v3/args.py | 132 +++ sheeprl/algos/dreamer_v3/dreamer_v3.py | 715 ++++++++++++++++ sheeprl/algos/dreamer_v3/loss.py | 86 ++ sheeprl/algos/dreamer_v3/utils.py | 117 +++ sheeprl/algos/p2e_dv1/p2e_dv1.py | 4 +- sheeprl/algos/p2e_dv2/p2e_dv2.py | 4 +- sheeprl/algos/ppo/args.py | 6 +- sheeprl/envs/diambra_wrapper.py | 2 +- sheeprl/envs/dummy.py | 20 +- sheeprl/envs/minedojo.py | 1 + sheeprl/envs/wrappers.py | 55 +- sheeprl/models/models.py | 8 +- sheeprl/utils/distribution.py | 127 +++ sheeprl/utils/utils.py | 9 + tests/test_algos/test_algos.py | 120 ++- 26 files changed, 2624 insertions(+), 80 deletions(-) create mode 100644 sheeprl/algos/dreamer_v3/__init__.py create mode 100644 sheeprl/algos/dreamer_v3/agent.py create mode 100644 sheeprl/algos/dreamer_v3/args.py create mode 100644 sheeprl/algos/dreamer_v3/dreamer_v3.py create mode 100644 sheeprl/algos/dreamer_v3/loss.py create mode 100644 sheeprl/algos/dreamer_v3/utils.py diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index dfefc505..120325c8 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -12,6 +12,7 @@ from sheeprl.algos.dreamer_v1 import dreamer_v1 from sheeprl.algos.dreamer_v2 import dreamer_v2 +from sheeprl.algos.dreamer_v3 import dreamer_v3 from sheeprl.algos.droq import droq from sheeprl.algos.p2e_dv1 import p2e_dv1 from sheeprl.algos.p2e_dv2 import p2e_dv2 diff --git a/sheeprl/algos/args.py b/sheeprl/algos/args.py index 5f503911..40063534 100644 --- a/sheeprl/algos/args.py +++ b/sheeprl/algos/args.py @@ -32,7 +32,11 @@ class StandardArgs: screen_size: int = Arg(default=64, help="the size of the pixel-from observations (if any)") frame_stack: int = Arg(default=-1, help="how many frame to stack (only for pixel-like observations)") frame_stack_dilation: int = Arg(default=1, help="the dilation between the stacked frames, 1 no dilation") - max_episode_steps: int = Arg(default=-1) + max_episode_steps: int = Arg( + default=-1, + help="the maximum duration in terms of number of steps of an episode, -1 to disable. " + "This value will be divided by the `action_repeat` value during the environment creation.", + ) def __setattr__(self, __name: str, __value: Any) -> None: super().__setattr__(__name, __value) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 1fa44671..181677f0 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -193,7 +193,7 @@ def __init__( self.continue_model = continue_model -class Player(nn.Module): +class PlayerDV1(nn.Module): """The model of the DreamerV1 player. Args: diff --git a/sheeprl/algos/dreamer_v1/args.py b/sheeprl/algos/dreamer_v1/args.py index a8b91b7c..b9517ce4 100644 --- a/sheeprl/algos/dreamer_v1/args.py +++ b/sheeprl/algos/dreamer_v1/args.py @@ -69,7 +69,9 @@ class DreamerV1Args(StandardArgs): ) action_repeat: int = Arg(default=2, help="the number of times an action is repeated") max_episode_steps: int = Arg( - default=1000, help="the maximum duration in terms of number of steps of an episode, -1 to disable" + default=1000, + help="the maximum duration in terms of number of steps of an episode, -1 to disable. " + "This value will be divided by the `action_repeat` value during the environment creation.", ) atari_noop_max: int = Arg( default=30, diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 935701d3..13e09376 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -19,7 +19,7 @@ from torch.utils.data import BatchSampler from torchmetrics import MeanMetric -from sheeprl.algos.dreamer_v1.agent import Player, WorldModel, build_models +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models from sheeprl.algos.dreamer_v1.args import DreamerV1Args from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v2.utils import test @@ -209,8 +209,8 @@ def train( aggregator.update("Loss/state_loss", state_loss.detach()) aggregator.update("Loss/continue_loss", continue_loss.detach()) aggregator.update("State/kl", kl.detach()) - aggregator.update("State/p_entropy", p.entropy().mean().detach()) - aggregator.update("State/q_entropy", q.entropy().mean().detach()) + aggregator.update("State/post_entropy", p.entropy().mean().detach()) + aggregator.update("State/prior_entropy", q.entropy().mean().detach()) # Behaviour Learning # unflatten first 2 dimensions of recurrent and posterior states in order to have all the states on the first dimension. @@ -443,7 +443,7 @@ def main(): state["actor"] if args.checkpoint_path else None, state["critic"] if args.checkpoint_path else None, ) - player = Player( + player = PlayerDV1( world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, @@ -482,8 +482,8 @@ def main(): "Loss/reward_loss": MeanMetric(sync_on_compute=False), "Loss/state_loss": MeanMetric(sync_on_compute=False), "Loss/continue_loss": MeanMetric(sync_on_compute=False), - "State/p_entropy": MeanMetric(sync_on_compute=False), - "State/q_entropy": MeanMetric(sync_on_compute=False), + "State/post_entropy": MeanMetric(sync_on_compute=False), + "State/prior_entropy": MeanMetric(sync_on_compute=False), "State/kl": MeanMetric(sync_on_compute=False), "Params/exploration_amout": MeanMetric(sync_on_compute=False), "Grads/world_model": MeanMetric(sync_on_compute=False), diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 84ad0d2b..6236f5c2 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -25,6 +25,24 @@ class CNNEncoder(nn.Module): + """The Dreamer-V2 image encoder. This is composed of 4 `nn.Conv2d` with + kernel_size=3, stride=2 and padding=1. No bias is used if a `nn.LayerNorm` + is used after the convolution. This 4-stages model assumes that the image + is a 64x64. If more than one image is to be encoded, then those will + be concatenated on the channel dimension and fed to the encoder. + + Args: + keys (Sequence[str]): the keys representing the image observations to encode. + input_channels (Sequence[int]): the input channels, one for each image observation to encode. + image_size (Tuple[int, int]): the image size as (Height,Width). + channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels + will be [1, 2, 4, 8] * `channels_multiplier`. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + activation (ModuleType, optional): the activation function. + Defaults to nn.ELU. + """ + def __init__( self, keys: Sequence[str], @@ -59,6 +77,24 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: class MLPEncoder(nn.Module): + """The Dreamer-V3 vector encoder. This is composed of N `nn.Linear` layers, where + N is specified by `mlp_layers`. No bias is used if a `nn.LayerNorm` is used after the linear layer. + If more than one vector is to be encoded, then those will concatenated on the last + dimension before being fed to the encoder. + + Args: + keys (Sequence[str]): the keys representing the vector observations to encode. + input_dims (Sequence[int]): the dimensions of every vector to encode. + mlp_layers (int, optional): how many mlp layers. + Defaults to 4. + dense_units (int, optional): the dimension of every mlp. + Defaults to 512. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + activation (ModuleType, optional): the activation function after every layer. + Defaults to nn.ELU. + """ + def __init__( self, keys: Sequence[str], @@ -87,6 +123,25 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: class CNNDecoder(nn.Module): + """The almost-exact inverse of the `CNNEncoder` class, where in 4 stages it reconstructs + the observation image to 64x64. If multiple images are to be reconstructed, + then it will create a dictionary with an entry for every reconstructed image. + No bias is used if a `nn.LayerNorm` is used after the `nn.Conv2dTranspose` layer. + + Args: + keys (Sequence[str]): the keys of the image observation to be reconstructed. + output_channels (Sequence[int]): the output channels, one for every image observation. + channels_multiplier (int): the channels multiplier, same for the encoder network. + latent_state_size (int): the size of the latent state. Before applying the decoder, + a `nn.Linear` layer is used to project the latent state to a feature vector. + cnn_encoder_output_dim (int): the output of the image encoder. + image_size (Tuple[int, int]): the final image size. + activation (nn.Module, optional): the activation function. + Defaults to nn.ELU. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + """ + def __init__( self, keys: Sequence[str], @@ -137,6 +192,25 @@ def forward(self, latent_states: Tensor) -> Dict[str, Tensor]: class MLPDecoder(nn.Module): + """The exact inverse of the MLPEncoder. This is composed of N `nn.Linear` layers, where + N is specified by `mlp_layers`. No bias is used if a `nn.LayerNorm` is used after the linear layer. + If more than one vector is to be decoded, then it will create a dictionary with an entry + for every reconstructed vector. + + Args: + keys (Sequence[str]): the keys representing the vector observations to decode. + output_dims (Sequence[int]): the dimensions of every vector to decode. + latent_state_size (int): the dimension of the latent state. + mlp_layers (int, optional): how many mlp layers. + Defaults to 4. + dense_units (int, optional): the dimension of every mlp. + Defaults to 512. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + activation (ModuleType, optional): the activation function after every layer. + Defaults to nn.ELU. + """ + def __init__( self, keys: Sequence[str], @@ -168,8 +242,10 @@ def forward(self, latent_states: Tensor) -> Dict[str, Tensor]: class RecurrentModel(nn.Module): - """ - Recurrent model for the model-base Dreamer agent. + """Recurrent model for the model-base Dreamer-V3 agent. + This implementation uses the `sheeprl.models.models.LayerNormGRUCell`, which combines + the standard GRUCell from PyTorch with the `nn.LayerNorm`, where the normalization is applied + right after having computed the projection from the input to the weight space. Args: input_size (int): the input size of the model. @@ -559,7 +635,7 @@ def __init__( self.continue_model = continue_model -class Player(nn.Module): +class PlayerDV2(nn.Module): """ The model of the Dreamer_v1 player. @@ -605,7 +681,6 @@ def __init__( self.discrete_size = discrete_size self.recurrent_state_size = recurrent_state_size self.num_envs = num_envs - self.init_states() def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: """Initialize the states and the actions for the ended environments. @@ -751,7 +826,6 @@ def build_models( # Sizes stochastic_size = args.stochastic_size * args.discrete_size latent_state_size = stochastic_size + args.recurrent_state_size - mlp_dims = [obs_space[k].shape[0] for k in mlp_keys] # Define models cnn_encoder = ( @@ -769,7 +843,7 @@ def build_models( mlp_encoder = ( MLPEncoder( keys=mlp_keys, - input_dims=mlp_dims, + input_dims=[obs_space[k].shape[0] for k in mlp_keys], mlp_layers=args.mlp_layers, dense_units=args.dense_units, activation=dense_act, @@ -826,7 +900,7 @@ def build_models( mlp_decoder = ( MLPDecoder( keys=mlp_keys, - output_dims=mlp_dims, + output_dims=[obs_space[k].shape[0] for k in mlp_keys], latent_state_size=latent_state_size, mlp_layers=args.mlp_layers, dense_units=args.dense_units, diff --git a/sheeprl/algos/dreamer_v2/args.py b/sheeprl/algos/dreamer_v2/args.py index a9aceeff..e1cc90bd 100644 --- a/sheeprl/algos/dreamer_v2/args.py +++ b/sheeprl/algos/dreamer_v2/args.py @@ -90,7 +90,9 @@ class DreamerV2Args(StandardArgs): max_step_expl_decay: int = Arg(default=0, help="the maximum number of decay steps") action_repeat: int = Arg(default=2, help="the number of times an action is repeated") max_episode_steps: int = Arg( - default=1000, help="the maximum duration in terms of number of steps of an episode, -1 to disable" + default=1000, + help="the maximum duration in terms of number of steps of an episode, -1 to disable. " + "This value will be divided by the `action_repeat` value during the environment creation.", ) atari_noop_max: int = Arg( default=30, diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 596d21fc..a6d0f7aa 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -24,7 +24,7 @@ from torch.utils.data import BatchSampler from torchmetrics import MeanMetric -from sheeprl.algos.dreamer_v2.agent import Player, WorldModel, build_models +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models from sheeprl.algos.dreamer_v2.args import DreamerV2Args from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test @@ -106,14 +106,29 @@ def train( mlp_keys (Sequence[str]): the mlp keys to encode/decode. actions_dim (Sequence[int]): the actions dimension. """ + + # The environment interaction goes like this: + # Actions: 0 a1 a2 a3 + # ^ \ ^ \ ^ \ + # / \ / \ / \ + # / v / v / v + # Observations: o0 o1 o2 o3 + # Rewards: 0 r1 r2 r3 + # Dones: 0 d1 d2 d3 + # Is-first 1 i1 i2 i3 + batch_size = args.per_rank_batch_size sequence_length = args.per_rank_sequence_length device = fabric.device batch_obs = {k: data[k] / 255 - 0.5 for k in cnn_keys} batch_obs.update({k: data[k] for k in mlp_keys}) + + # Given how the environment interaction works, we assume that the first element in a sequence + # is the first one, as if the environment has been reset data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :]) # Dynamic Learning + stoch_state_size = args.stochastic_size * args.discrete_size recurrent_state = torch.zeros(1, batch_size, args.recurrent_state_size, device=device) posterior = torch.zeros(1, batch_size, args.stochastic_size, args.discrete_size, device=device) @@ -122,13 +137,11 @@ def train( recurrent_states = torch.zeros(sequence_length, batch_size, args.recurrent_state_size, device=device) # Initialize all the tensor to collect priors and posteriors states with their associated logits - priors_logits = torch.empty(sequence_length, batch_size, args.stochastic_size * args.discrete_size, device=device) + priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) posteriors = torch.empty(sequence_length, batch_size, args.stochastic_size, args.discrete_size, device=device) - posteriors_logits = torch.empty( - sequence_length, batch_size, args.stochastic_size * args.discrete_size, device=device - ) + posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) - # Embedded observations from the environment + # Embed observations from the environment embedded_obs = world_model.encoder(batch_obs) for i in range(0, sequence_length): @@ -197,17 +210,17 @@ def train( aggregator.update("Loss/continue_loss", continue_loss.detach()) aggregator.update("State/kl", kl.mean().detach()) aggregator.update( - "State/p_entropy", + "State/post_entropy", Independent(OneHotCategorical(logits=posteriors_logits.detach()), 1).entropy().mean().detach(), ) aggregator.update( - "State/q_entropy", + "State/prior_entropy", Independent(OneHotCategorical(logits=priors_logits.detach()), 1).entropy().mean().detach(), ) # Behaviour Learning # (1, batch_size * sequence_length, stochastic_size * discrete_size) - imagined_prior = posteriors.detach().reshape(1, -1, args.stochastic_size * args.discrete_size) + imagined_prior = posteriors.detach().reshape(1, -1, stoch_state_size) # (1, batch_size * sequence_length, recurrent_state_size). recurrent_state = recurrent_states.detach().reshape(1, -1, args.recurrent_state_size) @@ -219,7 +232,7 @@ def train( imagined_trajectories = torch.empty( args.horizon + 1, batch_size * sequence_length, - args.stochastic_size * args.discrete_size + args.recurrent_state_size, + stoch_state_size + args.recurrent_state_size, device=device, ) imagined_trajectories[0] = imagined_latent_state @@ -233,6 +246,19 @@ def train( ) imagined_actions[0] = torch.zeros(1, batch_size * sequence_length, data["actions"].shape[-1]) + # The imagination goes like this, with H=3: + # Actions: 0 a'1 a'2 a'3 + # ^ \ ^ \ ^ \ + # / \ / \ / \ + # / v / v / v + # States: z0 ---> z'1 ---> z'2 ---> z'3 + # Rewards: r'0 r'1 r'2 r'3 + # Values: v'0 v'1 v'2 v'3 + # Lambda-values: l'0 l'1 l'2 + # Continues: c0 c'1 c'2 c'3 + # where z0 comes from the posterior (is initialized as the concatenation of the posteriors and the recurrent states), + # while z'i is the imagined states (prior) + # Imagine trajectories in the latent space for i in range(1, args.horizon + 1): # (1, batch_size * sequence_length, sum(actions_dim)) @@ -243,7 +269,7 @@ def train( imagined_prior, recurrent_state = world_model.rssm.imagination(imagined_prior, recurrent_state, actions) # Update current state - imagined_prior = imagined_prior.view(1, -1, args.stochastic_size * args.discrete_size) + imagined_prior = imagined_prior.view(1, -1, stoch_state_size) imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1) imagined_trajectories[i] = imagined_latent_state @@ -259,7 +285,7 @@ def train( else: continues = torch.ones_like(predicted_rewards.detach()) * args.gamma - # Compute the lambda_values, by passing as last values the values of the last imagined state + # Compute the lambda_values, by passing as last value the value of the last imagined state # (horizon, batch_size * sequence_length, 1) lambda_values = compute_lambda_values( predicted_rewards[:-1], @@ -270,11 +296,26 @@ def train( lmbda=args.lmbda, ) - # Compute the discounts to multiply to the lambda values + # Compute the discounts to multiply the lambda values with torch.no_grad(): discount = torch.cumprod(torch.cat((torch.ones_like(continues[:1]), continues[:-1]), 0), 0) # Actor optimization step. Eq. 6 from the paper + # Given the following diagram, with H=3: + # Actions: 0 [a'1] [a'2] a'3 + # ^ \ ^ \ ^ \ + # / \ / \ / \ + # / v / v / v + # States: [z0] -> [z'1] -> z'2 -> z'3 + # Values: [v'0] [v'1] v'2 v'3 + # Lambda-values: l'0 [l'1] [l'2] + # Entropies: [e'1] [e'2] + # The quantities wrapped into `[]` are the ones used for the actor optimization. + # From Hafner (https://github.com/danijar/dreamerv2/blob/main/dreamerv2/agent.py#L253): + # `Two states are lost at the end of the trajectory, one for the boostrap + # value prediction and one because the corresponding action does not lead + # anywhere anymore. One target is lost at the start of the trajectory + # because the initial state comes from the replay buffer.` actor_optimizer.zero_grad(set_to_none=True) policies: Sequence[Distribution] = actor(imagined_trajectories[:-2].detach())[1] @@ -309,7 +350,7 @@ def train( aggregator.update("Loss/policy_loss", policy_loss.detach()) # Predict the values distribution only for the first H (horizon) imagined states (to match the dimension with the lambda values), - # It removes the last imagined state in the trajectory because it is used only for computing correclty the lambda values + # It removes the last imagined state in the trajectory because it is used for bootstrapping qv = Independent(Normal(critic(imagined_trajectories.detach()[:-1]), 1), 1) # Critic optimization step. Eq. 5 from the paper. @@ -434,7 +475,7 @@ def main(): state["critic"] if args.checkpoint_path else None, state["target_critic"] if args.checkpoint_path else None, ) - player = Player( + player = PlayerDV2( world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, @@ -474,8 +515,8 @@ def main(): "Loss/reward_loss": MeanMetric(sync_on_compute=False), "Loss/state_loss": MeanMetric(sync_on_compute=False), "Loss/continue_loss": MeanMetric(sync_on_compute=False), - "State/p_entropy": MeanMetric(sync_on_compute=False), - "State/q_entropy": MeanMetric(sync_on_compute=False), + "State/post_entropy": MeanMetric(sync_on_compute=False), + "State/prior_entropy": MeanMetric(sync_on_compute=False), "State/kl": MeanMetric(sync_on_compute=False), "Params/exploration_amout": MeanMetric(sync_on_compute=False), "Grads/world_model": MeanMetric(sync_on_compute=False), diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 27a2b548..392d664c 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING, List, Optional, Union import gymnasium as gym @@ -11,16 +12,13 @@ from sheeprl.utils.env import make_dict_env if TYPE_CHECKING: - from sheeprl.algos.dreamer_v2.agent import Player + from sheeprl.algos.dreamer_v1.agent import PlayerDV1 + from sheeprl.algos.dreamer_v1.args import DreamerV1Args + from sheeprl.algos.dreamer_v2.agent import PlayerDV2 + from sheeprl.algos.dreamer_v2.args import DreamerV2Args -from sheeprl.algos.dreamer_v1.args import DreamerV1Args -from sheeprl.algos.dreamer_v2.args import DreamerV2Args - -def compute_stochastic_state( - logits: Tensor, - discrete: int = 32, -) -> Tensor: +def compute_stochastic_state(logits: Tensor, discrete: int = 32, sample=True) -> Tensor: """ Compute the stochastic state from the logits computed by the transition or representaiton model. @@ -28,16 +26,19 @@ def compute_stochastic_state( logits (Tensor): logits from either the representation model or the transition model. discrete (int, optional): the size of the Categorical variables. Defaults to 32. + sample (bool): whether or not to sample the stochastic state. + Default to True. Returns: The sampled stochastic state. """ logits = logits.view(*logits.shape[:-1], -1, discrete) - dist = Independent(OneHotCategoricalStraightThrough(logits=logits), 1) - return dist.rsample() + dist = Independent(OneHotCategoricalStraightThrough(logits=logits, validate_args=False), 1) + stochastic_state = dist.rsample() if sample else dist.mode + return stochastic_state -def init_weights(m: nn.Module): +def init_weights(m: nn.Module, mode: str = "normal"): """ Initialize the parameters of the m module acording to the Xavier normal method. @@ -45,13 +46,17 @@ def init_weights(m: nn.Module): Args: m (nn.Module): the module to be initialized. """ - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - nn.init.xavier_normal_(m.weight.data) + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): + if mode == "normal": + nn.init.xavier_normal_(m.weight.data) + elif mode == "uniform": + nn.init.xavier_uniform_(m.weight.data) + elif mode == "zero": + nn.init.constant_(m.weight.data, 0) + else: + raise RuntimeError(f"Unrecognized initialization: {mode}. Choose between: `normal`, `uniform` and `zero`") if m.bias is not None: nn.init.constant_(m.bias.data, 0) - elif isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight.data) - nn.init.constant_(m.bias.data, 0) def compute_lambda_values( @@ -76,26 +81,28 @@ def compute_lambda_values( @torch.no_grad() def test( - player: "Player", + player: Union["PlayerDV2", "PlayerDV1"], fabric: Fabric, - args: Union[DreamerV2Args, "DreamerV1Args"], + args: Union["DreamerV2Args", "DreamerV1Args"], cnn_keys: List[str], mlp_keys: List[str], test_name: str = "", + sample_actions: bool = False, ): """Test the model on the environment with the frozen model. Args: - player (Player): the agent which contains all the models needed to play. + player (PlayerDV2): the agent which contains all the models needed to play. fabric (Fabric): the fabric instance. - args (Union[DreamerV2Args, DreamerV1Args]): the hyper-parameters. + args (Union[DreamerV3Args, DreamerV2Args, DreamerV1Args]): the hyper-parameters. cnn_keys (Sequence[str]): the keys encoded by the cnn encoder. mlp_keys (Sequence[str]): the keys encoded by the mlp encoder. test_name (str): the name of the test. Default to "". """ + log_dir = fabric.logger.log_dir if len(fabric.loggers) > 0 else os.getcwd() env: gym.Env = make_dict_env( - args.env_id, args.seed, 0, args, fabric.logger.log_dir, "test" + (f"_{test_name}" if test_name != "" else "") + args.env_id, args.seed, 0, args, log_dir, "test" + (f"_{test_name}" if test_name != "" else "") )() done = False cumulative_rew = 0 @@ -114,7 +121,7 @@ def test( elif k in mlp_keys: preprocessed_obs[k] = v[None, ...].to(device) real_actions = player.get_greedy_action( - preprocessed_obs, False, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -128,5 +135,6 @@ def test( done = done or truncated or args.dry_run cumulative_rew += reward fabric.print("Test - Reward:", cumulative_rew) - fabric.logger.log_metrics({"Test/cumulative_reward": cumulative_rew}, 0) + if len(fabric.loggers) > 0: + fabric.logger.log_metrics({"Test/cumulative_reward": cumulative_rew}, 0) env.close() diff --git a/sheeprl/algos/dreamer_v3/__init__.py b/sheeprl/algos/dreamer_v3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py new file mode 100644 index 00000000..325178e1 --- /dev/null +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -0,0 +1,1052 @@ +import copy +from functools import partial +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from lightning.fabric import Fabric +from lightning.fabric.wrappers import _FabricModule +from torch import Tensor, device, nn +from torch.distributions import ( + Distribution, + Independent, + Normal, + OneHotCategorical, + OneHotCategoricalStraightThrough, + TanhTransform, + TransformedDistribution, +) +from torch.distributions.utils import probs_to_logits + +from sheeprl.algos.dreamer_v2.agent import WorldModel +from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights +from sheeprl.algos.dreamer_v3.args import DreamerV3Args +from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder +from sheeprl.utils.distribution import TruncatedNormal +from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward +from sheeprl.utils.utils import symlog + + +class CNNEncoder(nn.Module): + """The Dreamer-V3 image encoder. This is composed of 4 `nn.Conv2d` with + kernel_size=3, stride=2 and padding=1. No bias is used if a `nn.LayerNorm` + is used after the convolution. This 4-stages model assumes that the image + is a 64x64 and it ends with a resolution of 4x4. If more than one image is to be encoded, then those will + be concatenated on the channel dimension and fed to the encoder. + + Args: + keys (Sequence[str]): the keys representing the image observations to encode. + input_channels (Sequence[int]): the input channels, one for each image observation to encode. + image_size (Tuple[int, int]): the image size as (Height,Width). + channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels + will be [1, 2, 4, 8] * `channels_multiplier`. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + activation (ModuleType, optional): the activation function. + Defaults to nn.SiLU. + """ + + def __init__( + self, + keys: Sequence[str], + input_channels: Sequence[int], + image_size: Tuple[int, int], + channels_multiplier: int, + layer_norm: bool = True, + activation: ModuleType = nn.SiLU, + ) -> None: + super().__init__() + self.keys = keys + self.input_dim = (sum(input_channels), *image_size) + self.model = nn.Sequential( + CNN( + input_channels=self.input_dim[0], + hidden_channels=(torch.tensor([1, 2, 4, 8]) * channels_multiplier).tolist(), + cnn_layer=nn.Conv2d, + layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, + activation=activation, + norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None, + norm_args=[{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(4)] + if layer_norm + else None, + ), + nn.Flatten(-3, -1), + ) + with torch.no_grad(): + self.output_dim = self.model(torch.zeros(1, *self.input_dim)).shape[-1] + + def forward(self, obs: Dict[str, Tensor]) -> Tensor: + x = torch.cat([obs[k] for k in self.keys], -3) # channels dimension + return cnn_forward(self.model, x, x.shape[-3:], (-1,)) + + +class MLPEncoder(nn.Module): + """The Dreamer-V3 vector encoder. This is composed of N `nn.Linear` layers, where + N is specified by `mlp_layers`. No bias is used if a `nn.LayerNorm` is used after the linear layer. + If more than one vector is to be encoded, then those will concatenated on the last + dimension before being fed to the encoder. + + Args: + keys (Sequence[str]): the keys representing the vector observations to encode. + input_dims (Sequence[int]): the dimensions of every vector to encode. + mlp_layers (int, optional): how many mlp layers. + Defaults to 4. + dense_units (int, optional): the dimension of every mlp. + Defaults to 512. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + activation (ModuleType, optional): the activation function after every layer. + Defaults to nn.SiLU. + symlog_inputs (bool, optional): whether to squash the input with the symlog function. + Defaults to True. + """ + + def __init__( + self, + keys: Sequence[str], + input_dims: Sequence[int], + mlp_layers: int = 4, + dense_units: int = 512, + layer_norm: bool = True, + activation: ModuleType = nn.SiLU, + symlog_inputs: bool = True, + ) -> None: + super().__init__() + self.keys = keys + self.input_dim = sum(input_dims) + self.model = MLP( + self.input_dim, + None, + [dense_units] * mlp_layers, + activation=activation, + layer_args={"bias": not layer_norm}, + norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] + if layer_norm + else None, + ) + self.output_dim = dense_units + self.symlog_inputs = symlog_inputs + + def forward(self, obs: Dict[str, Tensor]) -> Tensor: + x = torch.cat([symlog(obs[k]) if self.symlog_inputs else obs[k] for k in self.keys], -1) + return self.model(x) + + +class CNNDecoder(nn.Module): + """The exact inverse of the `CNNEncoder` class. It assumes an initial resolution + of 4x4, and in 4 stages reconstructs the observation image to 64x64. If multiple + images are to be reconstructed, then it will create a dictionary with an entry + for every reconstructed image. No bias is used if a `nn.LayerNorm` is used after + the `nn.Conv2dTranspose` layer. + + Args: + keys (Sequence[str]): the keys of the image observation to be reconstructed. + output_channels (Sequence[int]): the output channels, one for every image observation. + channels_multiplier (int): the channels multiplier, same for the encoder network. + latent_state_size (int): the size of the latent state. Before applying the decoder, + a `nn.Linear` layer is used to project the latent state to a feature vector + of dimension [8 * `channels_multiplier`, 4, 4]. + cnn_encoder_output_dim (int): the output of the image encoder. It should be equal to + 8 * `channels_multiplier` * 4 * 4. + image_size (Tuple[int, int]): the final image size. + activation (nn.Module, optional): the activation function. + Defaults to nn.SiLU. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + """ + + def __init__( + self, + keys: Sequence[str], + output_channels: Sequence[int], + channels_multiplier: int, + latent_state_size: int, + cnn_encoder_output_dim: int, + image_size: Tuple[int, int], + activation: nn.Module = nn.SiLU, + layer_norm: bool = True, + ) -> None: + super().__init__() + self.keys = keys + self.output_channels = output_channels + self.cnn_encoder_output_dim = cnn_encoder_output_dim + self.image_size = image_size + self.output_dim = (sum(output_channels), *image_size) + self.model = nn.Sequential( + nn.Linear(latent_state_size, cnn_encoder_output_dim), + nn.Unflatten(1, (-1, 4, 4)), + DeCNN( + input_channels=8 * channels_multiplier, + hidden_channels=(torch.tensor([4, 2, 1]) * channels_multiplier).tolist() + [self.output_dim[0]], + cnn_layer=nn.ConvTranspose2d, + layer_args=[ + {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, + {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, + {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, + {"kernel_size": 4, "stride": 2, "padding": 1}, + ], + activation=[activation, activation, activation, None], + norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None, + norm_args=[ + {"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier, "eps": 1e-3} for i in range(3) + ] + + [None] + if layer_norm + else None, + ), + ) + + def forward(self, latent_states: Tensor) -> Dict[str, Tensor]: + cnn_out = cnn_forward(self.model, latent_states, (latent_states.shape[-1],), self.output_dim) + 0.5 + return {k: rec_obs for k, rec_obs in zip(self.keys, torch.split(cnn_out, self.output_channels, -3))} + + +class MLPDecoder(nn.Module): + """The exact inverse of the MLPEncoder. This is composed of N `nn.Linear` layers, where + N is specified by `mlp_layers`. No bias is used if a `nn.LayerNorm` is used after the linear layer. + If more than one vector is to be decoded, then it will create a dictionary with an entry + for every reconstructed vector. + + Args: + keys (Sequence[str]): the keys representing the vector observations to decode. + output_dims (Sequence[int]): the dimensions of every vector to decode. + latent_state_size (int): the dimension of the latent state. + mlp_layers (int, optional): how many mlp layers. + Defaults to 4. + dense_units (int, optional): the dimension of every mlp. + Defaults to 512. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + activation (ModuleType, optional): the activation function after every layer. + Defaults to nn.SiLU. + """ + + def __init__( + self, + keys: Sequence[str], + output_dims: Sequence[str], + latent_state_size: int, + mlp_layers: int = 4, + dense_units: int = 512, + activation: ModuleType = nn.SiLU, + layer_norm: bool = True, + ) -> None: + super().__init__() + self.output_dims = output_dims + self.keys = keys + self.model = MLP( + latent_state_size, + None, + [dense_units] * mlp_layers, + activation=activation, + layer_args={"bias": not layer_norm}, + norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] + if layer_norm + else None, + ) + self.heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.output_dims]) + + def forward(self, latent_states: Tensor) -> Dict[str, Tensor]: + x = self.model(latent_states) + return {k: h(x) for k, h in zip(self.keys, self.heads)} + + +class RecurrentModel(nn.Module): + """Recurrent model for the model-base Dreamer-V3 agent. + This implementation uses the `sheeprl.models.models.LayerNormGRUCell`, which combines + the standard GRUCell from PyTorch with the `nn.LayerNorm`, where the normalization is applied + right after having computed the projection from the input to the weight space. + + Args: + input_size (int): the input size of the model. + dense_units (int): the number of dense units. + recurrent_state_size (int): the size of the recurrent state. + activation_fn (nn.Module): the activation function. + Default to SiLU. + layer_norm (bool, optional): whether to use the LayerNorm inside the GRU. + Defaults to True. + """ + + def __init__( + self, + input_size: int, + recurrent_state_size: int, + dense_units: int, + activation_fn: nn.Module = nn.SiLU, + layer_norm: bool = True, + ) -> None: + super().__init__() + self.mlp = MLP( + input_dims=input_size, + output_dim=None, + hidden_sizes=[dense_units], + activation=activation_fn, + layer_args={"bias": not layer_norm}, + norm_layer=[nn.LayerNorm] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units, "eps": 1e-3}] if layer_norm else None, + ) + self.rnn = LayerNormGRUCell(dense_units, recurrent_state_size, bias=False, batch_first=False, layer_norm=True) + + def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor: + """ + Compute the next recurrent state from the latent state (stochastic and recurrent states) and the actions. + + Args: + input (Tensor): the input tensor composed by the stochastic state and the actions concatenated together. + recurrent_state (Tensor): the previous recurrent state. + + Returns: + the computed recurrent output and recurrent state. + """ + feat = self.mlp(input) + out = self.rnn(feat, recurrent_state) + return out + + +class RSSM(nn.Module): + """RSSM model for the model-base Dreamer agent. + + Args: + recurrent_model (nn.Module): the recurrent model of the RSSM model described in [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551). + representation_model (nn.Module): the representation model composed by a multi-layer perceptron to compute the stochastic part of the latent state. + For more information see [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + transition_model (nn.Module): the transition model described in [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + The model is composed by a multi-layer perceptron to predict the stochastic part of the latent state. + discrete (int, optional): the size of the Categorical variables. + Defaults to 32. + unimix: (float, optional): the percentage of uniform distribution to inject into the categorical distribution over states, + i.e. given some logits `l` and probabilities `p = softmax(l)`, then `p = (1 - self.unimix) * p + self.unimix * unif`, + where `unif = `1 / self.discrete`. + Defaults to 0.01. + """ + + def __init__( + self, + recurrent_model: nn.Module, + representation_model: nn.Module, + transition_model: nn.Module, + discrete: int = 32, + unimix: float = 0.01, + ) -> None: + super().__init__() + self.recurrent_model = recurrent_model + self.representation_model = representation_model + self.transition_model = transition_model + self.discrete = discrete + self.unimix = unimix + + def dynamic( + self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, embedded_obs: Tensor, is_first: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Perform one step of the dynamic learning: + Recurrent model: compute the recurrent state from the previous latent space, the action taken by the agent, + i.e., it computes the deterministic state (or ht). + Transition model: predict the prior from the recurrent output. + Representation model: compute the posterior from the recurrent state and from + the embedded observations provided by the environment. + For more information see [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551) + and [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + + Args: + posterior (Tensor): the stochastic state computed by the representation model (posterior). It is expected + to be of dimension `[stoch_size, self.discrete]`, which by default is `[32, 32]`. + recurrent_state (Tensor): a tuple representing the recurrent state of the recurrent model. + action (Tensor): the action taken by the agent. + embedded_obs (Tensor): the embedded observations provided by the environment. + is_first (Tensor): if this is the first step in the episode. + + Returns: + The recurrent state (Tensor): the recurrent state of the recurrent model. + The posterior stochastic state (Tensor): computed by the representation model + The prior stochastic state (Tensor): computed by the transition model + The logits of the posterior state (Tensor): computed by the transition model from the recurrent state. + The logits of the prior state (Tensor): computed by the transition model from the recurrent state. + from the recurrent state and the embbedded observation. + """ + action = (1 - is_first) * action + recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state)) + posterior = posterior.view(*posterior.shape[:-2], -1) + posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ + 1 + ].view_as(posterior) + recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state) + prior_logits, prior = self._transition(recurrent_state) + posterior_logits, posterior = self._representation(recurrent_state, embedded_obs) + return recurrent_state, posterior, prior, posterior_logits, prior_logits + + def _uniform_mix(self, logits: Tensor) -> Tensor: + dim = logits.dim() + if dim == 3: + logits = logits.view(*logits.shape[:-1], -1, self.discrete) + elif dim != 4: + raise RuntimeError(f"The logits expected shape is 3 or 4: received a {dim}D tensor") + if self.unimix > 0.0: + probs = logits.softmax(dim=-1) + uniform = torch.ones_like(probs) / self.discrete + probs = (1 - self.unimix) * probs + self.unimix * uniform + logits = probs_to_logits(probs) + logits = logits.view(*logits.shape[:-2], -1) + return logits + + def _representation(self, recurrent_state: Tensor, embedded_obs: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + recurrent_state (Tensor): the recurrent state of the recurrent model, i.e., + what is called h or deterministic state in [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551). + embedded_obs (Tensor): the embedded real observations provided by the environment. + + Returns: + logits (Tensor): the logits of the distribution of the posterior state. + posterior (Tensor): the sampled posterior stochastic state. + """ + logits: Tensor = self.representation_model(torch.cat((recurrent_state, embedded_obs), -1)) + logits = self._uniform_mix(logits) + return logits, compute_stochastic_state(logits, discrete=self.discrete) + + def _transition(self, recurrent_out: Tensor, sample_state=True) -> Tuple[Tensor, Tensor]: + """ + Args: + recurrent_out (Tensor): the output of the recurrent model, i.e., the deterministic part of the latent space. + sampler_state (bool): whether or not to sample the stochastic state. + Default to True + + Returns: + logits (Tensor): the logits of the distribution of the prior state. + prior (Tensor): the sampled prior stochastic state. + """ + logits: Tensor = self.transition_model(recurrent_out) + logits = self._uniform_mix(logits) + return logits, compute_stochastic_state(logits, discrete=self.discrete, sample=sample_state) + + def imagination(self, prior: Tensor, recurrent_state: Tensor, actions: Tensor) -> Tuple[Tensor, Tensor]: + """ + One-step imagination of the next latent state. + It can be used several times to imagine trajectories in the latent space (Transition Model). + + Args: + prior (Tensor): the prior state. + recurrent_state (Tensor): the recurrent state of the recurrent model. + actions (Tensor): the actions taken by the agent. + + Returns: + The imagined prior state (Tuple[Tensor, Tensor]): the imagined prior state. + The recurrent state (Tensor). + """ + recurrent_state = self.recurrent_model(torch.cat((prior, actions), -1), recurrent_state) + _, imagined_prior = self._transition(recurrent_state) + return imagined_prior, recurrent_state + + +class PlayerDV3(nn.Module): + """ + The model of the Dreamer_v3 player. + + Args: + encoder (_FabricModule): the encoder. + recurrent_model (_FabricModule): the recurrent model. + representation_model (_FabricModule): the representation model. + actor (_FabricModule): the actor. + actions_dim (Sequence[int]): the dimension of the actions. + expl_amout (float): the exploration amout to use during training. + num_envs (int): the number of environments. + stochastic_size (int): the size of the stochastic state. + recurrent_state_size (int): the size of the recurrent state. + device (torch.device): the device to work on. + transition_model (_FabricModule): the transition model. + discrete_size (int): the dimension of a single Categorical variable in the + stochastic state (prior or posterior). + Defaults to 32. + """ + + def __init__( + self, + encoder: _FabricModule, + rssm: RSSM, + actor: _FabricModule, + actions_dim: Sequence[int], + expl_amount: float, + num_envs: int, + stochastic_size: int, + recurrent_state_size: int, + device: device = "cpu", + discrete_size: int = 32, + ) -> None: + super().__init__() + self.encoder = encoder + self.rssm = RSSM( + recurrent_model=rssm.recurrent_model.module, + representation_model=rssm.representation_model.module, + transition_model=rssm.transition_model.module, + discrete=rssm.discrete, + unimix=rssm.unimix, + ) + self.actor = actor + self.device = device + self.expl_amount = expl_amount + self.actions_dim = actions_dim + self.stochastic_size = stochastic_size + self.discrete_size = discrete_size + self.recurrent_state_size = recurrent_state_size + self.num_envs = num_envs + + @torch.no_grad() + def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: + """Initialize the states and the actions for the ended environments. + + Args: + reset_envs (Optional[Sequence[int]], optional): which environments' states to reset. + If None, then all environments' states are reset. + Defaults to None. + """ + if reset_envs is None or len(reset_envs) == 0: + self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device) + self.recurrent_state = torch.tanh( + torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device) + ) + self.stochastic_state = self.rssm._transition(self.recurrent_state, sample_state=False)[1].reshape( + 1, self.num_envs, -1 + ) + else: + self.actions[:, reset_envs] = torch.zeros_like(self.actions[:, reset_envs]) + self.recurrent_state[:, reset_envs] = torch.tanh(torch.zeros_like(self.recurrent_state[:, reset_envs])) + self.stochastic_state[:, reset_envs] = self.rssm._transition( + self.recurrent_state[:, reset_envs], sample_state=False + )[1].reshape(1, len(reset_envs), -1) + + def get_exploration_action( + self, + obs: Dict[str, Tensor], + is_continuous: bool, + mask: Optional[Dict[str, np.ndarray]] = None, + ) -> Tensor: + """ + Return the actions with a certain amount of noise for exploration. + + Args: + obs (Dict[str, Tensor]): the current observations. + is_continuous (bool): whether or not the actions are continuous. + + Returns: + The actions the agent has to perform. + """ + actions = self.get_greedy_action(obs, mask=mask) + if is_continuous: + self.actions = torch.cat(actions, -1) + if self.expl_amount > 0.0: + self.actions = torch.clip(Normal(self.actions, self.expl_amount).sample(), -1, 1) + expl_actions = [self.actions] + else: + expl_actions = [] + for act in actions: + sample = OneHotCategorical(logits=torch.zeros_like(act), validate_args=False).sample().to(self.device) + expl_actions.append( + torch.where(torch.rand(act.shape[:1], device=self.device) < self.expl_amount, sample, act) + ) + self.actions = torch.cat(expl_actions, -1) + return tuple(expl_actions) + + def get_greedy_action( + self, + obs: Dict[str, Tensor], + is_training: bool = True, + mask: Optional[Dict[str, np.ndarray]] = None, + ) -> Sequence[Tensor]: + """ + Return the greedy actions. + + Args: + obs (Dict[str, Tensor]): the current observations. + is_training (bool): whether it is training. + Default to True. + + Returns: + The actions the agent has to perform. + """ + embedded_obs = self.encoder(obs) + self.recurrent_state = self.rssm.recurrent_model( + torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state + ) + _, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs) + self.stochastic_state = self.stochastic_state.view( + *self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size + ) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + self.actions = torch.cat(actions, -1) + return actions + + +class Actor(nn.Module): + """ + The wrapper class of the Dreamer_v2 Actor model. + + Args: + latent_state_size (int): the dimension of the latent state (stochastic size + recurrent_state_size). + actions_dim (Sequence[int]): the dimension in output of the actor. + The number of actions if continuous, the dimension of the action if discrete. + is_continuous (bool): whether or not the actions are continuous. + init_std (float): the amount to sum to the input of the softplus function for the standard deviation. + Default to 5. + min_std (float): the minimum standard deviation for the actions. + Default to 0.1. + dense_units (int): the dimension of the hidden dense layers. + Default to 400. + dense_act (int): the activation function to apply after the dense layers. + Default to nn.SiLU. + distribution (str): the distribution for the action. Possible values are: `auto`, `discrete`, `normal`, + `tanh_normal` and `trunc_normal`. If `auto`, then the distribution will be `discrete` if the + space is a discrete one, `trunc_normal` otherwise. + Defaults to `auto`. + layer_norm (bool, optional): whether to apply the layer normalization. + Defaults to True. + unimix: (float, optional): the percentage of uniform distribution to inject into the categorical distribution over actions, + i.e. given some logits `l` and probabilities `p = softmax(l)`, then `p = (1 - self.unimix) * p + self.unimix * unif`, + where `unif = `1 / self.discrete`. + Defaults to 0.01. + """ + + def __init__( + self, + latent_state_size: int, + actions_dim: Sequence[int], + is_continuous: bool, + init_std: float = 0.0, + min_std: float = 0.1, + dense_units: int = 400, + dense_act: nn.Module = nn.SiLU, + mlp_layers: int = 4, + distribution: str = "auto", + layer_norm: bool = True, + unimix: float = 0.01, + ) -> None: + super().__init__() + self.distribution = distribution.lower() + if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "trunc_normal"): + raise ValueError( + "The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. " + f"Found: {self.distribution}" + ) + if self.distribution == "discrete" and is_continuous: + raise ValueError("You have choose a discrete distribution but `is_continuous` is true") + if self.distribution == "auto": + if is_continuous: + self.distribution = "trunc_normal" + else: + self.distribution = "discrete" + self.model = MLP( + input_dims=latent_state_size, + output_dim=None, + hidden_sizes=[dense_units] * mlp_layers, + activation=dense_act, + flatten_dim=None, + layer_args={"bias": not layer_norm}, + norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] + if layer_norm + else None, + ) + if is_continuous: + self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, np.sum(actions_dim) * 2)]) + else: + self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, action_dim) for action_dim in actions_dim]) + self.actions_dim = actions_dim + self.is_continuous = is_continuous + self.init_std = torch.tensor(init_std) + self.min_std = min_std + self._unimix = unimix + + def forward( + self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, np.ndarray]] = None + ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: + """ + Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), + where * means any number of dimensions including None. + + Args: + state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). + + Returns: + The tensor of the actions taken by the agent with shape (batch_size, *, num_actions). + The distribution of the actions + """ + out: Tensor = self.model(state) + pre_dist: List[Tensor] = [head(out) for head in self.mlp_heads] + if self.is_continuous: + mean, std = torch.chunk(pre_dist[0], 2, -1) + if self.distribution == "tanh_normal": + mean = 5 * torch.tanh(mean / 5) + std = F.softplus(std + self.init_std) + self.min_std + actions_dist = Normal(mean, std) + actions_dist = Independent(TransformedDistribution(actions_dist, TanhTransform()), 1) + elif self.distribution == "normal": + actions_dist = Normal(mean, std) + actions_dist = Independent(actions_dist, 1) + elif self.distribution == "trunc_normal": + std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std + dist = TruncatedNormal(torch.tanh(mean), std, -1, 1) + actions_dist = Independent(dist, 1) + if is_training: + actions = actions_dist.rsample() + else: + sample = actions_dist.sample((100,)) + log_prob = actions_dist.log_prob(sample) + actions = sample[log_prob.argmax(0)].view(1, 1, -1) + actions = [actions] + actions_dist = [actions_dist] + else: + actions_dist: List[Distribution] = [] + actions: List[Tensor] = [] + for logits in pre_dist: + actions_dist.append( + OneHotCategoricalStraightThrough(logits=self._uniform_mix(logits), validate_args=False) + ) + if is_training: + actions.append(actions_dist[-1].rsample()) + else: + actions.append(actions_dist[-1].mode) + return tuple(actions), tuple(actions_dist) + + def _uniform_mix(self, logits: Tensor) -> Tensor: + if self._unimix > 0.0: + probs = logits.softmax(dim=-1) + uniform = torch.ones_like(probs) / probs.shape[-1] + probs = (1 - self._unimix) * probs + self._unimix * uniform + logits = probs_to_logits(probs) + return logits + + +class MinedojoActor(Actor): + def __init__( + self, + latent_state_size: int, + actions_dim: Sequence[int], + is_continuous: bool, + init_std: float = 0, + min_std: float = 0.1, + dense_units: int = 400, + dense_act: nn.Module = nn.SiLU, + mlp_layers: int = 4, + distribution: str = "auto", + layer_norm: bool = True, + ) -> None: + super().__init__( + latent_state_size, + actions_dim, + is_continuous, + init_std, + min_std, + dense_units, + dense_act, + mlp_layers, + distribution, + layer_norm, + ) + + def forward( + self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, np.ndarray]] = None + ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: + """ + Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), + where * means any number of dimensions including None. + + Args: + state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). + + Returns: + The tensor of the actions taken by the agent with shape (batch_size, *, num_actions). + The distribution of the actions + """ + out: Tensor = self.model(state) + actions_logits: List[Tensor] = [head(out) for head in self.mlp_heads] + actions_dist: List[Distribution] = [] + actions: List[Tensor] = [] + functional_action = None + for i, logits in enumerate(actions_logits): + if mask is not None: + if i == 0: + logits[torch.logical_not(mask["mask_action_type"].expand_as(logits))] = -torch.inf + elif i == 1: + mask["mask_craft_smelt"] = mask["mask_craft_smelt"].expand_as(logits) + for t in range(functional_action.shape[0]): + for b in range(functional_action.shape[1]): + sampled_action = functional_action[t, b].item() + if sampled_action == 15: # Craft action + logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf + elif i == 2: + mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) + mask["mask_equip/place"] = mask["mask_equip/place"].expand_as(logits) + for t in range(functional_action.shape[0]): + for b in range(functional_action.shape[1]): + sampled_action = functional_action[t, b].item() + if sampled_action in (16, 17): # Equip/Place action + logits[t, b][torch.logical_not(mask["mask_equip/place"][t, b])] = -torch.inf + elif sampled_action == 18: # Destroy action + logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf + actions_dist.append(OneHotCategoricalStraightThrough(logits=logits)) + if is_training: + actions.append(actions_dist[-1].rsample()) + else: + actions.append(actions_dist[-1].mode) + if functional_action is None: + functional_action = actions[0].argmax(dim=-1) # [T, B] + return tuple(actions), tuple(actions_dist) + + +def build_models( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + args: DreamerV3Args, + obs_space: Dict[str, Any], + cnn_keys: Sequence[str], + mlp_keys: Sequence[str], + world_model_state: Optional[Dict[str, Tensor]] = None, + actor_state: Optional[Dict[str, Tensor]] = None, + critic_state: Optional[Dict[str, Tensor]] = None, + target_critic_state: Optional[Dict[str, Tensor]] = None, +) -> Tuple[WorldModel, _FabricModule, _FabricModule, torch.nn.Module]: + """Build the models and wrap them with Fabric. + + Args: + fabric (Fabric): the fabric object. + actions_dim (Sequence[int]): the dimension of the actions. + is_continuous (bool): whether or not the actions are continuous. + args (DreamerV3Args): the hyper-parameters of DreamerV2. + obs_space (Dict[str, Any]): the observation space. + cnn_keys (Sequence[str]): the keys of the observation space to encode through the cnn encoder. + mlp_keys (Sequence[str]): the keys of the observation space to encode through the mlp encoder. + world_model_state (Dict[str, Tensor], optional): the state of the world model. + Default to None. + actor_state: (Dict[str, Tensor], optional): the state of the actor. + Default to None. + critic_state: (Dict[str, Tensor], optional): the state of the critic. + Default to None. + target_critic_state: (Dict[str, Tensor], optional): the state of the critic. + Default to None. + + Returns: + The world model (WorldModel): composed by the encoder, rssm, observation and reward models and the continue model. + The actor (_FabricModule). + The critic (_FabricModule). + The target critic (nn.Module). + """ + if args.cnn_channels_multiplier <= 0: + raise ValueError(f"cnn_channels_multiplier must be greater than zero, given {args.cnn_channels_multiplier}") + if args.dense_units <= 0: + raise ValueError(f"dense_units must be greater than zero, given {args.dense_units}") + try: + cnn_act = getattr(nn, args.cnn_act) + except: + raise ValueError( + f"Invalid value for cnn_act, given {args.cnn_act}, " + "must be one of https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity" + ) + try: + dense_act = getattr(nn, args.dense_act) + except: + raise ValueError( + f"Invalid value for dense_act, given {args.dense_act}, " + "must be one of https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity" + ) + + # Sizes + stochastic_size = args.stochastic_size * args.discrete_size + latent_state_size = stochastic_size + args.recurrent_state_size + + # Define models + cnn_encoder = ( + CNNEncoder( + keys=cnn_keys, + input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cnn_keys], + image_size=obs_space[cnn_keys[0]].shape[-2:], + channels_multiplier=args.cnn_channels_multiplier, + layer_norm=args.layer_norm, + activation=cnn_act, + ) + if cnn_keys is not None and len(cnn_keys) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + keys=mlp_keys, + input_dims=[obs_space[k].shape[0] for k in mlp_keys], + mlp_layers=args.mlp_layers, + dense_units=args.dense_units, + activation=dense_act, + layer_norm=args.layer_norm, + ) + if mlp_keys is not None and len(mlp_keys) > 0 + else None + ) + encoder = MultiEncoder(cnn_encoder, mlp_encoder) + recurrent_model = RecurrentModel( + int(np.sum(actions_dim)) + stochastic_size, + args.recurrent_state_size, + args.dense_units, + layer_norm=args.layer_norm, + ) + representation_model = MLP( + input_dims=args.recurrent_state_size + encoder.cnn_output_dim + encoder.mlp_output_dim, + output_dim=stochastic_size, + hidden_sizes=[args.hidden_size], + activation=dense_act, + flatten_dim=None, + layer_args={"bias": not args.layer_norm}, + norm_layer=[nn.LayerNorm] if args.layer_norm else None, + norm_args=[{"normalized_shape": args.hidden_size, "eps": 1e-3}] if args.layer_norm else None, + ) + transition_model = MLP( + input_dims=args.recurrent_state_size, + output_dim=stochastic_size, + hidden_sizes=[args.hidden_size], + activation=dense_act, + flatten_dim=None, + layer_args={"bias": not args.layer_norm}, + norm_layer=[nn.LayerNorm] if args.layer_norm else None, + norm_args=[{"normalized_shape": args.hidden_size, "eps": 1e-3}] if args.layer_norm else None, + ) + rssm = RSSM( + recurrent_model.apply(init_weights), + representation_model.apply(init_weights), + transition_model.apply(init_weights), + args.discrete_size, + ) + cnn_decoder = ( + CNNDecoder( + keys=cnn_keys, + output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cnn_keys], + channels_multiplier=args.cnn_channels_multiplier, + latent_state_size=latent_state_size, + cnn_encoder_output_dim=cnn_encoder.output_dim, + image_size=obs_space[cnn_keys[0]].shape[-2:], + activation=cnn_act, + layer_norm=args.layer_norm, + ) + if cnn_keys is not None and len(cnn_keys) > 0 + else None + ) + mlp_decoder = ( + MLPDecoder( + keys=mlp_keys, + output_dims=[obs_space[k].shape[0] for k in mlp_keys], + latent_state_size=latent_state_size, + mlp_layers=args.mlp_layers, + dense_units=args.dense_units, + activation=dense_act, + layer_norm=args.layer_norm, + ) + if mlp_keys is not None and len(mlp_keys) > 0 + else None + ) + observation_model = MultiDecoder(cnn_decoder, mlp_decoder) + reward_model = MLP( + input_dims=stochastic_size + args.recurrent_state_size, + output_dim=args.bins, + hidden_sizes=[args.dense_units] * args.mlp_layers, + activation=dense_act, + flatten_dim=None, + layer_args={"bias": not args.layer_norm}, + norm_layer=[nn.LayerNorm for _ in range(args.mlp_layers)] if args.layer_norm else None, + norm_args=[{"normalized_shape": args.dense_units, "eps": 1e-3} for _ in range(args.mlp_layers)] + if args.layer_norm + else None, + ) + continue_model = MLP( + input_dims=stochastic_size + args.recurrent_state_size, + output_dim=1, + hidden_sizes=[args.dense_units] * args.mlp_layers, + activation=dense_act, + flatten_dim=None, + layer_args={"bias": not args.layer_norm}, + norm_layer=[nn.LayerNorm for _ in range(args.mlp_layers)] if args.layer_norm else None, + norm_args=[{"normalized_shape": args.dense_units, "eps": 1e-3} for _ in range(args.mlp_layers)] + if args.layer_norm + else None, + ) + world_model = WorldModel( + encoder.apply(init_weights), + rssm, + observation_model.apply(init_weights), + reward_model.apply(init_weights), + continue_model.apply(init_weights), + ) + if "minedojo" in args.env_id: + actor = MinedojoActor( + stochastic_size + args.recurrent_state_size, + actions_dim, + is_continuous, + args.actor_init_std, + args.actor_min_std, + args.dense_units, + dense_act, + args.mlp_layers, + distribution=args.actor_distribution, + layer_norm=args.layer_norm, + ) + else: + actor = Actor( + stochastic_size + args.recurrent_state_size, + actions_dim, + is_continuous, + args.actor_init_std, + args.actor_min_std, + args.dense_units, + dense_act, + args.mlp_layers, + distribution=args.actor_distribution, + layer_norm=args.layer_norm, + ) + critic = MLP( + input_dims=stochastic_size + args.recurrent_state_size, + output_dim=args.bins, + hidden_sizes=[args.dense_units] * args.mlp_layers, + activation=dense_act, + flatten_dim=None, + layer_args={"bias": not args.layer_norm}, + norm_layer=[nn.LayerNorm for _ in range(args.mlp_layers)] if args.layer_norm else None, + norm_args=[{"normalized_shape": args.dense_units, "eps": 1e-3} for _ in range(args.mlp_layers)] + if args.layer_norm + else None, + ) + actor.apply(init_weights) + critic.apply(init_weights) + + if args.hafner_initialization: + actor.mlp_heads.apply(partial(init_weights, mode="uniform")) + critic.model[-1].apply(partial(init_weights, mode="zero")) + rssm.transition_model.model[-1].apply(partial(init_weights, mode="uniform")) + rssm.representation_model.model[-1].apply(partial(init_weights, mode="uniform")) + world_model.reward_model.model[-1].apply(partial(init_weights, mode="zero")) + world_model.continue_model.model[-1].apply(partial(init_weights, mode="uniform")) + if mlp_decoder is not None: + mlp_decoder.heads.apply(partial(init_weights, mode="uniform")) + if cnn_decoder is not None: + cnn_decoder.model[-1].model[-1].apply(partial(init_weights, mode="uniform")) + + # Load models from checkpoint + if world_model_state: + world_model.load_state_dict(world_model_state) + if actor_state: + actor.load_state_dict(actor_state) + if critic_state: + critic.load_state_dict(critic_state) + + # Setup models with Fabric + world_model.encoder = fabric.setup_module(world_model.encoder) + world_model.observation_model = fabric.setup_module(world_model.observation_model) + world_model.reward_model = fabric.setup_module(world_model.reward_model) + world_model.rssm.recurrent_model = fabric.setup_module(world_model.rssm.recurrent_model) + world_model.rssm.representation_model = fabric.setup_module(world_model.rssm.representation_model) + world_model.rssm.transition_model = fabric.setup_module(world_model.rssm.transition_model) + if world_model.continue_model: + world_model.continue_model = fabric.setup_module(world_model.continue_model) + actor = fabric.setup_module(actor) + critic = fabric.setup_module(critic) + target_critic = copy.deepcopy(critic.module) + if target_critic_state: + target_critic.load_state_dict(target_critic_state) + + return world_model, actor, critic, target_critic diff --git a/sheeprl/algos/dreamer_v3/args.py b/sheeprl/algos/dreamer_v3/args.py new file mode 100644 index 00000000..aac2d492 --- /dev/null +++ b/sheeprl/algos/dreamer_v3/args.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +from typing import List, Optional + +from sheeprl.algos.dreamer_v2.args import DreamerV2Args +from sheeprl.utils.parser import Arg + + +@dataclass +class DreamerV3Args(DreamerV2Args): + env_id: str = Arg(default="dmc_walker_walk", help="the id of the environment") + + # Experiment settings + per_rank_batch_size: int = Arg(default=16, help="the batch size for each rank") + per_rank_sequence_length: int = Arg(default=64, help="the sequence length for each rank") + total_steps: int = Arg(default=int(5e6), help="total timesteps of the experiments") + capture_video: bool = Arg( + default=False, help="whether to capture videos of the agent performances (check out `videos` folder)" + ) + buffer_size: int = Arg(default=int(1e6), help="the size of the buffer") + learning_starts: int = Arg(default=int(1024), help="timestep to start learning") + pretrain_steps: int = Arg(default=1, help="the number of pretrain steps") + gradient_steps: int = Arg(default=1, help="the number of gradient steps per each environment interaction") + train_every: int = Arg(default=5, help="the number of steps between one training and another") + checkpoint_every: int = Arg(default=-1, help="how often to make the checkpoint, -1 to deactivate the checkpoint") + checkpoint_buffer: bool = Arg(default=False, help="whether or not to save the buffer during the checkpoint") + checkpoint_path: Optional[str] = Arg(default=None, help="the path of the checkpoint from which you want to restart") + + # Agent settings + world_lr: float = Arg(default=1e-4, help="the learning rate of the optimizer of the world model") + actor_lr: float = Arg(default=3e-5, help="the learning rate of the optimizer of the actor") + critic_lr: float = Arg(default=3e-5, help="the learning rate of the optimizer of the critic") + horizon: int = Arg(default=15, help="the number of imagination step") + gamma: float = Arg(default=(1 - 1 / 333), help="the discount factor gamma") + lmbda: float = Arg(default=0.95, help="the lambda for the TD lambda values") + use_continues: bool = Arg(default=True, help="wheter or not to use the continue predictor") + stochastic_size: int = Arg(default=32, help="the dimension of the stochastic state") + discrete_size: int = Arg(default=32, help="the dimension of the discrete state") + hidden_size: int = Arg(default=512, help="the hidden size for the transition and representation model") + recurrent_state_size: int = Arg(default=512, help="the dimension of the recurrent state") + kl_dynamic: float = Arg(default=0.5, help="the regularizer for the KL dynamic loss") + kl_representation: float = Arg(default=0.1, help="the regularizer for the KL representation loss") + kl_free_nats: float = Arg(default=1.0, help="the minimum value for the kl divergence") + kl_regularizer: float = Arg(default=1.0, help="the scale factor for the kl divergence") + continue_scale_factor: float = Arg(default=1.0, help="the scale factor for the continue loss") + actor_ent_coef: float = Arg(default=3e-4, help="the entropy coefficient for the actor loss") + actor_init_std: float = Arg( + default=0.0, help="the amout to sum to the input of the function of the standard deviation of the actions" + ) + actor_min_std: float = Arg(default=0.1, help="the minimum standard deviation for the actions") + actor_distribution: str = Arg( + default="auto", + help="the actor distribution. One can chose between `auto`, `discrete` (one-hot categorical), " + "`normal`, `tanh_normal` and `trunc_normal`. If `auto`, then the distribution will be a one-hot categorical if " + "the action space is discrete, otherwise it will be a truncated normal distribution.", + ) + world_clip_gradients: float = Arg(default=1000.0, help="how much to clip the gradient norms") + actor_clip_gradients: float = Arg(default=100.0, help="how much to clip the gradient norms") + critic_clip_gradients: float = Arg(default=100.0, help="how much to clip the gradient norms") + dense_units: int = Arg(default=512, help="the number of units in dense layers, must be greater than zero") + mlp_layers: int = Arg( + default=2, help="the number of MLP layers for every model: actor, critic, continue and reward" + ) + cnn_channels_multiplier: int = Arg(default=32, help="cnn width multiplication factor, must be greater than zero") + dense_act: str = Arg( + default="SiLU", + help="the activation function for the dense layers, one of https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity (case sensitive, without 'nn.')", + ) + cnn_act: str = Arg( + default="SiLU", + help="the activation function for the convolutional layers, one of https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity (case sensitive, without 'nn.')", + ) + critic_target_network_update_freq: int = Arg(default=1, help="the frequency to update the target critic network") + layer_norm: bool = Arg(default=True, help="whether to apply nn.LayerNorm after every Linear/Conv2D/ConvTranspose2D") + critic_tau: float = Arg( + default=0.02, + help="tau value to be used for the EMA critic update: `critic_param * tau + (1 - tau) * target_critic_param`", + ) + unimix: float = Arg( + default=0.01, help="whether to use a mix of uniform and categorical for the stochastic state distribution" + ) + hafner_initialization: bool = Arg( + default=True, + help="whether to initialize the models as in the original Hafner code, i.e. " + "every model is initialized with a standard Xavier Normal initialization; every last layer before a distribution is " + "initialized with a Xavier Uniform distribution; the last critic and reward model layer are initialized to zero.", + ) + + # Environment settings + expl_amount: float = Arg(default=0.0, help="the exploration amout to add to the actions") + expl_decay: bool = Arg(default=False, help="whether or not to decrement the exploration amount") + expl_min: float = Arg(default=0.0, help="the minimum value for the exploration amout") + max_step_expl_decay: int = Arg(default=0, help="the maximum number of decay steps") + action_repeat: int = Arg(default=4, help="the number of times an action is repeated") + max_episode_steps: int = Arg( + default=108000, + help="the maximum duration in terms of number of steps of an episode, -1 to disable. " + "This value will be divided by the `action_repeat` value during the environment creation.", + ) + atari_noop_max: int = Arg( + default=30, + help="for No-op reset in Atari environment, the max number no-ops actions are taken at reset, to turn off, set to 0", + ) + clip_rewards: bool = Arg(default=False, help="whether or not to clip rewards using tanh") + grayscale_obs: bool = Arg(default=False, help="whether or not to the observations are grayscale") + cnn_keys: Optional[List[str]] = Arg( + default=None, help="a list of observation keys to be processed by the CNN encoder" + ) + mlp_keys: Optional[List[str]] = Arg( + default=None, help="a list of observation keys to be processed by the MLP encoder" + ) + mine_min_pitch: int = Arg(default=-60, help="The minimum value of pitch in Minecraft environmnets.") + mine_max_pitch: int = Arg(default=60, help="The maximum value of pitch in Minecraft environmnets.") + mine_start_position: Optional[List[str]] = Arg( + default=None, help="The starting position of the agent in Minecraft environment. (x, y, z, pitch, yaw)" + ) + minerl_dense: bool = Arg(default=False, help="whether or not the task has dense reward") + minerl_extreme: bool = Arg(default=False, help="whether or not the task is extreme") + mine_break_speed: int = Arg(default=100, help="the break speed multiplier of Minecraft environments") + mine_sticky_attack: int = Arg(default=30, help="the sticky value for the attack action") + mine_sticky_jump: int = Arg(default=10, help="the sticky value for the jump action") + + # Returns normalization: returns are rescaled by an exponentially + # decaying average of the range from their 5th to their 95th batch percentile + moments_decay: float = Arg(default=0.99, help="exponential moving average decay factor") + moment_max: float = Arg( + default=1.0, help="max value to be applied in `max(moment_max, Per(R_t, 95th) - Per(R_t, 5th))`" + ) + moments_percentile_low: float = Arg(default=0.05, help="lower percentile") + moments_percentile_high: float = Arg(default=0.95, help="higher percentile") + + # Two-hot encoding bins + bins: int = Arg(default=255, help="the number of bins to two-hot-encode rewards and critic values") diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py new file mode 100644 index 00000000..69519f9a --- /dev/null +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -0,0 +1,715 @@ +"""Dreamer-V3 implementation from [https://arxiv.org/abs/2301.04104](https://arxiv.org/abs/2301.04104) +Adapted from the original implementation from https://github.com/danijar/dreamerv3 +""" + +import copy +import os +import pathlib +import time +from dataclasses import asdict +from functools import partial +from typing import Dict, Sequence + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from lightning.fabric import Fabric +from lightning.fabric.fabric import _is_using_cli +from lightning.fabric.wrappers import _FabricModule +from tensordict import TensorDict +from tensordict.tensordict import TensorDictBase +from torch import Tensor +from torch.distributions import Bernoulli, Distribution, Independent, OneHotCategorical +from torch.optim import Adam, Optimizer +from torch.utils.data import BatchSampler +from torchmetrics import MeanMetric + +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models +from sheeprl.algos.dreamer_v3.args import DreamerV3Args +from sheeprl.algos.dreamer_v3.loss import reconstruction_loss +from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values +from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.envs.wrappers import RestartOnException +from sheeprl.utils.callback import CheckpointCallback +from sheeprl.utils.distribution import MSEDistribution, SymlogDistribution, TwoHotEncodingDistribution +from sheeprl.utils.env import make_dict_env +from sheeprl.utils.logger import create_tensorboard_logger +from sheeprl.utils.metric import MetricAggregator +from sheeprl.utils.parser import HfArgumentParser +from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.utils import polynomial_decay + +# Decomment the following two lines if you cannot start an experiment with DMC environments +# os.environ["PYOPENGL_PLATFORM"] = "" +# os.environ["MUJOCO_GL"] = "osmesa" + + +def train( + fabric: Fabric, + world_model: WorldModel, + actor: _FabricModule, + critic: _FabricModule, + target_critic: torch.nn.Module, + world_optimizer: Optimizer, + actor_optimizer: Optimizer, + critic_optimizer: Optimizer, + data: TensorDictBase, + aggregator: MetricAggregator, + args: DreamerV3Args, + is_continuous: bool, + cnn_keys: Sequence[str], + mlp_keys: Sequence[str], + actions_dim: Sequence[int], + moments: Moments, +) -> None: + """Runs one-step update of the agent. + + Args: + fabric (Fabric): the fabric instance. + world_model (_FabricModule): the world model wrapped with Fabric. + actor (_FabricModule): the actor model wrapped with Fabric. + critic (_FabricModule): the critic model wrapped with Fabric. + target_critic (nn.Module): the target critic model. + world_optimizer (Optimizer): the world optimizer. + actor_optimizer (Optimizer): the actor optimizer. + critic_optimizer (Optimizer): the critic optimizer. + data (TensorDictBase): the batch of data to use for training. + aggregator (MetricAggregator): the aggregator to print the metrics. + args (DreamerV3Args): the configs. + cnn_keys (Sequence[str]): the cnn keys to encode/decode. + mlp_keys (Sequence[str]): the mlp keys to encode/decode. + actions_dim (Sequence[int]): the actions dimension. + """ + # The environment interaction goes like this: + # Actions: a0 a1 a2 a4 + # ^ \ ^ \ ^ \ ^ + # / \ / \ / \ / + # / v / v / v / + # Observations: o0 o1 o2 o3 + # Rewards: 0 r1 r2 r3 + # Dones: 0 d1 d2 d3 + # Is-first 1 i1 i2 i3 + + batch_size = args.per_rank_batch_size + sequence_length = args.per_rank_sequence_length + device = fabric.device + batch_obs = {k: data[k] / 255.0 for k in cnn_keys} + batch_obs.update({k: data[k] for k in mlp_keys}) + data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :]) + + # Given how the environment interaction works, we remove the last actions + # and add the first one as the zero action + batch_actions = torch.cat((torch.zeros_like(data["actions"][:1]), data["actions"][:-1]), dim=0) + + # Dynamic Learning + stoch_state_size = args.stochastic_size * args.discrete_size + recurrent_state = torch.zeros(1, batch_size, args.recurrent_state_size, device=device) + posterior = torch.zeros(1, batch_size, args.stochastic_size, args.discrete_size, device=device) + recurrent_states = torch.empty(sequence_length, batch_size, args.recurrent_state_size, device=device) + priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) + posteriors = torch.empty(sequence_length, batch_size, args.stochastic_size, args.discrete_size, device=device) + posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) + + # Embed observations from the environment + embedded_obs = world_model.encoder(batch_obs) + + for i in range(0, sequence_length): + recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic( + posterior, recurrent_state, batch_actions[i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1] + ) + recurrent_states[i] = recurrent_state + priors_logits[i] = prior_logits + posteriors[i] = posterior + posteriors_logits[i] = posterior_logits + latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1) + + # Compute predictions for the observations + reconstructed_obs: Dict[str, torch.Tensor] = world_model.observation_model(latent_states) + + # Compute the distribution over the reconstructed observations + po = {k: MSEDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) for k in cnn_keys} + po.update({k: SymlogDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) for k in mlp_keys}) + + # Compute the distribution over the rewards + pr = TwoHotEncodingDistribution(world_model.reward_model(latent_states), dims=1) + + # Compute the distribution over the terminal steps, if required + pc = Independent(Bernoulli(logits=world_model.continue_model(latent_states), validate_args=False), 1) + continue_targets = 1 - data["dones"] + + # Reshape posterior and prior logits to shape [B, T, 32, 32] + priors_logits = priors_logits.view(*priors_logits.shape[:-1], args.stochastic_size, args.discrete_size) + posteriors_logits = posteriors_logits.view(*posteriors_logits.shape[:-1], args.stochastic_size, args.discrete_size) + + # World model optimization step. Eq. 4 in the paper + world_optimizer.zero_grad(set_to_none=True) + rec_loss, kl, state_loss, reward_loss, observation_loss, continue_loss = reconstruction_loss( + po, + batch_obs, + pr, + data["rewards"], + priors_logits, + posteriors_logits, + args.kl_dynamic, + args.kl_representation, + args.kl_free_nats, + args.kl_regularizer, + pc, + continue_targets, + args.continue_scale_factor, + ) + fabric.backward(rec_loss) + if args.world_clip_gradients is not None and args.world_clip_gradients > 0: + world_model_grads = fabric.clip_gradients( + module=world_model, optimizer=world_optimizer, max_norm=args.world_clip_gradients, error_if_nonfinite=False + ) + world_optimizer.step() + aggregator.update("Grads/world_model", world_model_grads.mean().detach()) + aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/observation_loss", observation_loss.detach()) + aggregator.update("Loss/reward_loss", reward_loss.detach()) + aggregator.update("Loss/state_loss", state_loss.detach()) + aggregator.update("Loss/continue_loss", continue_loss.detach()) + aggregator.update("State/kl", kl.mean().detach()) + aggregator.update( + "State/post_entropy", + Independent(OneHotCategorical(logits=posteriors_logits.detach()), 1).entropy().mean().detach(), + ) + aggregator.update( + "State/prior_entropy", + Independent(OneHotCategorical(logits=priors_logits.detach()), 1).entropy().mean().detach(), + ) + + # Behaviour Learning + imagined_prior = posteriors.detach().reshape(1, -1, stoch_state_size) + recurrent_state = recurrent_states.detach().reshape(1, -1, args.recurrent_state_size) + imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1) + imagined_trajectories = torch.empty( + args.horizon + 1, + batch_size * sequence_length, + stoch_state_size + args.recurrent_state_size, + device=device, + ) + imagined_trajectories[0] = imagined_latent_state + imagined_actions = torch.empty( + args.horizon + 1, + batch_size * sequence_length, + data["actions"].shape[-1], + device=device, + ) + actions = torch.cat(actor(imagined_latent_state.detach())[0], dim=-1) + imagined_actions[0] = actions + + # The imagination goes like this, with H=3: + # Actions: a'0 a'1 a'2 a'4 + # ^ \ ^ \ ^ \ ^ + # / \ / \ / \ / + # / \ / \ / \ / + # States: z0 ---> z'1 ---> z'2 ---> z'3 + # Rewards: r'0 r'1 r'2 r'3 + # Values: v'0 v'1 v'2 v'3 + # Lambda-values: l'1 l'2 l'3 + # Continues: c0 c'1 c'2 c'3 + # where z0 comes from the posterior, while z'i is the imagined states (prior) + + # Imagine trajectories in the latent space + for i in range(1, args.horizon + 1): + imagined_prior, recurrent_state = world_model.rssm.imagination(imagined_prior, recurrent_state, actions) + imagined_prior = imagined_prior.view(1, -1, stoch_state_size) + imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1) + imagined_trajectories[i] = imagined_latent_state + actions = torch.cat(actor(imagined_latent_state.detach())[0], dim=-1) + imagined_actions[i] = actions + + # Predict values, rewards and continues + predicted_values = TwoHotEncodingDistribution(critic(imagined_trajectories), dims=1).mean + predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean + continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mode + true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_done, continues[1:])) + + # Estimate lambda-values + lambda_values = compute_lambda_values( + predicted_rewards[1:], + predicted_values[1:], + continues[1:] * args.gamma, + lmbda=args.lmbda, + ) + + # Compute the discounts to multiply the lambda values to + with torch.no_grad(): + discount = torch.cumprod(continues * args.gamma, dim=0) / args.gamma + + # Actor optimization step. Eq. 11 from the paper + # Given the following diagram, with H=3 + # Actions: [a'0] [a'1] [a'2] a'3 + # ^ \ ^ \ ^ \ ^ + # / \ / \ / \ / + # / \ / \ / \ / + # States: [z0] -> [z'1] -> [z'2] -> z'3 + # Values: [v'0] [v'1] [v'2] v'3 + # Lambda-values: [l'1] [l'2] [l'3] + # Entropies: [e'0] [e'1] [e'2] + actor_optimizer.zero_grad(set_to_none=True) + policies: Sequence[Distribution] = actor(imagined_trajectories.detach())[1] + + baseline = predicted_values[:-1] + offset, invscale = moments(lambda_values) + normed_lambda_values = (lambda_values - offset) / invscale + normed_baseline = (baseline - offset) / invscale + advantage = normed_lambda_values - normed_baseline + if is_continuous: + objective = advantage + else: + objective = ( + torch.stack( + [ + p.log_prob(imgnd_act.detach()).unsqueeze(-1)[:-1] + for p, imgnd_act in zip(policies, torch.split(imagined_actions, actions_dim, dim=-1)) + ], + dim=-1, + ).sum(dim=-1) + * advantage.detach() + ) + try: + entropy = args.actor_ent_coef * torch.stack([p.entropy() for p in policies], -1).sum(dim=-1) + except NotImplementedError: + entropy = torch.zeros_like(objective) + policy_loss = -torch.mean(discount[:-1].detach() * (objective + entropy.unsqueeze(dim=-1)[:-1])) + fabric.backward(policy_loss) + if args.actor_clip_gradients is not None and args.actor_clip_gradients > 0: + actor_grads = fabric.clip_gradients( + module=actor, optimizer=actor_optimizer, max_norm=args.actor_clip_gradients, error_if_nonfinite=False + ) + actor_optimizer.step() + aggregator.update("Grads/actor", actor_grads.mean().detach()) + aggregator.update("Loss/policy_loss", policy_loss.detach()) + + # Predict the values + qv = TwoHotEncodingDistribution(critic(imagined_trajectories.detach()[:-1]), dims=1) + predicted_target_values = TwoHotEncodingDistribution( + target_critic(imagined_trajectories.detach()[:-1]), dims=1 + ).mean + + # Critic optimization. Eq. 10 in the paper + critic_optimizer.zero_grad(set_to_none=True) + value_loss = -qv.log_prob(lambda_values.detach()) + value_loss = value_loss - qv.log_prob(predicted_target_values.detach()) + value_loss = torch.mean(value_loss * discount[:-1].squeeze(-1)) + + fabric.backward(value_loss) + if args.critic_clip_gradients is not None and args.critic_clip_gradients > 0: + critic_grads = fabric.clip_gradients( + module=critic, optimizer=critic_optimizer, max_norm=args.critic_clip_gradients, error_if_nonfinite=False + ) + critic_optimizer.step() + aggregator.update("Grads/critic", critic_grads.mean().detach()) + aggregator.update("Loss/value_loss", value_loss.detach()) + + # Reset everything + actor_optimizer.zero_grad(set_to_none=True) + critic_optimizer.zero_grad(set_to_none=True) + world_optimizer.zero_grad(set_to_none=True) + + +@register_algorithm() +def main(): + parser = HfArgumentParser(DreamerV3Args) + args: DreamerV3Args = parser.parse_args_into_dataclasses()[0] + + # These arguments cannot be changed + args.screen_size = 64 + args.frame_stack = -1 + + # Initialize Fabric + fabric = Fabric(callbacks=[CheckpointCallback()]) + if not _is_using_cli(): + fabric.launch() + rank = fabric.global_rank + device = fabric.device + fabric.seed_everything(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + if args.checkpoint_path: + state = fabric.load(args.checkpoint_path) + state["args"]["checkpoint_path"] = args.checkpoint_path + args = DreamerV3Args(**state["args"]) + args.per_rank_batch_size = state["batch_size"] // fabric.world_size + ckpt_path = pathlib.Path(args.checkpoint_path) + + # Create TensorBoardLogger. This will create the logger only on the + # rank-0 process + logger, log_dir = create_tensorboard_logger(fabric, args, "dreamer_v3") + if fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(asdict(args)) + + # Environment setup + vectorized_env = gym.vector.SyncVectorEnv if args.sync_env else gym.vector.AsyncVectorEnv + envs = vectorized_env( + [ + partial( + RestartOnException, + env_fn=make_dict_env( + args.env_id, + args.seed + rank * args.num_envs, + rank, + args, + logger.log_dir if rank == 0 else None, + "train", + ), + ) + for i in range(args.num_envs) + ], + ) + action_space = envs.single_action_space + observation_space = envs.single_observation_space + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + clip_rewards_fn = lambda r: torch.tanh(r) if args.clip_rewards else r + cnn_keys = [] + mlp_keys = [] + if isinstance(observation_space, gym.spaces.Dict): + cnn_keys = [] + for k, v in observation_space.spaces.items(): + if args.cnn_keys and k in args.cnn_keys: + if len(v.shape) == 3: + cnn_keys.append(k) + else: + fabric.print( + f"Found a CNN key which is not an image: `{k}` of shape {v.shape}. " + "Try to transform the observation from the environment into a 3D image" + ) + mlp_keys = [] + for k, v in observation_space.spaces.items(): + if args.mlp_keys and k in args.mlp_keys: + if len(v.shape) == 1: + mlp_keys.append(k) + else: + fabric.print( + f"Found an MLP key which is not a vector: `{k}` of shape {v.shape}. " + "Try to flatten the observation from the environment" + ) + else: + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cnn_keys == [] and mlp_keys == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: `--cnn_keys rgb` or `--mlp_keys state` " + ) + fabric.print("CNN keys:", cnn_keys) + fabric.print("MLP keys:", mlp_keys) + obs_keys = cnn_keys + mlp_keys + + world_model, actor, critic, target_critic = build_models( + fabric, + actions_dim, + is_continuous, + args, + observation_space, + cnn_keys, + mlp_keys, + state["world_model"] if args.checkpoint_path else None, + state["actor"] if args.checkpoint_path else None, + state["critic"] if args.checkpoint_path else None, + state["target_critic"] if args.checkpoint_path else None, + ) + player = PlayerDV3( + world_model.encoder.module, + world_model.rssm, + actor.module, + actions_dim, + args.expl_amount, + args.num_envs, + args.stochastic_size, + args.recurrent_state_size, + fabric.device, + discrete_size=args.discrete_size, + ) + + # Optimizers + world_optimizer = Adam(world_model.parameters(), lr=args.world_lr, weight_decay=0.0, eps=1e-8) + actor_optimizer = Adam(actor.parameters(), lr=args.actor_lr, weight_decay=0.0, eps=1e-5) + critic_optimizer = Adam(critic.parameters(), lr=args.critic_lr, weight_decay=0.0, eps=1e-5) + if args.checkpoint_path: + world_optimizer.load_state_dict(state["world_optimizer"]) + actor_optimizer.load_state_dict(state["actor_optimizer"]) + critic_optimizer.load_state_dict(state["critic_optimizer"]) + world_optimizer, actor_optimizer, critic_optimizer = fabric.setup_optimizers( + world_optimizer, actor_optimizer, critic_optimizer + ) + moments = Moments( + fabric, args.moments_decay, args.moment_max, args.moments_percentile_low, args.moments_percentile_high + ) + + # Metrics + with device: + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=False), + "Game/ep_len_avg": MeanMetric(sync_on_compute=False), + "Time/step_per_second": MeanMetric(sync_on_compute=False), + "Loss/reconstruction_loss": MeanMetric(sync_on_compute=False), + "Loss/value_loss": MeanMetric(sync_on_compute=False), + "Loss/policy_loss": MeanMetric(sync_on_compute=False), + "Loss/observation_loss": MeanMetric(sync_on_compute=False), + "Loss/reward_loss": MeanMetric(sync_on_compute=False), + "Loss/state_loss": MeanMetric(sync_on_compute=False), + "Loss/continue_loss": MeanMetric(sync_on_compute=False), + "State/kl": MeanMetric(sync_on_compute=False), + "State/post_entropy": MeanMetric(sync_on_compute=False), + "State/prior_entropy": MeanMetric(sync_on_compute=False), + "Params/exploration_amout": MeanMetric(sync_on_compute=False), + "Grads/world_model": MeanMetric(sync_on_compute=False), + "Grads/actor": MeanMetric(sync_on_compute=False), + "Grads/critic": MeanMetric(sync_on_compute=False), + } + ) + aggregator.to(fabric.device) + + # Local data + buffer_size = args.buffer_size // int(args.num_envs * fabric.world_size) if not args.dry_run else 2 + rb = AsyncReplayBuffer( + buffer_size, + args.num_envs, + device="cpu", + memmap=args.memmap_buffer, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + sequential=True, + ) + if args.checkpoint_path and args.checkpoint_buffer: + if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + rb = state["rb"][fabric.global_rank] + elif isinstance(state["rb"], AsyncReplayBuffer): + rb = state["rb"] + else: + raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + step_data = TensorDict({}, batch_size=[args.num_envs], device="cpu") + expl_decay_steps = state["expl_decay_steps"] if args.checkpoint_path else 0 + + # Global variables + start_time = time.perf_counter() + start_step = state["global_step"] // fabric.world_size if args.checkpoint_path else 1 + single_global_step = int(args.num_envs * fabric.world_size) + step_before_training = args.train_every // single_global_step + num_updates = int(args.total_steps // single_global_step) if not args.dry_run else 1 + learning_starts = args.learning_starts // single_global_step if not args.dry_run else 0 + if args.checkpoint_path and not args.checkpoint_buffer: + learning_starts += start_step + max_step_expl_decay = args.max_step_expl_decay // (args.gradient_steps * fabric.world_size) + if args.checkpoint_path: + player.expl_amount = polynomial_decay( + expl_decay_steps, + initial=args.expl_amount, + final=args.expl_min, + max_decay_steps=max_step_expl_decay, + ) + + # Get the first environment observation and start the optimization + o = envs.reset(seed=args.seed)[0] + obs = {} + for k in obs_keys: + torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape[1:]) + if k in mlp_keys: + # Images stay uint8 to save space + torch_obs = torch_obs.float() + step_data[k] = torch_obs + obs[k] = torch_obs + step_data["dones"] = torch.zeros(args.num_envs, 1).float() + step_data["rewards"] = torch.zeros(args.num_envs, 1).float() + step_data["is_first"] = torch.ones_like(step_data["dones"]).float() + player.init_states() + + gradient_steps = 0 + for global_step in range(start_step, num_updates + 1): + # Sample an action given the observation received by the environment + if global_step <= learning_starts and args.checkpoint_path is None and "minedojo" not in args.env_id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cnn_keys: + preprocessed_obs[k] = v[None, ...].to(device) / 255.0 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() + else: + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + step_data["actions"] = torch.from_numpy(actions).view(args.num_envs, -1).float() + rb.add(step_data[None, ...]) + + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) + + step_data["is_first"] = torch.zeros_like(step_data["dones"]) + if "restart_on_exception" in infos: + for i, agent_roe in enumerate(infos["restart_on_exception"]): + if agent_roe and not dones[i]: + last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size + rb.buffer[i]["dones"][last_inserted_idx] = torch.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) + rb.buffer[i]["is_first"][last_inserted_idx] = torch.zeros_like( + rb.buffer[i]["is_first"][last_inserted_idx] + ) + step_data["is_first"][i] = torch.ones_like(step_data["is_first"][i]) + + if "final_info" in infos: + for i, agent_final_info in enumerate(infos["final_info"]): + if agent_final_info is not None and "episode" in agent_final_info: + fabric.print( + f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" + ) + aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) + aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + + # Save the real next observation + real_next_obs = copy.deepcopy(o) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + + next_obs: Dict[str, Tensor] = {} + for k in real_next_obs.keys(): # [N_envs, N_obs] + if k in obs_keys: + next_obs[k] = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape[1:]) + step_data[k] = next_obs[k] + if k in mlp_keys: + next_obs[k] = next_obs[k].float() + step_data[k] = step_data[k].float() + + # next_obs becomes the new obs + obs = next_obs + + rewards = torch.from_numpy(rewards).view(args.num_envs, -1).float() + dones = torch.from_numpy(dones).view(args.num_envs, -1).float() + step_data["dones"] = dones + step_data["rewards"] = clip_rewards_fn(rewards) + + dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + for k in real_next_obs.keys(): + if k in obs_keys: + reset_data[k] = real_next_obs[k][dones_idxes] + if k in mlp_keys: + reset_data[k] = reset_data[k].float() + reset_data["dones"] = torch.ones(reset_envs, 1).float() + reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() + reset_data["rewards"] = step_data["rewards"][dones_idxes].float() + reset_data["is_first"] = torch.zeros_like(reset_data["dones"]).float() + rb.add(reset_data[None, ...], dones_idxes) + + # Reset already inserted step data + step_data["rewards"][dones_idxes] = torch.zeros_like(reset_data["rewards"]).float() + step_data["dones"][dones_idxes] = torch.zeros_like(step_data["dones"][dones_idxes]).float() + step_data["is_first"][dones_idxes] = torch.ones_like(step_data["is_first"][dones_idxes]).float() + player.init_states(dones_idxes) + + step_before_training -= 1 + + # Train the agent + if global_step >= learning_starts and step_before_training <= 0: + fabric.barrier() + local_data = rb.sample( + args.per_rank_batch_size, + sequence_length=args.per_rank_sequence_length, + n_samples=args.pretrain_steps if global_step == learning_starts else args.gradient_steps, + ).to(device) + distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + for i in distributed_sampler: + if gradient_steps % args.critic_target_network_update_freq == 0: + tau = 1 if gradient_steps == 0 else args.critic_tau + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(args.per_rank_sequence_length, args.per_rank_batch_size), + aggregator, + args, + is_continuous, + cnn_keys, + mlp_keys, + actions_dim, + moments, + ) + gradient_steps += 1 + step_before_training = args.train_every // single_global_step + if args.expl_decay: + expl_decay_steps += 1 + player.expl_amount = polynomial_decay( + expl_decay_steps, + initial=args.expl_amount, + final=args.expl_min, + max_decay_steps=max_step_expl_decay, + ) + aggregator.update("Params/exploration_amout", player.expl_amount) + aggregator.update("Time/step_per_second", int(global_step / (time.perf_counter() - start_time))) + fabric.log_dict(aggregator.compute(), global_step) + aggregator.reset() + + # Checkpoint Model + if ( + (args.checkpoint_every > 0 and global_step % args.checkpoint_every == 0) + or args.dry_run + or global_step == num_updates + ): + state = { + "world_model": world_model.state_dict(), + "actor": actor.state_dict(), + "critic": critic.state_dict(), + "target_critic": target_critic.state_dict(), + "world_optimizer": world_optimizer.state_dict(), + "actor_optimizer": actor_optimizer.state_dict(), + "critic_optimizer": critic_optimizer.state_dict(), + "expl_decay_steps": expl_decay_steps, + "args": asdict(args), + "moments": moments.state_dict(), + "global_step": global_step * fabric.world_size, + "batch_size": args.per_rank_batch_size * fabric.world_size, + } + ckpt_path = log_dir + f"/checkpoint/ckpt_{global_step}_{fabric.global_rank}.ckpt" + fabric.call( + "on_checkpoint_coupled", + fabric=fabric, + ckpt_path=ckpt_path, + state=state, + replay_buffer=rb if args.checkpoint_buffer else None, + ) + + envs.close() + if fabric.is_global_zero: + test(player, fabric, args, cnn_keys, mlp_keys, sample_actions=True) + + +if __name__ == "__main__": + main() diff --git a/sheeprl/algos/dreamer_v3/loss.py b/sheeprl/algos/dreamer_v3/loss.py new file mode 100644 index 00000000..48dc4d5e --- /dev/null +++ b/sheeprl/algos/dreamer_v3/loss.py @@ -0,0 +1,86 @@ +from typing import Dict, Optional, Tuple + +import torch +from torch import Tensor +from torch.distributions import Distribution, Independent, OneHotCategoricalStraightThrough +from torch.distributions.kl import kl_divergence + + +def reconstruction_loss( + po: Dict[str, Distribution], + observations: Tensor, + pr: Distribution, + rewards: Tensor, + priors_logits: Tensor, + posteriors_logits: Tensor, + kl_dynamic: float = 0.5, + kl_representation: float = 0.1, + kl_free_nats: float = 1.0, + kl_regularizer: float = 1.0, + pc: Optional[Distribution] = None, + continue_targets: Optional[Tensor] = None, + continue_scale_factor: float = 1.0, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Compute the reconstruction loss as described in Eq. 2 in [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + + Args: + po (Dict[str, Distribution]): the distribution returned by the observation_model (decoder). + observations (Tensor): the observations provided by the environment. + pr (Distribution): the reward distribution returned by the reward_model. + rewards (Tensor): the rewards obtained by the agent during the "Environment interaction" phase. + priors_logits (Tensor): the logits of the prior. + posteriors_logits (Tensor): the logits of the posterior. + kl_dynamic (float): the kl-balancing dynamic loss regularizer. + Defaults to 0.5. + kl_balancing_alpha (float): the kl-balancing representation loss regularizer. + Defaults to 0.1. + kl_free_nats (float): lower bound of the KL divergence. + Default to 1.0. + kl_regularizer (float): scale factor of the KL divergence. + Default to 1.0. + pc (Bernoulli, optional): the predicted Bernoulli distribution of the terminal steps. + 0s for the entries that are relative to a terminal step, 1s otherwise. + Default to None. + continue_targets (Tensor, optional): the targets for the discount predictor. Those are normally computed + as `(1 - data["dones"]) * args.gamma`. + Default to None. + continue_scale_factor (float): the scale factor for the continue loss. + Default to 10. + + Returns: + observation_loss (Tensor): the value of the observation loss. + KL divergence (Tensor): the KL divergence between the posterior and the prior. + reward_loss (Tensor): the value of the reward loss. + state_loss (Tensor): the value of the state loss. + continue_loss (Tensor): the value of the continue loss (0 if it is not computed). + reconstruction_loss (Tensor): the value of the overall reconstruction loss. + """ + device = rewards.device + observation_loss = -sum([po[k].log_prob(observations[k]) for k in po.keys()]) + reward_loss = -pr.log_prob(rewards) + # KL balancing + kl_free_nats = torch.tensor([kl_free_nats], device=device) + dyn_loss = kl = kl_divergence( + Independent(OneHotCategoricalStraightThrough(logits=posteriors_logits.detach(), validate_args=False), 1), + Independent(OneHotCategoricalStraightThrough(logits=priors_logits, validate_args=False), 1), + ) + dyn_loss = kl_dynamic * torch.maximum(dyn_loss, kl_free_nats) + repr_loss = kl_divergence( + Independent(OneHotCategoricalStraightThrough(logits=posteriors_logits, validate_args=False), 1), + Independent(OneHotCategoricalStraightThrough(logits=priors_logits.detach(), validate_args=False), 1), + ) + repr_loss = kl_representation * torch.maximum(repr_loss, kl_free_nats) + kl_loss = dyn_loss + repr_loss + continue_loss = torch.tensor(0.0, device=device) + if pc is not None and continue_targets is not None: + continue_loss = continue_scale_factor * -pc.log_prob(continue_targets) + reconstruction_loss = (kl_regularizer * kl_loss + observation_loss + reward_loss + continue_loss).mean() + return ( + reconstruction_loss, + kl.mean(), + kl_loss.mean(), + reward_loss.mean(), + observation_loss.mean(), + continue_loss.mean(), + ) diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py new file mode 100644 index 00000000..4ab89790 --- /dev/null +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -0,0 +1,117 @@ +import os +from typing import TYPE_CHECKING, Any, List + +import gymnasium as gym +import numpy as np +import torch +from lightning import Fabric +from torch import Tensor + +from sheeprl.utils.env import make_dict_env + +if TYPE_CHECKING: + from sheeprl.algos.dreamer_v3.agent import PlayerDV3 + from sheeprl.algos.dreamer_v3.args import DreamerV3Args + + +class Moments(torch.nn.Module): + def __init__( + self, + fabric: Fabric, + decay: float = 0.99, + max_: float = 1e8, + percentile_low: float = 0.05, + percentile_high: float = 0.95, + ) -> None: + super().__init__() + self._fabric = fabric + self._decay = decay + self._max = torch.tensor(max_) + self._percentile_low = percentile_low + self._percentile_high = percentile_high + self.register_buffer("low", torch.zeros((), dtype=torch.float32)) + self.register_buffer("high", torch.zeros((), dtype=torch.float32)) + + def forward(self, x: Tensor) -> Any: + gathered_x = self._fabric.all_gather(x).detach() + low = torch.quantile(gathered_x, self._percentile_low) + high = torch.quantile(gathered_x, self._percentile_high) + self.low = self._decay * self.low + (1 - self._decay) * low + self.high = self._decay * self.high + (1 - self._decay) * high + invscale = torch.max(1 / self._max, self.high - self.low) + return self.low.detach(), invscale.detach() + + +def compute_lambda_values( + rewards: Tensor, + values: Tensor, + continues: Tensor, + lmbda: float = 0.95, +): + vals = [values[-1:]] + interm = rewards + continues * values * (1 - lmbda) + for t in reversed(range(len(continues))): + vals.append(interm[t] + continues[t] * lmbda * vals[-1]) + ret = torch.cat(list(reversed(vals))[:-1]) + return ret + + +@torch.no_grad() +def test( + player: "PlayerDV3", + fabric: Fabric, + args: "DreamerV3Args", + cnn_keys: List[str], + mlp_keys: List[str], + test_name: str = "", + sample_actions: bool = False, +): + """Test the model on the environment with the frozen model. + + Args: + player (PlayerDV2): the agent which contains all the models needed to play. + fabric (Fabric): the fabric instance. + args (Union[DreamerV3Args, DreamerV2Args, DreamerV1Args]): the hyper-parameters. + cnn_keys (Sequence[str]): the keys encoded by the cnn encoder. + mlp_keys (Sequence[str]): the keys encoded by the mlp encoder. + test_name (str): the name of the test. + Default to "". + """ + log_dir = fabric.logger.log_dir if len(fabric.loggers) > 0 else os.getcwd() + env: gym.Env = make_dict_env( + args.env_id, args.seed, 0, args, log_dir, "test" + (f"_{test_name}" if test_name != "" else "") + )() + done = False + cumulative_rew = 0 + device = fabric.device + next_obs = env.reset(seed=args.seed)[0] + for k in next_obs.keys(): + next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + player.num_envs = 1 + player.init_states() + while not done: + # Act greedly through the environment + preprocessed_obs = {} + for k, v in next_obs.items(): + if k in cnn_keys: + preprocessed_obs[k] = v[None, ...].to(device) / 255 + elif k in mlp_keys: + preprocessed_obs[k] = v[None, ...].to(device) + real_actions = player.get_greedy_action( + preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + ) + if player.actor.is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() + else: + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + # Single environment step + next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) + for k in next_obs.keys(): + next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + done = done or truncated or args.dry_run + cumulative_rew += reward + fabric.print("Test - Reward:", cumulative_rew) + if len(fabric.loggers) > 0: + fabric.logger.log_metrics({"Test/cumulative_reward": cumulative_rew}, 0) + env.close() diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index ef5b1d08..5594e14d 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -21,7 +21,7 @@ from torch.utils.data import BatchSampler from torchmetrics import MeanMetric -from sheeprl.algos.dreamer_v1.agent import Player, WorldModel +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv1.agent import build_models @@ -472,7 +472,7 @@ def main(): if args.checkpoint_path: ensembles.load_state_dict(state["ensembles"]) fabric.setup_module(ensembles) - player = Player( + player = PlayerDV1( world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 750447f4..0ba3d0af 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -21,7 +21,7 @@ from torch.utils.data import BatchSampler from torchmetrics import MeanMetric -from sheeprl.algos.dreamer_v2.agent import Player, WorldModel +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, test from sheeprl.algos.p2e_dv2.agent import build_models @@ -595,7 +595,7 @@ def main(): if args.checkpoint_path: ensembles.load_state_dict(state["ensembles"]) fabric.setup_module(ensembles) - player = Player( + player = PlayerDV2( world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, diff --git a/sheeprl/algos/ppo/args.py b/sheeprl/algos/ppo/args.py index a9aab268..308e28b6 100644 --- a/sheeprl/algos/ppo/args.py +++ b/sheeprl/algos/ppo/args.py @@ -60,7 +60,11 @@ class PPOArgs(StandardArgs): default=None, help="a list of observation keys to be processed by the MLP encoder" ) eps: float = Arg(default=1e-4) - max_episode_steps: int = Arg(default=-1, help="the maximum amount of steps in an episode") + max_episode_steps: int = Arg( + default=-1, + help="the maximum duration in terms of number of steps of an episode, -1 to disable. " + "This value will be divided by the `action_repeat` value during the environment creation.", + ) cnn_features_dim: int = Arg(default=512, help="the features dimension after the CNNEncoder") mlp_features_dim: int = Arg(default=64, help="the features dimension after the MLPEncoder") atari_noop_max: int = Arg(default=30, help="the maximum number of noop in Atari envs on reset") diff --git a/sheeprl/envs/diambra_wrapper.py b/sheeprl/envs/diambra_wrapper.py index 8ff27e93..8214a3c5 100644 --- a/sheeprl/envs/diambra_wrapper.py +++ b/sheeprl/envs/diambra_wrapper.py @@ -92,7 +92,7 @@ def _convert_obs(self, obs: Dict[str, Union[int, np.ndarray]]) -> Dict[str, np.n for k, v in obs.items() } - def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: + def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: obs, reward, done, infos = self._env.step(action) infos["env_domain"] = "DIAMBRA" return self._convert_obs(obs), reward, done, False, infos diff --git a/sheeprl/envs/dummy.py b/sheeprl/envs/dummy.py index a4711547..3e6c77f8 100644 --- a/sheeprl/envs/dummy.py +++ b/sheeprl/envs/dummy.py @@ -7,7 +7,7 @@ class ContinuousDummyEnv(gym.Env): def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128): self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,)) - self.observation_space = gym.spaces.Box(0, 255, shape=size, dtype=np.uint8) + self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8) self.reward_range = (-np.inf, np.inf) self._current_step = 0 self._n_steps = n_steps @@ -16,7 +16,7 @@ def step(self, action): done = self._current_step == self._n_steps self._current_step += 1 return ( - np.zeros(self.observation_space.shape, dtype=np.float32), + np.zeros(self.observation_space.shape, dtype=np.uint8), np.zeros(1, dtype=np.float32).item(), done, False, @@ -25,7 +25,7 @@ def step(self, action): def reset(self, seed=None, options=None): self._current_step = 0 - return np.zeros(self.observation_space.shape, dtype=np.float32), {} + return np.zeros(self.observation_space.shape, dtype=np.uint8), {} def render(self, mode="human", close=False): pass @@ -38,9 +38,9 @@ def seed(self, seed=None): class DiscreteDummyEnv(gym.Env): - def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128): + def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 4): self.action_space = gym.spaces.Discrete(action_dim) - self.observation_space = gym.spaces.Box(0, 255, shape=size, dtype=np.uint8) + self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8) self.reward_range = (-np.inf, np.inf) self._current_step = 0 self._n_steps = n_steps @@ -49,7 +49,7 @@ def step(self, action): done = self._current_step == self._n_steps self._current_step += 1 return ( - np.zeros(self.observation_space.shape, dtype=np.float32), + np.random.randint(0, 256, self.observation_space.shape, dtype=np.uint8), np.zeros(1, dtype=np.float32).item(), done, False, @@ -58,7 +58,7 @@ def step(self, action): def reset(self, seed=None, options=None): self._current_step = 0 - return np.zeros(self.observation_space.shape, dtype=np.float32), {} + return np.zeros(self.observation_space.shape, dtype=np.uint8), {} def render(self, mode="human", close=False): pass @@ -73,7 +73,7 @@ def seed(self, seed=None): class MultiDiscreteDummyEnv(gym.Env): def __init__(self, action_dims: List[int] = [2, 2], size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128): self.action_space = gym.spaces.MultiDiscrete(action_dims) - self.observation_space = gym.spaces.Box(0, 255, shape=size, dtype=np.uint8) + self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8) self.reward_range = (-np.inf, np.inf) self._current_step = 0 self._n_steps = n_steps @@ -82,7 +82,7 @@ def step(self, action): done = self._current_step == self._n_steps self._current_step += 1 return ( - np.zeros(self.observation_space.shape, dtype=np.float32), + np.zeros(self.observation_space.shape, dtype=np.uint8), np.zeros(1, dtype=np.float32).item(), done, False, @@ -91,7 +91,7 @@ def step(self, action): def reset(self, seed=None, options=None): self._current_step = 0 - return np.zeros(self.observation_space.shape, dtype=np.float32), {} + return np.zeros(self.observation_space.shape, dtype=np.uint8), {} def render(self, mode="human", close=False): pass diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index b52e6a03..5b2e6f73 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -275,6 +275,7 @@ def reset( "food": float(obs["life_stats"]["food"].item()), }, "location_stats": copy.deepcopy(self._pos), + "biomeid": float(obs["location_stats"]["biome_id"].item()), } def close(self): diff --git a/sheeprl/envs/wrappers.py b/sheeprl/envs/wrappers.py index 5335554c..3fed01d1 100644 --- a/sheeprl/envs/wrappers.py +++ b/sheeprl/envs/wrappers.py @@ -1,6 +1,7 @@ import copy +import time from collections import deque -from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, SupportsFloat, Tuple import gymnasium as gym import numpy as np @@ -69,6 +70,58 @@ def step(self, action): return obs, total_reward, done, truncated, info +class RestartOnException(gym.Wrapper): + def __init__(self, env_fn: Callable[..., gym.Env], exceptions=(Exception,), window=300, maxfails=2, wait=20): + if not isinstance(exceptions, (tuple, list)): + exceptions = [exceptions] + self._env_fn = env_fn + self._exceptions = tuple(exceptions) + self._window = window + self._maxfails = maxfails + self._wait = wait + self._last = time.time() + self._fails = 0 + super().__init__(self._env_fn()) + + def step(self, action) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: + try: + return self.env.step(action) + except self._exceptions as e: + if time.time() > self._last + self._window: + self._last = time.time() + self._fails = 1 + else: + self._fails += 1 + if self._fails > self._maxfails: + raise RuntimeError(f"The env crashed too many times: {self._fails}") + gym.logger.warn(f"STEP - Restarting env after crash with {type(e).__name__}: {e}") + time.sleep(self._wait) + self.env = self._env_fn() + new_obs, info = self.env.reset() + info.update({"restart_on_exception": True}) + return new_obs, 0.0, False, False, info + + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> tuple[Any, Dict[str, Any]]: + try: + return self.env.reset(seed=seed, options=options) + except self._exceptions as e: + if time.time() > self._last + self._window: + self._last = time.time() + self._fails = 1 + else: + self._fails += 1 + if self._fails > self._maxfails: + raise RuntimeError(f"The env crashed too many times: {self._fails}") + gym.logger.warn(f"RESET - Restarting env after crash with {type(e).__name__}: {e}") + time.sleep(self._wait) + self.env = self._env_fn() + new_obs, info = self.env.reset() + info.update({"restart_on_exception": True}) + return new_obs, info + + class FrameStack(gym.Wrapper): def __init__(self, env: Env, num_stack: int, cnn_keys: Sequence[str], dilation: int = 1) -> None: super().__init__(env) diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index 605b1dc3..bc63472e 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -146,6 +146,7 @@ def __init__( self, input_channels: int, hidden_channels: Sequence[int], + cnn_layer: ModuleType = nn.Conv2d, layer_args: ArgsType = None, dropout_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, dropout_args: Optional[ArgsType] = None, @@ -181,7 +182,7 @@ def __init__( activation_list, act_args_list, ): - model += miniblock(in_dim, out_dim, nn.Conv2d, l_args, drop, drop_args, norm, norm_args, activ, act_args) + model += miniblock(in_dim, out_dim, cnn_layer, l_args, drop, drop_args, norm, norm_args, activ, act_args) self._output_dim = hidden_sizes[-1] self._model = nn.Sequential(*model) @@ -227,6 +228,7 @@ def __init__( self, input_channels: int, hidden_channels: Sequence[int] = (), + cnn_layer: ModuleType = nn.ConvTranspose2d, layer_args: ArgsType = None, dropout_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, dropout_args: Optional[ArgsType] = None, @@ -262,9 +264,7 @@ def __init__( activation_list, act_args_list, ): - model += miniblock( - in_dim, out_dim, nn.ConvTranspose2d, l_args, drop, drop_args, norm, norm_args, activ, act_args - ) + model += miniblock(in_dim, out_dim, cnn_layer, l_args, drop, drop_args, norm, norm_args, activ, act_args) self._output_dim = hidden_sizes[-1] self._model = nn.Sequential(*model) diff --git a/sheeprl/utils/distribution.py b/sheeprl/utils/distribution.py index d1f3f1c3..ded14ea8 100644 --- a/sheeprl/utils/distribution.py +++ b/sheeprl/utils/distribution.py @@ -2,11 +2,16 @@ import math from numbers import Number +from typing import Callable import torch +import torch.nn.functional as F +from torch import Tensor from torch.distributions import Distribution, constraints from torch.distributions.utils import broadcast_all +from sheeprl.utils.utils import symexp, symlog + CONST_SQRT_2 = math.sqrt(2) CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) CONST_INV_SQRT_2 = 1 / math.sqrt(2) @@ -137,3 +142,125 @@ def icdf(self, value): def log_prob(self, value): return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale + + +# From https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/jaxutils.py +class SymlogDistribution(Distribution): + def __init__( + self, + mode: Tensor, + dims: int, + dist: str = "mse", + agg: str = "sum", + tol: float = 1e-8, + ): + self._mode = mode + self._dims = tuple([-x for x in range(1, dims + 1)]) + self._dist = dist + self._agg = agg + self._tol = tol + self._batch_shape = mode.shape[: len(mode.shape) - dims] + self._event_shape = mode.shape[len(mode.shape) - dims :] + + @property + def mode(self) -> Tensor: + return symexp(self._mode) + + @property + def mean(self) -> Tensor: + return symexp(self._mode) + + def log_prob(self, value: Tensor) -> Tensor: + assert self._mode.shape == value.shape, (self._mode.shape, value.shape) + if self._dist == "mse": + distance = (self._mode - symlog(value)) ** 2 + distance = torch.where(distance < self._tol, 0, distance) + elif self._dist == "abs": + distance = torch.abs(self._mode - symlog(value)) + distance = torch.where(distance < self._tol, 0, distance) + else: + raise NotImplementedError(self._dist) + if self._agg == "mean": + loss = distance.mean(self._dims) + elif self._agg == "sum": + loss = distance.sum(self._dims) + else: + raise NotImplementedError(self._agg) + return -loss + + +class MSEDistribution(Distribution): + def __init__(self, mode: Tensor, dims: int, agg: str = "sum"): + self._mode = mode + self._dims = tuple([-x for x in range(1, dims + 1)]) + self._agg = agg + self._batch_shape = mode.shape[: len(mode.shape) - dims] + self._event_shape = mode.shape[len(mode.shape) - dims :] + + @property + def mode(self) -> Tensor: + return self._mode + + @property + def mean(self) -> Tensor: + return self._mode + + def log_prob(self, value: Tensor) -> Tensor: + assert self._mode.shape == value.shape, (self._mode.shape, value.shape) + distance = (self._mode - value) ** 2 + if self._agg == "mean": + loss = distance.mean(self._dims) + elif self._agg == "sum": + loss = distance.sum(self._dims) + else: + raise NotImplementedError(self._agg) + return -loss + + +class TwoHotEncodingDistribution(Distribution): + def __init__( + self, + logits: Tensor, + dims: int = 0, + low: int = -20, + high: int = 20, + transfwd: Callable[[Tensor], Tensor] = symlog, + transbwd: Callable[[Tensor], Tensor] = symexp, + ): + self.logits = logits + self.probs = F.softmax(logits, dim=-1) + self.dims = tuple([-x for x in range(1, dims + 1)]) + self.bins = torch.linspace(low, high, logits.shape[-1], device=logits.device) + self.low = low + self.high = high + self.transfwd = transfwd + self.transbwd = transbwd + self._batch_shape = logits.shape[: len(logits.shape) - dims] + self._event_shape = logits.shape[len(logits.shape) - dims : -1] + (1,) + + @property + def mean(self) -> Tensor: + return self.transbwd((self.probs * self.bins).sum(dim=self.dims, keepdim=True)) + + @property + def mode(self) -> Tensor: + return self.transbwd((self.probs * self.bins).sum(dim=self.dims, keepdim=True)) + + def log_prob(self, x: Tensor) -> Tensor: + x = self.transfwd(x) + below = (self.bins <= x).type(torch.int32).sum(dim=-1, keepdim=True) - 1 + above = len(self.bins) - (self.bins > x).type(torch.int32).sum(dim=-1, keepdim=True) + below = torch.clip(below, 0, len(self.bins) - 1) + above = torch.clip(above, 0, len(self.bins) - 1) + equal = below == above + dist_to_below = torch.where(equal, 1, torch.abs(self.bins[below] - x)) + dist_to_above = torch.where(equal, 1, torch.abs(self.bins[above] - x)) + total = dist_to_below + dist_to_above + weight_below = dist_to_above / total + weight_above = dist_to_below / total + target = ( + F.one_hot(below, len(self.bins)) * weight_below[..., None] + + F.one_hot(above, len(self.bins)) * weight_above[..., None] + ).squeeze(-2) + log_pred = self.logits - torch.logsumexp(self.logits, dim=-1, keepdims=True) + return (target * log_pred).sum(dim=self.dims) diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index ac8da33d..34aa1b96 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -122,3 +122,12 @@ def polynomial_decay( return final else: return (initial - final) * ((1 - current_step / max_decay_steps) ** power) + final + + +# From https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/jaxutils.py +def symlog(x: Tensor) -> Tensor: + return torch.sign(x) * torch.log(1 + torch.abs(x)) + + +def symexp(x: Tensor) -> Tensor: + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 5cd5634a..7a8a568a 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -20,7 +20,7 @@ def devices(request): @pytest.fixture() def standard_args(): - return ["--num_envs=1", "--dry_run"] + return ["--num_envs=1", "--dry_run=True", f"--sync_env={_IS_WINDOWS}"] @pytest.fixture() @@ -416,6 +416,66 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): shutil.rmtree(f"pytest_{start_time}") +@pytest.mark.timeout(60) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +@pytest.mark.parametrize("checkpoint_buffer", [True, False]) +def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): + task = importlib.import_module("sheeprl.algos.p2e_dv2.p2e_dv2") + root_dir = os.path.join("pytest_" + start_time, "p2e_dv2", os.environ["LT_DEVICES"]) + run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" + ckpt_path = os.path.join(root_dir, run_name) + version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) + ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") + args = standard_args + [ + "--per_rank_batch_size=2", + "--per_rank_sequence_length=2", + f"--buffer_size={int(os.environ['LT_DEVICES'])}", + "--learning_starts=0", + "--gradient_steps=1", + "--horizon=2", + "--env_id=" + env_id, + "--root_dir=" + root_dir, + "--run_name=" + run_name, + "--dense_units=8", + "--cnn_channels_multiplier=2", + "--recurrent_state_size=8", + "--hidden_size=8", + "--cnn_keys=rgb", + "--pretrain_steps=1", + ] + if checkpoint_buffer: + args.append("--checkpoint_buffer") + + with mock.patch.object(sys, "argv", [task.__file__] + args): + for command in task.__all__: + if command == "main": + task.__dict__[command]() + + with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): + keys = { + "world_model", + "actor_task", + "critic_task", + "ensembles", + "world_optimizer", + "actor_task_optimizer", + "critic_task_optimizer", + "ensemble_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + "actor_exploration", + "critic_exploration", + "actor_exploration_optimizer", + "critic_exploration_optimizer", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) + + @pytest.mark.timeout(60) @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) @@ -529,4 +589,60 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): if checkpoint_buffer: keys.add("rb") check_checkpoint(ckpt_path, keys, checkpoint_buffer) - shutil.rmtree(f"pytest_{start_time}") + shutil.rmtree("pytest_" + start_time) + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +@pytest.mark.parametrize("checkpoint_buffer", [True, False]) +def test_dreamer_v3(standard_args, env_id, checkpoint_buffer, start_time): + task = importlib.import_module("sheeprl.algos.dreamer_v3.dreamer_v3") + root_dir = os.path.join("pytest_" + start_time, "dreamer_v3", os.environ["LT_DEVICES"]) + run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" + ckpt_path = os.path.join(root_dir, run_name) + version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) + ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") + args = standard_args + [ + "--per_rank_batch_size=1", + "--per_rank_sequence_length=1", + f"--buffer_size={int(os.environ['LT_DEVICES'])}", + "--learning_starts=0", + "--gradient_steps=1", + "--horizon=8", + "--env_id=" + env_id, + "--root_dir=" + root_dir, + "--run_name=" + run_name, + "--dense_units=8", + "--cnn_channels_multiplier=2", + "--recurrent_state_size=8", + "--hidden_size=8", + "--cnn_keys=rgb", + "--layer_norm=True", + "--train_every=1", + ] + if checkpoint_buffer: + args.append("--checkpoint_buffer") + + with mock.patch.object(sys, "argv", [task.__file__] + args): + for command in task.__all__: + if command == "main": + task.__dict__[command]() + + keys = { + "world_model", + "actor", + "critic", + "target_critic", + "world_optimizer", + "actor_optimizer", + "critic_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + "moments", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) From e2e164716efabdfe01512defd6f5870621dcec22 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Thu, 10 Aug 2023 18:20:43 +0200 Subject: [PATCH 2/2] Fix typing --- sheeprl/envs/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/envs/wrappers.py b/sheeprl/envs/wrappers.py index 3fed01d1..8998ab90 100644 --- a/sheeprl/envs/wrappers.py +++ b/sheeprl/envs/wrappers.py @@ -103,7 +103,7 @@ def step(self, action) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> tuple[Any, Dict[str, Any]]: + ) -> Tuple[Any, Dict[str, Any]]: try: return self.env.reset(seed=seed, options=options) except self._exceptions as e: