diff --git a/baselines/baselines_utils.py b/baselines/baselines_utils.py index 31950f39..553107b2 100644 --- a/baselines/baselines_utils.py +++ b/baselines/baselines_utils.py @@ -10,17 +10,6 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cpu): env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - #env = make_env(0, env_config) - #if env_config['stream'] is True: - # env = StreamWrapper( - # env, - # stream_metadata = { # All of this is part is optional - # "user": "MATHIEU", # choose your own username - # "env_id": env_config['instance_id'], # environment identifier - # "color": "#d900ff", # choose your color :) - # "extra": "", # any extra text you put here will be displayed - # } - # ) if exists(model_to_load_path + '.zip'): print('\nloading checkpoint') model = PPO.load(model_to_load_path, env=env) @@ -35,7 +24,7 @@ def load_or_create_model(model_to_load_path, env_config, total_timesteps, num_cp return model -def make_env(rank, env_conf, seed=0): +def make_env(rank, env_config, seed=0): """ Utility function for multiprocessed env. :param env_id: (str) the environment ID @@ -44,8 +33,18 @@ def make_env(rank, env_conf, seed=0): :param rank: (int) index of the subprocess """ def _init(): - env = RedGymEnv(env_conf) + env = RedGymEnv(env_config) env.reset(seed=(seed + rank)) + if env_config['stream'] is True: + env = StreamWrapper( + env, + stream_metadata = { # All of this is part is optional + "user": "MATHIEU", # choose your own username + "env_id": env_config['instance_id'], # environment identifier + "color": "#d900ff", # choose your color :) + "extra": "", # any extra text you put here will be displayed + } + ) return env set_random_seed(seed) return _init \ No newline at end of file