Skip to content

Commit

Permalink
Merge pull request #55 from Eclectic-Sheep/fix/dreamer-v2-untied-weights
Browse files Browse the repository at this point in the history
Untie decoder MLP weights
  • Loading branch information
belerico authored Jul 13, 2023
2 parents 77a66e0 + 54da7ef commit c2f39ec
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __init__(
),
)
if self.mlp_keys != []:
self.mlp_decoder = MLP(latent_state_size, mlp_output_dim, [dense_units] * mlp_layers, activation=mlp_act)
self.mlp_decoder = MLP(latent_state_size, None, [dense_units] * mlp_layers, activation=mlp_act)
self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.mlp_splits])

def forward(self, latent_states: Tensor) -> Dict[str, Tensor]:
reconstructed_obs = {}
Expand All @@ -139,9 +140,7 @@ def forward(self, latent_states: Tensor) -> Dict[str, Tensor]:
)
if self.mlp_keys != []:
mlp_out = self.mlp_decoder(latent_states)
reconstructed_obs.update(
{k: rec_obs for k, rec_obs in zip(self.mlp_keys, torch.split(mlp_out, self.mlp_splits, -1))}
)
reconstructed_obs.update({k: head(mlp_out) for k, head in zip(self.mlp_keys, self.mlp_heads)})
return reconstructed_obs


Expand Down

0 comments on commit c2f39ec

Please sign in to comment.