From b42bb209983fd293ce6272e382732abe5c112c11 Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Thu, 2 Nov 2023 21:21:38 +0800 Subject: [PATCH] multiprocess --- xuance/common/memory_tools_marl.py | 8 +- xuance/configs/dcg/sc2/1c3s5z.yaml | 2 +- xuance/configs/dcg/sc2/25m.yaml | 2 +- xuance/configs/dcg/sc2/2m_vs_1z.yaml | 2 +- xuance/configs/dcg/sc2/2s3z.yaml | 2 +- xuance/configs/dcg/sc2/3m.yaml | 2 +- xuance/configs/dcg/sc2/5m_vs_6m.yaml | 2 +- xuance/configs/dcg/sc2/8m.yaml | 2 +- xuance/configs/dcg/sc2/8m_vs_9m.yaml | 2 +- xuance/configs/dcg/sc2/MMM2.yaml | 2 +- xuance/configs/dcg/sc2/corridor.yaml | 2 +- xuance/configs/iql/sc2/1c3s5z.yaml | 2 +- xuance/configs/iql/sc2/25m.yaml | 2 +- xuance/configs/iql/sc2/2m_vs_1z.yaml | 2 +- xuance/configs/iql/sc2/2s3z.yaml | 2 +- xuance/configs/iql/sc2/3m.yaml | 2 +- xuance/configs/iql/sc2/5m_vs_6m.yaml | 2 +- xuance/configs/iql/sc2/8m.yaml | 2 +- xuance/configs/iql/sc2/8m_vs_9m.yaml | 2 +- xuance/configs/iql/sc2/MMM2.yaml | 2 +- xuance/configs/iql/sc2/corridor.yaml | 2 +- xuance/configs/qmix/sc2/1c3s5z.yaml | 2 +- xuance/configs/qmix/sc2/25m.yaml | 2 +- xuance/configs/qmix/sc2/2m_vs_1z.yaml | 2 +- xuance/configs/qmix/sc2/2s3z.yaml | 2 +- xuance/configs/qmix/sc2/5m_vs_6m.yaml | 2 +- xuance/configs/qmix/sc2/8m.yaml | 2 +- xuance/configs/qmix/sc2/8m_vs_9m.yaml | 2 +- xuance/configs/qmix/sc2/MMM2.yaml | 2 +- xuance/configs/qmix/sc2/corridor.yaml | 2 +- xuance/configs/vdn/sc2/1c3s5z.yaml | 2 +- xuance/configs/vdn/sc2/25m.yaml | 2 +- xuance/configs/vdn/sc2/2m_vs_1z.yaml | 2 +- xuance/configs/vdn/sc2/2s3z.yaml | 2 +- xuance/configs/vdn/sc2/3m.yaml | 2 +- xuance/configs/vdn/sc2/5m_vs_6m.yaml | 2 +- xuance/configs/vdn/sc2/8m.yaml | 2 +- xuance/configs/vdn/sc2/8m_vs_9m.yaml | 2 +- xuance/configs/vdn/sc2/MMM2.yaml | 2 +- xuance/configs/vdn/sc2/corridor.yaml | 2 +- xuance/configs/wqmix/sc2/1c3s5z.yaml | 2 +- xuance/configs/wqmix/sc2/25m.yaml | 2 +- xuance/configs/wqmix/sc2/2m_vs_1z.yaml | 2 +- xuance/configs/wqmix/sc2/2s3z.yaml | 2 +- xuance/configs/wqmix/sc2/5m_vs_6m.yaml | 2 +- xuance/configs/wqmix/sc2/8m.yaml | 2 +- xuance/configs/wqmix/sc2/8m_vs_9m.yaml | 2 +- xuance/configs/wqmix/sc2/MMM2.yaml | 2 +- xuance/configs/wqmix/sc2/corridor.yaml | 2 +- xuance/environment/__init__.py | 38 +-- .../pettingzoo/pettingzoo_vec_env.py | 245 ++++++++++++++++-- xuance/torch/runners/runner_sc2.py | 11 +- 52 files changed, 300 insertions(+), 98 deletions(-) diff --git a/xuance/common/memory_tools_marl.py b/xuance/common/memory_tools_marl.py index 4200a5fc..1097ee8d 100644 --- a/xuance/common/memory_tools_marl.py +++ b/xuance/common/memory_tools_marl.py @@ -131,10 +131,6 @@ def clear_episodes(self): 'obs': np.zeros((self.n_envs, self.n_agents, self.max_eps_len + 1) + self.obs_space, dtype=np.float32), 'actions': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.act_space, dtype=np.float32), 'rewards': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, dtype=np.float32), - 'returns': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), - 'values': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), - 'advantages': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), - 'log_pi_old': np.zeros((self.n_envs, self.n_agents, self.max_eps_len,), np.float32), 'terminals': np.zeros((self.n_envs, self.max_eps_len) + self.done_space, dtype=np.bool), 'avail_actions': np.ones((self.n_envs, self.n_agents, self.max_eps_len + 1, self.dim_act), dtype=np.bool), 'filled': np.zeros((self.n_envs, self.max_eps_len, 1), dtype=np.bool), @@ -149,8 +145,6 @@ def store_transitions(self, t_envs, *transition_data): self.episode_data['obs'][:, :, t_envs] = obs_n self.episode_data['actions'][:, :, t_envs] = actions_dict['actions_n'] self.episode_data['rewards'][:, :, t_envs] = rewards - self.episode_data['values'][:, :, t_envs] = actions_dict['values'] - self.episode_data['log_pi_old'][:, :, t_envs] = actions_dict['log_pi'] self.episode_data['terminals'][:, t_envs] = terminated self.episode_data['avail_actions'][:, :, t_envs] = avail_actions if self.state_space is not None: @@ -159,7 +153,7 @@ def store_transitions(self, t_envs, *transition_data): def store_episodes(self): for i_env in range(self.n_envs): for k in self.keys: - self.data[k][self.ptr] = self.episode_data[k][i_env] + self.data[k][self.ptr] = self.episode_data[k][i_env].copy() self.ptr = (self.ptr + 1) % self.buffer_size self.size = np.min([self.size + 1, self.buffer_size]) self.clear_episodes() diff --git a/xuance/configs/dcg/sc2/1c3s5z.yaml b/xuance/configs/dcg/sc2/1c3s5z.yaml index bf063642..645063fc 100644 --- a/xuance/configs/dcg/sc2/1c3s5z.yaml +++ b/xuance/configs/dcg/sc2/1c3s5z.yaml @@ -4,7 +4,7 @@ env_id: "1c3s5z" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/25m.yaml b/xuance/configs/dcg/sc2/25m.yaml index 626f11cf..4a9eaef1 100644 --- a/xuance/configs/dcg/sc2/25m.yaml +++ b/xuance/configs/dcg/sc2/25m.yaml @@ -4,7 +4,7 @@ env_id: "25m" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/2m_vs_1z.yaml b/xuance/configs/dcg/sc2/2m_vs_1z.yaml index a0660d06..383eefd4 100644 --- a/xuance/configs/dcg/sc2/2m_vs_1z.yaml +++ b/xuance/configs/dcg/sc2/2m_vs_1z.yaml @@ -4,7 +4,7 @@ env_id: "2m_vs_1z" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/2s3z.yaml b/xuance/configs/dcg/sc2/2s3z.yaml index 08f89a96..8ae26c47 100644 --- a/xuance/configs/dcg/sc2/2s3z.yaml +++ b/xuance/configs/dcg/sc2/2s3z.yaml @@ -4,7 +4,7 @@ env_id: "2s3z" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/3m.yaml b/xuance/configs/dcg/sc2/3m.yaml index 5caf2243..844e9532 100644 --- a/xuance/configs/dcg/sc2/3m.yaml +++ b/xuance/configs/dcg/sc2/3m.yaml @@ -4,7 +4,7 @@ env_id: "3m" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/5m_vs_6m.yaml b/xuance/configs/dcg/sc2/5m_vs_6m.yaml index 9a08a1ce..9713c22e 100644 --- a/xuance/configs/dcg/sc2/5m_vs_6m.yaml +++ b/xuance/configs/dcg/sc2/5m_vs_6m.yaml @@ -4,7 +4,7 @@ env_id: "5m_vs_6m" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/8m.yaml b/xuance/configs/dcg/sc2/8m.yaml index 614f6143..82c03f63 100644 --- a/xuance/configs/dcg/sc2/8m.yaml +++ b/xuance/configs/dcg/sc2/8m.yaml @@ -4,7 +4,7 @@ env_id: "8m" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/8m_vs_9m.yaml b/xuance/configs/dcg/sc2/8m_vs_9m.yaml index 31673c88..236a69c4 100644 --- a/xuance/configs/dcg/sc2/8m_vs_9m.yaml +++ b/xuance/configs/dcg/sc2/8m_vs_9m.yaml @@ -4,7 +4,7 @@ env_id: "8m_vs_9m" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/MMM2.yaml b/xuance/configs/dcg/sc2/MMM2.yaml index 4555f9b9..d5efbf7d 100644 --- a/xuance/configs/dcg/sc2/MMM2.yaml +++ b/xuance/configs/dcg/sc2/MMM2.yaml @@ -4,7 +4,7 @@ env_id: "MMM2" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/dcg/sc2/corridor.yaml b/xuance/configs/dcg/sc2/corridor.yaml index ed598056..2916b010 100644 --- a/xuance/configs/dcg/sc2/corridor.yaml +++ b/xuance/configs/dcg/sc2/corridor.yaml @@ -4,7 +4,7 @@ env_id: "corridor" fps: 15 policy: "DCG_policy" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/1c3s5z.yaml b/xuance/configs/iql/sc2/1c3s5z.yaml index 64318747..756b0176 100644 --- a/xuance/configs/iql/sc2/1c3s5z.yaml +++ b/xuance/configs/iql/sc2/1c3s5z.yaml @@ -4,7 +4,7 @@ env_id: "1c3s5z" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/25m.yaml b/xuance/configs/iql/sc2/25m.yaml index 9594571e..c13def31 100644 --- a/xuance/configs/iql/sc2/25m.yaml +++ b/xuance/configs/iql/sc2/25m.yaml @@ -4,7 +4,7 @@ env_id: "25m" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/2m_vs_1z.yaml b/xuance/configs/iql/sc2/2m_vs_1z.yaml index 2ccbf3c8..e4ecd2d7 100644 --- a/xuance/configs/iql/sc2/2m_vs_1z.yaml +++ b/xuance/configs/iql/sc2/2m_vs_1z.yaml @@ -5,7 +5,7 @@ env_id: "2m_vs_1z" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/2s3z.yaml b/xuance/configs/iql/sc2/2s3z.yaml index e12b520c..1f2ec821 100644 --- a/xuance/configs/iql/sc2/2s3z.yaml +++ b/xuance/configs/iql/sc2/2s3z.yaml @@ -4,7 +4,7 @@ env_id: "2s3z" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/3m.yaml b/xuance/configs/iql/sc2/3m.yaml index 1341a64b..8e983d68 100644 --- a/xuance/configs/iql/sc2/3m.yaml +++ b/xuance/configs/iql/sc2/3m.yaml @@ -4,7 +4,7 @@ env_id: "3m" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/5m_vs_6m.yaml b/xuance/configs/iql/sc2/5m_vs_6m.yaml index 5b309139..1481c827 100644 --- a/xuance/configs/iql/sc2/5m_vs_6m.yaml +++ b/xuance/configs/iql/sc2/5m_vs_6m.yaml @@ -4,7 +4,7 @@ env_id: "5m_vs_6m" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/8m.yaml b/xuance/configs/iql/sc2/8m.yaml index ad06f205..6385416d 100644 --- a/xuance/configs/iql/sc2/8m.yaml +++ b/xuance/configs/iql/sc2/8m.yaml @@ -4,7 +4,7 @@ env_id: "8m" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/8m_vs_9m.yaml b/xuance/configs/iql/sc2/8m_vs_9m.yaml index eb4d83c5..87d4f42d 100644 --- a/xuance/configs/iql/sc2/8m_vs_9m.yaml +++ b/xuance/configs/iql/sc2/8m_vs_9m.yaml @@ -4,7 +4,7 @@ env_id: "8m_vs_9m" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/MMM2.yaml b/xuance/configs/iql/sc2/MMM2.yaml index 3b603625..6722810d 100644 --- a/xuance/configs/iql/sc2/MMM2.yaml +++ b/xuance/configs/iql/sc2/MMM2.yaml @@ -4,7 +4,7 @@ env_id: "MMM2" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/iql/sc2/corridor.yaml b/xuance/configs/iql/sc2/corridor.yaml index 56113b19..7989fcd4 100644 --- a/xuance/configs/iql/sc2/corridor.yaml +++ b/xuance/configs/iql/sc2/corridor.yaml @@ -4,7 +4,7 @@ env_id: "corridor" fps: 15 policy: "Basic_Q_network_marl" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/1c3s5z.yaml b/xuance/configs/qmix/sc2/1c3s5z.yaml index 71d7deb7..f1b682bb 100644 --- a/xuance/configs/qmix/sc2/1c3s5z.yaml +++ b/xuance/configs/qmix/sc2/1c3s5z.yaml @@ -5,7 +5,7 @@ env_id: "1c3s5z" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/25m.yaml b/xuance/configs/qmix/sc2/25m.yaml index 93574fc8..8c336c82 100644 --- a/xuance/configs/qmix/sc2/25m.yaml +++ b/xuance/configs/qmix/sc2/25m.yaml @@ -5,7 +5,7 @@ env_id: "25m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/2m_vs_1z.yaml b/xuance/configs/qmix/sc2/2m_vs_1z.yaml index 82f4db23..3b9d85e2 100644 --- a/xuance/configs/qmix/sc2/2m_vs_1z.yaml +++ b/xuance/configs/qmix/sc2/2m_vs_1z.yaml @@ -5,7 +5,7 @@ env_id: "2m_vs_1z" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/2s3z.yaml b/xuance/configs/qmix/sc2/2s3z.yaml index 25fdc65e..fd545aeb 100644 --- a/xuance/configs/qmix/sc2/2s3z.yaml +++ b/xuance/configs/qmix/sc2/2s3z.yaml @@ -5,7 +5,7 @@ env_id: "2s3z" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/5m_vs_6m.yaml b/xuance/configs/qmix/sc2/5m_vs_6m.yaml index 0d2518ab..b1f36193 100644 --- a/xuance/configs/qmix/sc2/5m_vs_6m.yaml +++ b/xuance/configs/qmix/sc2/5m_vs_6m.yaml @@ -5,7 +5,7 @@ env_id: "5m_vs_6m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/8m.yaml b/xuance/configs/qmix/sc2/8m.yaml index 48508ac6..2e32158b 100644 --- a/xuance/configs/qmix/sc2/8m.yaml +++ b/xuance/configs/qmix/sc2/8m.yaml @@ -5,7 +5,7 @@ env_id: "8m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/8m_vs_9m.yaml b/xuance/configs/qmix/sc2/8m_vs_9m.yaml index dd64fe86..5ba1fe1b 100644 --- a/xuance/configs/qmix/sc2/8m_vs_9m.yaml +++ b/xuance/configs/qmix/sc2/8m_vs_9m.yaml @@ -5,7 +5,7 @@ env_id: "8m_vs_9m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/MMM2.yaml b/xuance/configs/qmix/sc2/MMM2.yaml index 78a1170f..43ba3948 100644 --- a/xuance/configs/qmix/sc2/MMM2.yaml +++ b/xuance/configs/qmix/sc2/MMM2.yaml @@ -5,7 +5,7 @@ env_id: "MMM2" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/qmix/sc2/corridor.yaml b/xuance/configs/qmix/sc2/corridor.yaml index b75c0cc1..a25c5aef 100644 --- a/xuance/configs/qmix/sc2/corridor.yaml +++ b/xuance/configs/qmix/sc2/corridor.yaml @@ -5,7 +5,7 @@ env_id: "corridor" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/1c3s5z.yaml b/xuance/configs/vdn/sc2/1c3s5z.yaml index 25bae01d..91a9afc1 100644 --- a/xuance/configs/vdn/sc2/1c3s5z.yaml +++ b/xuance/configs/vdn/sc2/1c3s5z.yaml @@ -5,7 +5,7 @@ env_id: "1c3s5z" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/25m.yaml b/xuance/configs/vdn/sc2/25m.yaml index 3ca718af..c6ca6414 100644 --- a/xuance/configs/vdn/sc2/25m.yaml +++ b/xuance/configs/vdn/sc2/25m.yaml @@ -5,7 +5,7 @@ env_id: "25m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/2m_vs_1z.yaml b/xuance/configs/vdn/sc2/2m_vs_1z.yaml index 58da592d..0ada594a 100644 --- a/xuance/configs/vdn/sc2/2m_vs_1z.yaml +++ b/xuance/configs/vdn/sc2/2m_vs_1z.yaml @@ -5,7 +5,7 @@ env_id: "2m_vs_1z" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/2s3z.yaml b/xuance/configs/vdn/sc2/2s3z.yaml index 5a29a89d..34672749 100644 --- a/xuance/configs/vdn/sc2/2s3z.yaml +++ b/xuance/configs/vdn/sc2/2s3z.yaml @@ -5,7 +5,7 @@ env_id: "2s3z" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/3m.yaml b/xuance/configs/vdn/sc2/3m.yaml index a150ece6..3b552ae9 100644 --- a/xuance/configs/vdn/sc2/3m.yaml +++ b/xuance/configs/vdn/sc2/3m.yaml @@ -5,7 +5,7 @@ env_id: "3m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/5m_vs_6m.yaml b/xuance/configs/vdn/sc2/5m_vs_6m.yaml index 3744ef5d..5f2afc78 100644 --- a/xuance/configs/vdn/sc2/5m_vs_6m.yaml +++ b/xuance/configs/vdn/sc2/5m_vs_6m.yaml @@ -5,7 +5,7 @@ env_id: "5m_vs_6m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/8m.yaml b/xuance/configs/vdn/sc2/8m.yaml index 0c35e5e3..f3695c6b 100644 --- a/xuance/configs/vdn/sc2/8m.yaml +++ b/xuance/configs/vdn/sc2/8m.yaml @@ -5,7 +5,7 @@ env_id: "8m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/8m_vs_9m.yaml b/xuance/configs/vdn/sc2/8m_vs_9m.yaml index d26f31c9..2601103d 100644 --- a/xuance/configs/vdn/sc2/8m_vs_9m.yaml +++ b/xuance/configs/vdn/sc2/8m_vs_9m.yaml @@ -5,7 +5,7 @@ env_id: "8m_vs_9m" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/MMM2.yaml b/xuance/configs/vdn/sc2/MMM2.yaml index 05d1237e..21915698 100644 --- a/xuance/configs/vdn/sc2/MMM2.yaml +++ b/xuance/configs/vdn/sc2/MMM2.yaml @@ -5,7 +5,7 @@ env_id: "MMM2" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/vdn/sc2/corridor.yaml b/xuance/configs/vdn/sc2/corridor.yaml index 3babce95..f8750666 100644 --- a/xuance/configs/vdn/sc2/corridor.yaml +++ b/xuance/configs/vdn/sc2/corridor.yaml @@ -5,7 +5,7 @@ env_id: "corridor" fps: 15 policy: "Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/1c3s5z.yaml b/xuance/configs/wqmix/sc2/1c3s5z.yaml index 0877d2a6..011889f6 100644 --- a/xuance/configs/wqmix/sc2/1c3s5z.yaml +++ b/xuance/configs/wqmix/sc2/1c3s5z.yaml @@ -4,7 +4,7 @@ env_id: "1c3s5z" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/25m.yaml b/xuance/configs/wqmix/sc2/25m.yaml index 71a1bfc1..54c53573 100644 --- a/xuance/configs/wqmix/sc2/25m.yaml +++ b/xuance/configs/wqmix/sc2/25m.yaml @@ -4,7 +4,7 @@ env_id: "25m" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/2m_vs_1z.yaml b/xuance/configs/wqmix/sc2/2m_vs_1z.yaml index 99cfda8d..ddc6fb00 100644 --- a/xuance/configs/wqmix/sc2/2m_vs_1z.yaml +++ b/xuance/configs/wqmix/sc2/2m_vs_1z.yaml @@ -4,7 +4,7 @@ env_id: "2m_vs_1z" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/2s3z.yaml b/xuance/configs/wqmix/sc2/2s3z.yaml index 9d162e03..0a935c45 100644 --- a/xuance/configs/wqmix/sc2/2s3z.yaml +++ b/xuance/configs/wqmix/sc2/2s3z.yaml @@ -4,7 +4,7 @@ env_id: "2s3z" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/5m_vs_6m.yaml b/xuance/configs/wqmix/sc2/5m_vs_6m.yaml index df900f0b..d549c8cb 100644 --- a/xuance/configs/wqmix/sc2/5m_vs_6m.yaml +++ b/xuance/configs/wqmix/sc2/5m_vs_6m.yaml @@ -4,7 +4,7 @@ env_id: "5m_vs_6m" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/8m.yaml b/xuance/configs/wqmix/sc2/8m.yaml index da2611fd..e0bf8159 100644 --- a/xuance/configs/wqmix/sc2/8m.yaml +++ b/xuance/configs/wqmix/sc2/8m.yaml @@ -4,7 +4,7 @@ env_id: "8m" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/8m_vs_9m.yaml b/xuance/configs/wqmix/sc2/8m_vs_9m.yaml index c197d998..47a8a952 100644 --- a/xuance/configs/wqmix/sc2/8m_vs_9m.yaml +++ b/xuance/configs/wqmix/sc2/8m_vs_9m.yaml @@ -4,7 +4,7 @@ env_id: "8m_vs_9m" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/MMM2.yaml b/xuance/configs/wqmix/sc2/MMM2.yaml index 6f6066ae..2084de44 100644 --- a/xuance/configs/wqmix/sc2/MMM2.yaml +++ b/xuance/configs/wqmix/sc2/MMM2.yaml @@ -4,7 +4,7 @@ env_id: "3m" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/configs/wqmix/sc2/corridor.yaml b/xuance/configs/wqmix/sc2/corridor.yaml index 3904be87..e906e614 100644 --- a/xuance/configs/wqmix/sc2/corridor.yaml +++ b/xuance/configs/wqmix/sc2/corridor.yaml @@ -4,7 +4,7 @@ env_id: "corridor" fps: 15 policy: "Weighted_Mixing_Q_network" representation: "Basic_RNN" -vectorize: "Subproc_StarCraft2 +vectorize: "Subproc_StarCraft2" runner: "StarCraft2_Runner" on_policy: False diff --git a/xuance/environment/__init__.py b/xuance/environment/__init__.py index 8977f18d..79f1133b 100644 --- a/xuance/environment/__init__.py +++ b/xuance/environment/__init__.py @@ -6,13 +6,29 @@ from .vector_envs.vector_env import VecEnv from xuance.environment.gym.gym_vec_env import DummyVecEnv_Gym, SubprocVecEnv_Gym from xuance.environment.gym.gym_vec_env import DummyVecEnv_Atari, SubprocVecEnv_Atari -from xuance.environment.pettingzoo.pettingzoo_vec_env import DummyVecEnv_Pettingzoo +from xuance.environment.pettingzoo.pettingzoo_vec_env import DummyVecEnv_Pettingzoo, SubprocVecEnv_Pettingzoo from xuance.environment.magent2.magent_vec_env import DummyVecEnv_MAgent from xuance.environment.starcraft2.sc2_vec_env import DummyVecEnv_StarCraft2, SubprocVecEnv_StarCraft2 from xuance.environment.football.gfootball_vec_env import DummyVecEnv_GFootball from .vector_envs.subproc_vec_env import SubprocVecEnv +REGISTRY_VEC_ENV = { + "Dummy_Gym": DummyVecEnv_Gym, + "Dummy_Pettingzoo": DummyVecEnv_Pettingzoo, + "Dummy_MAgent": DummyVecEnv_MAgent, + "Dummy_StarCraft2": DummyVecEnv_StarCraft2, + "Dummy_Football": DummyVecEnv_GFootball, + "Dummy_Atari": DummyVecEnv_Atari, + + # multiprocess # + "Subproc": SubprocVecEnv, + "Subproc_Gym": SubprocVecEnv_Gym, + "Subproc_Pettingzoo": SubprocVecEnv_Pettingzoo, + "Subproc_StarCraft2": SubprocVecEnv_StarCraft2, + "Subproc_Atari": SubprocVecEnv_Atari, +} + def make_envs(config: Namespace): def _thunk(): @@ -47,23 +63,9 @@ def _thunk(): env = Gym_Env(config.env_id, config.seed, config.render_mode) return env - if config.vectorize == "Subproc": - return SubprocVecEnv([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Dummy_Gym": - return DummyVecEnv_Gym([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Dummy_Pettingzoo": - return DummyVecEnv_Pettingzoo([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Dummy_MAgent": - return DummyVecEnv_MAgent([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Dummy_StarCraft2": - return DummyVecEnv_StarCraft2([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Subproc_StarCraft2": - return SubprocVecEnv_StarCraft2([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Dummy_Football": - return DummyVecEnv_GFootball([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "Dummy_Atari": - return DummyVecEnv_Atari([_thunk for _ in range(config.parallels)]) - elif config.vectorize == "NOREQUIRED": + if config.vectorize in REGISTRY_VEC_ENV.keys(): + return REGISTRY_VEC_ENV[config.vectorize]([_thunk for _ in range(config.parallels)]) + elif config.vectorize != "NOREQUIRED": return _thunk() else: raise NotImplementedError diff --git a/xuance/environment/pettingzoo/pettingzoo_vec_env.py b/xuance/environment/pettingzoo/pettingzoo_vec_env.py index c88e23cc..9ac860db 100644 --- a/xuance/environment/pettingzoo/pettingzoo_vec_env.py +++ b/xuance/environment/pettingzoo/pettingzoo_vec_env.py @@ -4,6 +4,231 @@ from operator import itemgetter from gymnasium.spaces.box import Box import numpy as np +from xuance.environment.vector_envs.subproc_vec_env import clear_mpi_env_vars, flatten_list, CloudpickleWrapper +import multiprocessing as mp + + +def worker(remote, parent_remote, env_fn_wrappers): + def step_env(env, action): + obs_n, reward_n, terminated, truncated, info = env.step(action) + return obs_n, reward_n, terminated, truncated, info + + parent_remote.close() + envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x] + try: + while True: + cmd, data = remote.recv() + if cmd == 'step': + remote.send([step_env(env, action) for env, action in zip(envs, data)]) + elif cmd == 'state': + remote.send([env.state() for env in envs]) + elif cmd == 'get_agent_mask': + remote.send([env.get_agent_mask() for env in envs]) + elif cmd == 'reset': + remote.send([env.reset() for env in envs]) + elif cmd == 'render': + remote.send([env.render(data) for env in envs]) + elif cmd == 'close': + remote.close() + break + elif cmd == 'get_env_info': + env_info = { + "handles": envs[0].handles, + "observation_spaces": envs[0].observation_spaces, + "state_space": envs[0].state_space, + "action_spaces": envs[0].action_spaces, + "agent_ids": envs[0].agent_ids, + "n_agents": [envs[0].get_num(h) for h in envs[0].handles], + "max_cycles": envs[0].max_cycles + } + remote.send(CloudpickleWrapper(env_info)) + else: + raise NotImplementedError + except KeyboardInterrupt: + print('SubprocVecEnv worker: got KeyboardInterrupt') + finally: + for env in envs: + env.close() + + +class SubprocVecEnv_Pettingzoo(VecEnv): + """ + VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. + Recommended to use when num_envs > 1 and step() can be a bottleneck. + """ + + def __init__(self, env_fns, context="spawn"): + """ + Arguments: + env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable + in_series: number of environments to run in series in a single process + (e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series) + """ + self.waiting = False + self.closed = False + self.n_remotes = num_envs = len(env_fns) + env_fns = np.array_split(env_fns, self.n_remotes) + ctx = mp.get_context(context) + self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.n_remotes)]) + self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) + for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] + for p in self.ps: + p.daemon = True # if the main process crashes, we should not cause things to hang + with clear_mpi_env_vars(): + p.start() + for remote in self.work_remotes: + remote.close() + + self.remotes[0].send(('get_env_info', None)) + env_info = self.remotes[0].recv().x + self.handles = env_info["handles"] + self.state_space = env_info["state_space"] + self.state_shape = self.state_space.shape + self.state_dtype = self.state_space.dtype + obs_n_space = env_info["observation_spaces"] + self.agent_ids = env_info["agent_ids"] + self.n_agents = env_info["n_agents"] + VecEnv.__init__(self, num_envs, obs_n_space, env_info["action_spaces"]) + + self.keys, self.shapes, self.dtypes = obs_n_space_info(obs_n_space) + self.agent_keys = [[self.keys[k] for k in ids] for ids in self.agent_ids] + if isinstance(env_info["action_spaces"][self.agent_keys[0][0]], Box): + self.act_dim = [env_info["action_spaces"][keys[0]].shape[0] for keys in self.agent_keys] + else: + self.act_dim = [env_info["action_spaces"][keys[0]].n for keys in self.agent_keys] + self.n_agent_all = len(self.keys) + self.obs_shapes = [self.shapes[self.agent_keys[h.value][0]] for h in self.handles] + self.obs_dtype = self.dtypes[self.keys[0]] + + # buffer of dict data + self.buf_obs_dict = [{k: np.zeros(tuple(self.shapes[k]), dtype=self.dtypes[k]) for k in self.keys} for _ in + range(self.num_envs)] + self.buf_rews_dict = [{k: 0.0 for k in self.keys} for _ in range(self.num_envs)] + self.buf_dones_dict = [{k: False for k in self.keys} for _ in range(self.num_envs)] + self.buf_trunctions_dict = [{k: False for k in self.keys} for _ in range(self.num_envs)] + self.buf_infos_dict = [{} for _ in range(self.num_envs)] + # buffer of numpy data + self.buf_state = np.zeros((self.num_envs,) + self.state_shape, dtype=self.state_dtype) + self.buf_agent_mask = [np.ones([self.num_envs, n], dtype=np.bool) for n in self.n_agents] + self.buf_obs = [np.zeros((self.num_envs, n) + tuple(self.obs_shapes[h]), dtype=self.obs_dtype) for h, n in + enumerate(self.n_agents)] + self.buf_rews = [np.zeros((self.num_envs, n, 1), dtype=np.float32) for n in self.n_agents] + self.buf_dones = [np.ones((self.num_envs, n), dtype=np.bool) for n in self.n_agents] + self.buf_trunctions = [np.ones((self.num_envs, n), dtype=np.bool) for n in self.n_agents] + + self.max_episode_length = env_info["max_cycles"] + self.actions = None + + def empty_dict_buffers(self, i_env): + # buffer of dict data + self.buf_obs_dict[i_env] = {k: np.zeros(tuple(self.shapes[k]), dtype=self.dtypes[k]) for k in self.keys} + self.buf_rews_dict[i_env] = {k: 0.0 for k in self.keys} + self.buf_dones_dict[i_env] = {k: False for k in self.keys} + self.buf_trunctions_dict[i_env] = {k: False for k in self.keys} + self.buf_infos_dict[i_env] = {k: {} for k in self.keys} + + def reset(self): + for remote in self.remotes: + remote.send(('reset', None)) + result = [remote.recv() for remote in self.remotes] + result = flatten_list(result) + obs, info = zip(*result) + for e in range(self.num_envs): + self.buf_obs_dict[e].update(obs) + self.buf_infos_dict[e].update(info["infos"]) + for h, agent_keys_h in enumerate(self.agent_keys): + self.buf_obs[h][e] = itemgetter(*agent_keys_h)(self.buf_obs_dict[e]) + return self.buf_obs.copy(), self.buf_infos_dict.copy() + + def step_async(self, actions): + if self.waiting: + raise AlreadySteppingError + listify = True + try: + if len(actions) == self.num_envs: + listify = False + except TypeError: + pass + if not listify: + self.actions = actions + else: + assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format( + actions, self.num_envs) + self.actions = [actions] + for remote, action in zip(self.remotes, self.actions): + remote.send(('step', action)) + self.waiting = True + + def step_wait(self): + if not self.waiting: + raise NotSteppingError + + for e, remote in zip(range(self.num_envs, self.remotes)): + result = remote.recv() + result = flatten_list(result) + o, r, d, t, info = result + remote.send(('state', None)) + self.buf_state[e] = remote.recv() + + if len(o.keys()) < self.n_agent_all: + self.empty_dict_buffers(e) + # update the data of alive agents + self.buf_obs_dict[e].update(o) + self.buf_rews_dict[e].update(r) + self.buf_dones_dict[e].update(d) + self.buf_trunctions_dict[e].update(t) + self.buf_infos_dict[e].update(info["infos"]) + + # resort the data as group-wise + episode_scores = [] + remote.send(('get_agent_mask', None)) + mask = remote.recv() + for h, agent_keys_h in enumerate(self.agent_keys): + getter = itemgetter(*agent_keys_h) + self.buf_agent_mask[h][e] = mask[self.agent_ids[h]] + self.buf_obs[h][e] = getter(self.buf_obs_dict[e]) + self.buf_rews[h][e, :, 0] = getter(self.buf_rews_dict[e]) + self.buf_dones[h][e] = getter(self.buf_dones_dict[e]) + self.buf_trunctions[h][e] = getter(self.buf_trunctions_dict[e]) + episode_scores.append(getter(info["individual_episode_rewards"])) + self.buf_infos_dict[e]["individual_episode_rewards"] = episode_scores + + if all(self.buf_dones_dict[e].values()) or all(self.buf_trunctions_dict[e].values()): + remote.send(('reset', None)) + obs_reset, _ = remote.recv() + remote.send(('state', None)) + state_reset = remote.recv() + remote.send(('get_agent_mask', None)) + mask_reset = remote.recv() + obs_reset_handles, mask_reset_handles = [], [] + for h, agent_keys_h in enumerate(self.agent_keys): + getter = itemgetter(*agent_keys_h) + obs_reset_handles.append(np.array(getter(obs_reset))) + mask_reset_handles.append(mask_reset[self.agent_ids[h]]) + + self.buf_infos_dict[e]["reset_obs"] = obs_reset_handles + self.buf_infos_dict[e]["reset_agent_mask"] = mask_reset_handles + self.buf_infos_dict[e]["reset_state"] = state_reset + + self.waiting = False + return self.buf_obs.copy(), self.buf_rews.copy(), self.buf_dones.copy(), self.buf_trunctions.copy(), self.buf_infos_dict.copy() + + def render(self, mode=None): + for pipe in self.remotes: + pipe.send(('render', mode)) + imgs = [pipe.recv() for pipe in self.remotes] + imgs = flatten_list(imgs) + return imgs + + def global_state(self): + return self.buf_state + + def agent_mask(self): + return self.buf_agent_mask + + def available_actions(self): + act_mask = [np.ones([self.num_envs, n, self.act_dim[h]], dtype=np.bool) for h, n in enumerate(self.n_agents)] + return np.array(act_mask) class DummyVecEnv_Pettingzoo(DummyVecEnv_Gym): @@ -66,23 +291,6 @@ def reset(self): self.buf_obs[h][e] = itemgetter(*agent_keys_h)(self.buf_obs_dict[e]) return self.buf_obs.copy(), self.buf_infos_dict.copy() - def reset_one_env(self, e): - o = self.envs[e].reset() - self.buf_obs_dict[e].update(o) - obs_e = [] - for h, agent_keys_h in enumerate(self.agent_keys): - self.buf_obs[h][e] = itemgetter(*agent_keys_h)(self.buf_obs_dict[e]) - obs_e.append(self.buf_obs[h][e]) - - return obs_e - - def _get_max_obs_shape(self, k, observation_shape): - obs_shape_n = itemgetter(*list(k))(observation_shape) - size_obs_n = [] - for shape in obs_shape_n: - size_obs_n.append(shape.shape) - return max(size_obs_n) - def step_async(self, actions): if self.waiting: raise AlreadySteppingError @@ -153,9 +361,6 @@ def render(self, mode=None): def global_state(self): return self.buf_state - def global_state_one_env(self, e): - return np.array(self.envs[e].state()) - def agent_mask(self): return self.buf_agent_mask diff --git a/xuance/torch/runners/runner_sc2.py b/xuance/torch/runners/runner_sc2.py index 490081fe..72a548bc 100644 --- a/xuance/torch/runners/runner_sc2.py +++ b/xuance/torch/runners/runner_sc2.py @@ -214,11 +214,12 @@ def run_episodes(self, n_episodes, test_mode=False): obs_n, state = deepcopy(next_obs_n), deepcopy(next_state) # train the model - if not test_mode: - self.agents.memory.store_episodes() # store episode data - n_epoch = self.agents.n_epoch if self.on_policy else self.n_envs - train_info = self.agents.train(self.current_step, n_epoch=n_epoch) # train - self.log_infos(train_info, self.current_step) + if test_mode: + continue + self.agents.memory.store_episodes() # store episode data + n_epoch = self.agents.n_epoch if self.on_policy else self.n_envs + train_info = self.agents.train(self.current_step, n_epoch=n_epoch) # train + self.log_infos(train_info, self.current_step) # After running n_episodes episode_score = np.array(episode_score)