Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 8, 2024
1 parent 36a672d commit f517e45
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
6 changes: 3 additions & 3 deletions sota-implementations/dreamer/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ env:
image_size : 64
horizon: 500
n_parallel_envs: 8
device: null

collector:
total_frames: 5_000_000
init_random_frames: 1000
frames_per_batch: 1000
max_frames_per_traj: 1000
init_random_frames: 8000
frames_per_batch: 8000
device: cuda:0

optimization:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(cfg: "DictConfig"): # noqa: F821
wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg},
)

train_env, test_env = make_environments(cfg=cfg, device=device, parallel_envs=cfg.env.n_parallel_envs)
train_env, test_env = make_environments(cfg=cfg, parallel_envs=cfg.env.n_parallel_envs)

# Make dreamer components
action_key = "action"
Expand Down
13 changes: 6 additions & 7 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def transform_env(cfg, env):
return env


def make_environments(cfg, device, parallel_envs=1):
def make_environments(cfg, parallel_envs=1):
"""Make environments for training and evaluation."""
func = functools.partial(_make_env, cfg=cfg, device=device)
func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device)
train_env = ParallelEnv(
parallel_envs,
EnvCreator(func),
Expand All @@ -141,6 +141,7 @@ def make_dreamer(
action_key: str = "action",
value_key: str = "state_value",
use_decoder_in_env: bool = False,
compile: bool=True,
):
test_env = _make_env(config, device="cpu")
test_env = transform_env(config, test_env)
Expand All @@ -163,10 +164,6 @@ def make_dreamer(
num_cells=config.networks.hidden_dim,
activation_class=get_activation(config.networks.activation),
)
# if config.env.backend == "dm_control":
# observation_in_key = ("position", "velocity")
# obsevation_out_key = "reco_observation"
# else:
observation_in_key = "observation"
obsevation_out_key = "reco_observation"

Expand Down Expand Up @@ -283,7 +280,9 @@ def make_collector(cfg, train_env, actor_model_explore):
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
policy_device=cfg.collector.device,
env_device=train_env.device,
storing_device="cpu",
)
collector.set_seed(cfg.env.seed)

Expand Down

0 comments on commit f517e45

Please sign in to comment.