diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index c83e46b1..39c2c036 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -61,9 +61,9 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2}, activation=activation, norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None, - norm_args=[{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)] if layer_norm else None + ), ), nn.Flatten(-3, -1), ) @@ -172,12 +172,12 @@ def __init__( ], activation=[activation, activation, activation, None], norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None, - norm_args=[ - {"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0]) - ] - + [None] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0])] + + [None] + if layer_norm + else None + ), ), ) @@ -943,9 +943,11 @@ def build_agent( activation=eval(world_model_cfg.representation_model.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, - norm_args=[{"normalized_shape": world_model_cfg.representation_model.hidden_size}] - if world_model_cfg.representation_model.layer_norm - else None, + norm_args=( + [{"normalized_shape": world_model_cfg.representation_model.hidden_size}] + if world_model_cfg.representation_model.layer_norm + else None + ), ) transition_model = MLP( input_dims=world_model_cfg.recurrent_model.recurrent_state_size, @@ -954,9 +956,11 @@ def build_agent( activation=eval(world_model_cfg.transition_model.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None, - norm_args=[{"normalized_shape": world_model_cfg.transition_model.hidden_size}] - if world_model_cfg.transition_model.layer_norm - else None, + norm_args=( + [{"normalized_shape": world_model_cfg.transition_model.hidden_size}] + if world_model_cfg.transition_model.layer_norm + else None + ), ) rssm = RSSM( recurrent_model=recurrent_model.apply(init_weights), @@ -999,15 +1003,19 @@ def build_agent( hidden_sizes=[world_model_cfg.reward_model.dense_units] * world_model_cfg.reward_model.mlp_layers, activation=eval(world_model_cfg.reward_model.dense_act), flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] - if world_model_cfg.reward_model.layer_norm - else None, - norm_args=[ - {"normalized_shape": world_model_cfg.reward_model.dense_units} - for _ in range(world_model_cfg.reward_model.mlp_layers) - ] - if world_model_cfg.reward_model.layer_norm - else None, + norm_layer=( + [nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] + if world_model_cfg.reward_model.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": world_model_cfg.reward_model.dense_units} + for _ in range(world_model_cfg.reward_model.mlp_layers) + ] + if world_model_cfg.reward_model.layer_norm + else None + ), ) if world_model_cfg.use_continues: continue_model = MLP( @@ -1016,15 +1024,19 @@ def build_agent( hidden_sizes=[world_model_cfg.discount_model.dense_units] * world_model_cfg.discount_model.mlp_layers, activation=eval(world_model_cfg.discount_model.dense_act), flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] - if world_model_cfg.discount_model.layer_norm - else None, - norm_args=[ - {"normalized_shape": world_model_cfg.discount_model.dense_units} - for _ in range(world_model_cfg.discount_model.mlp_layers) - ] - if world_model_cfg.discount_model.layer_norm - else None, + norm_layer=( + [nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] + if world_model_cfg.discount_model.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": world_model_cfg.discount_model.dense_units} + for _ in range(world_model_cfg.discount_model.mlp_layers) + ] + if world_model_cfg.discount_model.layer_norm + else None + ), ) world_model = WorldModel( encoder.apply(init_weights), @@ -1053,9 +1065,11 @@ def build_agent( activation=eval(critic_cfg.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ) actor.apply(init_weights) critic.apply(init_weights) diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py index ffb146cd..7a1f2942 100644 --- a/sheeprl/algos/p2e_dv2/agent.py +++ b/sheeprl/algos/p2e_dv2/agent.py @@ -116,9 +116,11 @@ def build_agent( activation=eval(critic_cfg.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ) actor_task.apply(init_weights) critic_task.apply(init_weights) diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index 66547eaa..8d6c80c1 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -126,9 +126,11 @@ def build_agent( flatten_dim=None, layer_args={"bias": not critic_cfg.layer_norm}, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ), } critics_exploration[k]["module"].apply(init_weights) diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index ac914e07..1df8480e 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -26,9 +26,11 @@ def __init__( activation=eval(pre_rnn_mlp_cfg.activation), layer_args={"bias": pre_rnn_mlp_cfg.bias}, norm_layer=[nn.LayerNorm] if pre_rnn_mlp_cfg.layer_norm else None, - norm_args=[{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}] - if pre_rnn_mlp_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}] + if pre_rnn_mlp_cfg.layer_norm + else None + ), ) else: self._pre_mlp = nn.Identity() @@ -45,9 +47,11 @@ def __init__( activation=eval(post_rnn_mlp_cfg.activation), layer_args={"bias": post_rnn_mlp_cfg.bias}, norm_layer=[nn.LayerNorm] if post_rnn_mlp_cfg.layer_norm else None, - norm_args=[{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}] - if post_rnn_mlp_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}] + if post_rnn_mlp_cfg.layer_norm + else None + ), ) self._output_dim = post_rnn_mlp_cfg.dense_units else: diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 73672e39..c2c4dafe 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -135,12 +135,10 @@ def to_tensor( return buf @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None: """Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten. @@ -614,12 +612,10 @@ def __len__(self) -> int: return self.buffer_size @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add( self, @@ -857,8 +853,9 @@ def __len__(self) -> int: return self._cum_lengths[-1] if len(self._buf) > 0 else 0 @typing.overload - def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None: - ... + def add( + self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False + ) -> None: ... @typing.overload def add( @@ -866,8 +863,7 @@ def add( data: Dict[str, np.ndarray], env_idxes: Sequence[int] | None = None, validate_args: bool = False, - ) -> None: - ... + ) -> None: ... def add( self, diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index df774c13..089c45ce 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -1,6 +1,7 @@ """ Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py """ + import warnings from math import prod from typing import Dict, Optional, Sequence, Union, no_type_check diff --git a/sheeprl/utils/distribution.py b/sheeprl/utils/distribution.py index 31765bb6..842a745d 100644 --- a/sheeprl/utils/distribution.py +++ b/sheeprl/utils/distribution.py @@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True @@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (Bengio et al, 2013) """ + has_rsample = True def rsample(self, sample_shape=torch.Size()): diff --git a/sheeprl/utils/model.py b/sheeprl/utils/model.py index f74ba626..89ea6c29 100644 --- a/sheeprl/utils/model.py +++ b/sheeprl/utils/model.py @@ -1,6 +1,7 @@ """ Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py """ + from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch