From 340ac6b116517b42dc849388f9c16e9b76c3af40 Mon Sep 17 00:00:00 2001 From: Buridi Aditya Date: Fri, 4 Jan 2019 15:54:22 +0530 Subject: [PATCH] Support for training on MADRAS env --- baselines/a2c/a2c.py | 3 ++- baselines/a2c/runner.py | 3 ++- baselines/common/models.py | 14 ++++++++++++++ baselines/common/vec_env/subproc_vec_env.py | 5 +++++ baselines/run.py | 10 ++++++---- 5 files changed, 29 insertions(+), 6 deletions(-) diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index b0fccfb659..4a25f8dd83 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -118,7 +118,7 @@ def learn( network, env, seed=None, - nsteps=5, + nsteps=20, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, @@ -187,6 +187,7 @@ def learn( nenvs = env.num_envs policy = build_policy(env, network, **network_kwargs) + print('Parallel %d number'%(nenvs)) # Instantiate the model object (that creates step_model and train_model) model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule) diff --git a/baselines/a2c/runner.py b/baselines/a2c/runner.py index 8d0c6ecc7a..e6b22cdfdd 100644 --- a/baselines/a2c/runner.py +++ b/baselines/a2c/runner.py @@ -39,7 +39,8 @@ def run(self): self.dones = dones for n, done in enumerate(dones): if done: - self.obs[n] = self.obs[n]*0 + # self.obs[n] = self.obs[n]*0 + pass self.obs = obs mb_rewards.append(rewards) mb_dones.append(self.dones) diff --git a/baselines/common/models.py b/baselines/common/models.py index 0003079649..e5ad5bc97a 100644 --- a/baselines/common/models.py +++ b/baselines/common/models.py @@ -58,6 +58,20 @@ def network_fn(X): return network_fn +@register("mynn") +def mynn(num_layers=2, num_hidden=[300,500], activation=tf.tanh, layer_norm=False): + def network_fn(X): + h = tf.layers.flatten(X) + for i in range(num_layers): + h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden[i], init_scale=np.sqrt(2)) + if layer_norm: + h = tf.contrib.layers.layer_norm(h, center=True, scale=True) + h = activation(h) + + return h + + return network_fn + @register("cnn") def cnn(**conv_kwargs): diff --git a/baselines/common/vec_env/subproc_vec_env.py b/baselines/common/vec_env/subproc_vec_env.py index 4dc4d2c6c0..450412f7a9 100644 --- a/baselines/common/vec_env/subproc_vec_env.py +++ b/baselines/common/vec_env/subproc_vec_env.py @@ -78,6 +78,11 @@ def reset(self): remote.send(('reset', None)) return np.stack([remote.recv() for remote in self.remotes]) + def reset_envno(self,no): + self._assert_not_closed() + self.remotes[no].send(('reset', None)) + return self.remotes[no].recv() + def close_extras(self): self.closed = True if self.waiting: diff --git a/baselines/run.py b/baselines/run.py index 451544523e..889653420e 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -50,7 +50,7 @@ 'SpaceInvaders-Snes', } -_game_envs['madras'] = {'gym-torcs-v0','gym-madras-v0'} +_game_envs['madras'] = {'Madras-v0'} def train(args, extra_args): env_type, env_id = get_env_type(args.env) print('env_type: {}'.format(env_type)) @@ -88,6 +88,7 @@ def build_env(args): ncpu = multiprocessing.cpu_count() if sys.platform == 'darwin': ncpu //= 2 nenv = args.num_env or ncpu + print('Found %d CPUs'%(nenv)) alg = args.alg seed = args.seed @@ -196,7 +197,7 @@ def main(args): rank = MPI.COMM_WORLD.Get_rank() model, env = train(args, extra_args) - env.close() + # env.close() if args.save_path is not None and rank == 0: save_path = osp.expanduser(args.save_path) @@ -204,15 +205,16 @@ def main(args): if args.play: logger.log("Running trained model") - env = build_env(args) + # env = build_env(args) obs = env.reset() def initialize_placeholders(nlstm=128,**kwargs): return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1)) state, dones = initialize_placeholders(**extra_args) while True: actions, _, state, _ = model.step(obs,S=state, M=dones) + # actions, _, state, _ = model.step(obs) obs, _, done, _ = env.step(actions) - env.render() + # env.render() done = done.any() if isinstance(done, np.ndarray) else done if done: