Skip to content

Commit

Permalink
Merge pull request #56 from Eclectic-Sheep/feature/dreamer-layer-norm
Browse files Browse the repository at this point in the history
Feature/dreamer layer norm
  • Loading branch information
belerico authored Jul 14, 2023
2 parents c2f39ec + 6aa5e61 commit 45a28f7
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 16 deletions.
95 changes: 81 additions & 14 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell
from sheeprl.utils.distribution import TruncatedNormal
from sheeprl.utils.model import ModuleType
from sheeprl.utils.model import ModuleType, LayerNormChannelLast


class MultiEncoder(nn.Module):
Expand All @@ -37,6 +37,7 @@ def __init__(
cnn_act: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ELU,
mlp_act: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ELU,
device: Union[str, torch.device] = "cpu",
layer_norm: bool = False,
) -> None:
super().__init__()
if isinstance(device, str):
Expand All @@ -55,16 +56,27 @@ def __init__(
hidden_channels=(torch.tensor([1, 2, 4, 8]) * cnn_channels_multiplier).tolist(),
layer_args={"kernel_size": 4, "stride": 2},
activation=cnn_act,
norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None,
norm_args=[{"normalized_shape": (2**i) * cnn_channels_multiplier} for i in range(4)]
if layer_norm
else None,
),
nn.Flatten(-3, -1),
)
with torch.no_grad():
self.cnn_output_dim = self.cnn_encoder(torch.zeros(*self.cnn_input_dim)).shape[-1]
self.cnn_output_dim = self.cnn_encoder(torch.zeros(1, *self.cnn_input_dim)).shape[-1]
else:
self.cnn_output_dim = 0

if self.mlp_keys != []:
self.mlp_encoder = MLP(self.mlp_input_dim, None, [dense_units] * mlp_layers, activation=mlp_act)
self.mlp_encoder = MLP(
self.mlp_input_dim,
None,
[dense_units] * mlp_layers,
activation=mlp_act,
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None,
)
self.mlp_output_dim = dense_units
else:
self.mlp_output_dim = 0
Expand All @@ -88,7 +100,6 @@ def __init__(
cnn_keys: Sequence[str],
mlp_keys: Sequence[str],
cnn_channels_multiplier: int,
mlp_output_dim: int,
latent_state_size: int,
cnn_decoder_input_dim: int,
cnn_decoder_output_dim: Tuple[int, int, int],
Expand All @@ -97,6 +108,7 @@ def __init__(
cnn_act: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ELU,
mlp_act: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ELU,
device: Union[str, torch.device] = "cpu",
layer_norm: bool = False,
) -> None:
super().__init__()
if isinstance(device, str):
Expand All @@ -123,10 +135,22 @@ def __init__(
{"kernel_size": 6, "stride": 2},
],
activation=[cnn_act, cnn_act, cnn_act, None],
norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None,
norm_args=[{"normalized_shape": (2 ** (4 - i - 2)) * cnn_channels_multiplier} for i in range(3)]
+ [None]
if layer_norm
else None,
),
)
if self.mlp_keys != []:
self.mlp_decoder = MLP(latent_state_size, None, [dense_units] * mlp_layers, activation=mlp_act)
self.mlp_decoder = MLP(
latent_state_size,
None,
[dense_units] * mlp_layers,
activation=mlp_act,
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None,
)
self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.mlp_splits])

