Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 8, 2024
1 parent e81c9a5 commit 331faf4
Showing 1 changed file with 18 additions and 25 deletions.
43 changes: 18 additions & 25 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,29 @@ def _make_env(cfg, device):
lib = cfg.env.backend
if lib in ("gym", "gymnasium"):
with set_gym_backend(lib):
return GymEnv(
env = GymEnv(
cfg.env.name,
device=device,
)
elif lib == "dm_control":
return DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels)
env = DMControlEnv(cfg.env.name, cfg.env.task, from_pixels=cfg.env.from_pixels)
else:
raise NotImplementedError(f"Unknown lib {lib}.")
default_dict = {
"state": UnboundedContinuousTensorSpec(
shape=(cfg.networks.state_dim,)
),
"belief": UnboundedContinuousTensorSpec(
shape=(cfg.networks.rssm_hidden_dim,)
),
}
env = env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
)
return env


def transform_env(cfg, env, parallel_envs, dummy=False):
def transform_env(cfg, env):
env = TransformedEnv(env)
if cfg.env.from_pixels:
# transforms pixel from 0-255 to 0-1 (uint8 to float32)
Expand All @@ -95,25 +107,6 @@ def transform_env(cfg, env, parallel_envs, dummy=False):
env.append_transform(RewardSum())
env.append_transform(FrameSkipTransform(cfg.env.frame_skip))
env.append_transform(StepCounter(cfg.env.horizon))
if dummy:
default_dict = {
"state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim)),
"belief": UnboundedContinuousTensorSpec(
shape=(cfg.networks.rssm_hidden_dim)
),
}
else:
default_dict = {
"state": UnboundedContinuousTensorSpec(
shape=(parallel_envs, cfg.networks.state_dim)
),
"belief": UnboundedContinuousTensorSpec(
shape=(parallel_envs, cfg.networks.rssm_hidden_dim)
),
}
env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
)

return env

Expand All @@ -126,14 +119,14 @@ def make_environments(cfg, device, parallel_envs=1):
EnvCreator(func),
serial_for_single=True,
)
train_env = transform_env(cfg, train_env, parallel_envs)
train_env = transform_env(cfg, train_env)
train_env.set_seed(cfg.env.seed)
eval_env = ParallelEnv(
parallel_envs,
EnvCreator(func),
serial_for_single=True,
)
eval_env = transform_env(cfg, eval_env, parallel_envs)
eval_env = transform_env(cfg, eval_env)
eval_env.set_seed(cfg.env.seed + 1)
check_env_specs(train_env)
check_env_specs(eval_env)
Expand All @@ -148,7 +141,7 @@ def make_dreamer(
use_decoder_in_env: bool = False,
):
test_env = _make_env(config, device="cpu")
test_env = transform_env(config, test_env, parallel_envs=1, dummy=True)
test_env = transform_env(config, test_env)
# Make encoder and decoder
if config.env.from_pixels:
encoder = ObsEncoder()
Expand Down

0 comments on commit 331faf4

Please sign in to comment.