diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 76e61cbb3f2..e3aad6afc9e 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -84,11 +84,13 @@ def _make_env(cfg, device): env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) ) + assert env is not None return env def transform_env(cfg, env): - env = TransformedEnv(env) + if not isinstance(env, TransformedEnv): + env = TransformedEnv(env) if cfg.env.from_pixels: # transforms pixel from 0-255 to 0-1 (uint8 to float32) env.append_transform(