def forward(self, latent_states: Tensor) -> Dict[str, Tensor]:
Expand Down Expand Up @@ -159,10 +183,22 @@ class RecurrentModel(nn.Module):
"""

def __init__(
self, input_size: int, recurrent_state_size: int, dense_units: int, activation_fn: nn.Module = nn.ELU
self,
input_size: int,
recurrent_state_size: int,
dense_units: int,
activation_fn: nn.Module = nn.ELU,
layer_norm: bool = False,
) -> None:
super().__init__()
self.mlp = nn.Sequential(nn.Linear(input_size, dense_units), activation_fn())
self.mlp = MLP(
input_dims=input_size,
output_dim=None,
hidden_sizes=[dense_units],
activation=activation_fn,
norm_layer=[nn.LayerNorm] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units}] if layer_norm else None,
)
self.rnn = LayerNormGRUCell(dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm=True)

def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor:
Expand Down Expand Up @@ -326,6 +362,7 @@ def __init__(
dense_act: nn.Module = nn.ELU,
mlp_layers: int = 4,
distribution: str = "auto",
layer_norm: bool = False,
) -> None:
super().__init__()
self.distribution = distribution.lower()
Expand All @@ -343,11 +380,17 @@ def __init__(
self.distribution = "discrete"
self.model = MLP(
input_dims=latent_state_size,
output_dim=np.sum(actions_dim) * 2 if is_continuous else np.sum(actions_dim),
output_dim=None,
hidden_sizes=[dense_units] * mlp_layers,
activation=dense_act,
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None,
)
if is_continuous:
self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, np.sum(actions_dim) * 2)])
else:
self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, action_dim) for action_dim in actions_dim])
self.actions_dim = actions_dim
self.is_continuous = is_continuous
self.init_std = torch.tensor(init_std)
Expand All @@ -368,8 +411,9 @@ def forward(
The distribution of the actions
"""
out: Tensor = self.model(state)
pre_dist: List[Tensor] = [head(out) for head in self.mlp_heads]
if self.is_continuous:
mean, std = torch.chunk(out, 2, -1)
mean, std = torch.chunk(pre_dist[0], 2, -1)
if self.distribution == "tanh_normal":
mean = 5 * torch.tanh(mean / 5)
std = F.softplus(std + self.init_std) + self.min_std
Expand All @@ -391,10 +435,9 @@ def forward(
actions = [actions]
actions_dist = [actions_dist]
else:
actions_logits = torch.split(out, self.actions_dim, -1)
actions_dist: List[Distribution] = []
actions: List[Tensor] = []
for logits in actions_logits:
for logits in pre_dist:
actions_dist.append(OneHotCategoricalStraightThrough(logits=logits))
if is_training:
actions.append(actions_dist[-1].rsample())
Expand All @@ -415,6 +458,7 @@ def __init__(
dense_act: nn.Module = nn.ELU,
mlp_layers: int = 4,
distribution: str = "auto",
layer_norm: bool = False,
) -> None:
super().__init__(
latent_state_size,
Expand All @@ -426,6 +470,7 @@ def __init__(
dense_act,
mlp_layers,
distribution,
layer_norm,
)

def forward(
Expand All @@ -443,7 +488,7 @@ def forward(
The distribution of the actions
"""
out: Tensor = self.model(state)
actions_logits = torch.split(out, self.actions_dim, -1)
actions_logits: List[Tensor] = [head(out) for head in self.mlp_heads]
actions_dist: List[Distribution] = []
actions: List[Tensor] = []
functional_action = None
Expand Down Expand Up @@ -677,22 +722,32 @@ def build_models(
cnn_act,
dense_act,
fabric.device,
args.layer_norm,
)
stochastic_size = args.stochastic_size * args.discrete_size
recurrent_model = RecurrentModel(np.sum(actions_dim) + stochastic_size, args.recurrent_state_size, args.dense_units)
recurrent_model = RecurrentModel(
int(np.sum(actions_dim)) + stochastic_size,
args.recurrent_state_size,
args.dense_units,
layer_norm=args.layer_norm,
)
representation_model = MLP(
input_dims=args.recurrent_state_size + encoder.cnn_output_dim + encoder.mlp_output_dim,
output_dim=stochastic_size,
hidden_sizes=[args.hidden_size],
activation=dense_act,
flatten_dim=None,
norm_layer=[nn.LayerNorm] if args.layer_norm else None,
norm_args=[{"normalized_shape": args.hidden_size}] if args.layer_norm else None,
)
transition_model = MLP(
input_dims=args.recurrent_state_size,
output_dim=stochastic_size,
hidden_sizes=[args.hidden_size],
activation=dense_act,
flatten_dim=None,
norm_layer=[nn.LayerNorm] if args.layer_norm else None,
norm_args=[{"normalized_shape": args.hidden_size}] if args.layer_norm else None,
)
rssm = RSSM(
recurrent_model.apply(init_weights),
Expand All @@ -706,7 +761,6 @@ def build_models(
cnn_keys,
mlp_keys,
args.cnn_channels_multiplier,
encoder.mlp_input_dim,
args.stochastic_size * args.discrete_size + args.recurrent_state_size,
encoder.cnn_output_dim,
encoder.cnn_input_dim,
Expand All @@ -715,13 +769,16 @@ def build_models(
cnn_act,
dense_act,
fabric.device,
args.layer_norm,
)
reward_model = MLP(
input_dims=stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(args.mlp_layers)] if args.layer_norm else None,
norm_args=[{"normalized_shape": args.dense_units} for _ in range(args.mlp_layers)] if args.layer_norm else None,
)
if args.use_continues:
continue_model = MLP(
Expand All @@ -730,6 +787,10 @@ def build_models(
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(args.mlp_layers)] if args.layer_norm else None,
norm_args=[{"normalized_shape": args.dense_units} for _ in range(args.mlp_layers)]
if args.layer_norm
else None,
)
world_model = WorldModel(
encoder.apply(init_weights),
Expand All @@ -748,6 +809,8 @@ def build_models(
args.dense_units,
dense_act,
args.mlp_layers,
distribution=args.actor_distribution,
layer_norm=args.layer_norm,
)
else:
actor = Actor(
Expand All @@ -759,13 +822,17 @@ def build_models(
args.dense_units,
dense_act,
args.mlp_layers,
distribution=args.actor_distribution,
layer_norm=args.layer_norm,
)
critic = MLP(
input_dims=stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(args.mlp_layers)] if args.layer_norm else None,
norm_args=[{"normalized_shape": args.dense_units} for _ in range(args.mlp_layers)] if args.layer_norm else None,
)
actor.apply(init_weights)
critic.apply(init_weights)
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/dreamer_v2/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class DreamerV2Args(StandardArgs):
help="the activation function for the convolutional layers, one of https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity (case sensitive, without 'nn.')",
)
critic_target_network_update_freq: int = Arg(default=100, help="the frequency to update the target critic network")
layer_norm: bool = Arg(
default=False, help="whether to apply nn.LayerNorm after every Linear/Conv2D/ConvTranspose2D"
)

# Environment settings
expl_amount: float = Arg(default=0.0, help="the exploration amout to add to the actions")
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def train(
if is_continuous:
objective = lambda_values[1:]
else:
baseline = target_critic(imagined_trajectories)
advantage = (lambda_values[1:] - baseline[:-2]).detach()
baseline = target_critic(imagined_trajectories[:-2])
advantage = (lambda_values[1:] - baseline).detach()
objective = (
torch.stack(
[
Expand Down
14 changes: 14 additions & 0 deletions sheeprl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from torch import nn
from torch import Tensor

ModuleType = Optional[Type[nn.Module]]
ArgType = Union[Tuple[Any, ...], Dict[Any, Any], None]
Expand Down Expand Up @@ -154,3 +155,16 @@ def per_layer_ortho_init_weights(module: nn.Module, gain: float = 1.0, bias: flo
elif isinstance(module, (nn.Sequential, nn.ModuleList)):
for i in range(len(module)):
per_layer_ortho_init_weights(module[i], gain=gain, bias=bias)


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x
1 change: 1 addition & 0 deletions tests/test_algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time):
"--recurrent_state_size=8",
"--hidden_size=8",
"--cnn_keys=rgb",
"--layer_norm=True"
]
if checkpoint_buffer:
args.append("--checkpoint_buffer")
Expand Down

0 comments on commit 45a28f7

Please sign in to comment.