Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 8, 2024
1 parent e22ba77 commit 83b2074
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 36 deletions.
11 changes: 8 additions & 3 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def main(cfg: "DictConfig"): # noqa: F821
wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg},
)

train_env, test_env = make_environments(cfg=cfg, parallel_envs=cfg.env.n_parallel_envs)
train_env, test_env = make_environments(
cfg=cfg, parallel_envs=cfg.env.n_parallel_envs
)

# Make dreamer components
action_key = "action"
Expand Down Expand Up @@ -144,7 +146,9 @@ def main(cfg: "DictConfig"): # noqa: F821
for _ in range(optim_steps_per_batch):
# sample from replay buffer
t_sample_init = time.time()
sampled_tensordict = replay_buffer.sample(batch_size).reshape(-1, batch_length)
sampled_tensordict = replay_buffer.sample(batch_size).reshape(
-1, batch_length
)
t_sample = time.time() - t_sample_init

t_loss_model_init = time.time()
Expand All @@ -165,7 +169,7 @@ def main(cfg: "DictConfig"): # noqa: F821
clip_grad_norm_(world_model.parameters(), grad_clip)
scaler1.step(world_model_opt)
scaler1.update()
t_loss_model += (time.time()-t_loss_model_init)
t_loss_model += time.time() - t_loss_model_init

# update actor network
t_loss_actor_init = time.time()
Expand Down Expand Up @@ -230,5 +234,6 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, eval_metrics, collected_frames)
t_collect_init = time.time()


if __name__ == "__main__":
main()
22 changes: 12 additions & 10 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
import torch

import torch.nn as nn
from tensordict.nn import InteractionType, TensorDictModule, ProbabilisticTensorDictModule, \
ProbabilisticTensorDictSequential, TensorDictSequential
from tensordict.nn import (
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
TensorDictSequential,
)
from torchrl.collectors import SyncDataCollector
from torchrl.data import SliceSampler, TensorDictReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
Expand All @@ -36,10 +41,11 @@
TransformedEnv,
)
from torchrl.envs.transforms.transforms import (
DeviceCastTransform,
ExcludeTransform,
RenameTransform,
StepCounter,
TensorDictPrimer, DeviceCastTransform,
TensorDictPrimer,
)
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import (
Expand Down Expand Up @@ -75,12 +81,8 @@ def _make_env(cfg, device):
else:
raise NotImplementedError(f"Unknown lib {lib}.")
default_dict = {
"state": UnboundedContinuousTensorSpec(
shape=(cfg.networks.state_dim,)
),
"belief": UnboundedContinuousTensorSpec(
shape=(cfg.networks.rssm_hidden_dim,)
),
"state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim,)),
"belief": UnboundedContinuousTensorSpec(shape=(cfg.networks.rssm_hidden_dim,)),
}
env = env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
Expand Down Expand Up @@ -142,7 +144,7 @@ def make_dreamer(
action_key: str = "action",
value_key: str = "state_value",
use_decoder_in_env: bool = False,
compile: bool=True,
compile: bool = True,
):
test_env = _make_env(config, device="cpu")
test_env = transform_env(config, test_env)
Expand Down
6 changes: 1 addition & 5 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten

from torchrl._utils import (
_CKPT_BACKEND,
implement_for,
logger as torchrl_logger,
)
from torchrl._utils import _CKPT_BACKEND, implement_for, logger as torchrl_logger
from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES

try:
Expand Down
37 changes: 19 additions & 18 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

import torch
from packaging import version
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictModuleBase
from torch import nn

from torchrl.envs.utils import step_mdp
from torchrl.modules.distributions import NormalParamWrapper
from torchrl.modules.models.models import MLP
from torchrl.modules.tensordict_module.sequence import SafeSequential

Expand Down Expand Up @@ -49,14 +48,16 @@ def __init__(
std_min_val=1e-4,
):
super().__init__()
self.backbone = NormalParamWrapper(
self.backbone = nn.Sequential(
MLP(
out_features=2 * out_features,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
),
scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}",
NormalParamExtractor(
scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}",
),
)

def forward(self, state, belief):
Expand Down Expand Up @@ -289,14 +290,14 @@ def __init__(
# Prior
self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim)
self.action_state_projector = nn.Sequential(nn.LazyLinear(hidden_dim), nn.ELU())
self.rnn_to_prior_projector = NormalParamWrapper(
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 2 * state_dim),
self.rnn_to_prior_projector = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 2 * state_dim),
NormalParamExtractor(
scale_lb=scale_lb,
scale_mapping="softplus",
),
scale_lb=scale_lb,
scale_mapping="softplus",
)

self.state_dim = state_dim
Expand Down Expand Up @@ -344,14 +345,14 @@ class RSSMPosterior(nn.Module):

def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1):
super().__init__()
self.obs_rnn_to_post_projector = NormalParamWrapper(
nn.Sequential(
nn.LazyLinear(hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 2 * state_dim),
self.obs_rnn_to_post_projector = nn.Sequential(
nn.LazyLinear(hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, 2 * state_dim),
NormalParamExtractor(
scale_lb=scale_lb,
scale_mapping="softplus",
),
scale_lb=scale_lb,
scale_mapping="softplus",
)
self.hidden_dim = hidden_dim

Expand Down

0 comments on commit 83b2074

Please sign in to comment.