Skip to content

Commit

Permalink
subproc pettingzoo
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Nov 3, 2023
1 parent 3128144 commit 2f292e1
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
4 changes: 2 additions & 2 deletions benchmark_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
def parse_args():
parser = argparse.ArgumentParser("Run an MARL demo.")
parser.add_argument("--method", type=str, default="mappo")
parser.add_argument("--env", type=str, default="sc2")
parser.add_argument("--env-id", type=str, default="3m")
parser.add_argument("--env", type=str, default="mpe")
parser.add_argument("--env-id", type=str, default="simple_spread_v3")
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--test", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda:0")
Expand Down
4 changes: 2 additions & 2 deletions xuance/configs/mappo/mpe/simple_spread_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ critic_hidden_size: [64, 64]
activation: "ReLU"

seed: 1
parallels: 128
n_size: 25
parallels: 8 # 128
n_size: 128
n_epoch: 10
n_minibatch: 1
learning_rate: 0.0007
Expand Down
38 changes: 27 additions & 11 deletions xuance/environment/pettingzoo/pettingzoo_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from abc import ABC

from xuance.environment.vector_envs.vector_env import VecEnv, AlreadySteppingError, NotSteppingError
from xuance.environment.vector_envs.env_utils import obs_n_space_info
from xuance.environment.gym.gym_vec_env import DummyVecEnv_Gym
Expand Down Expand Up @@ -27,7 +29,7 @@ def step_env(env, action):
elif cmd == 'reset':
remote.send([env.reset() for env in envs])
elif cmd == 'render':
remote.send([env.render(data) for env in envs])
remote.send([env.render() for env in envs])
elif cmd == 'close':
remote.close()
break
Expand All @@ -39,7 +41,8 @@ def step_env(env, action):
"action_spaces": envs[0].action_spaces,
"agent_ids": envs[0].agent_ids,
"n_agents": [envs[0].get_num(h) for h in envs[0].handles],
"max_cycles": envs[0].max_cycles
"max_cycles": envs[0].max_cycles,
"side_names": envs[0].side_names
}
remote.send(CloudpickleWrapper(env_info))
else:
Expand Down Expand Up @@ -88,6 +91,7 @@ def __init__(self, env_fns, context="spawn"):
obs_n_space = env_info["observation_spaces"]
self.agent_ids = env_info["agent_ids"]
self.n_agents = env_info["n_agents"]
self.side_names = env_info["side_names"]
VecEnv.__init__(self, num_envs, obs_n_space, env_info["action_spaces"])

self.keys, self.shapes, self.dtypes = obs_n_space_info(obs_n_space)
Expand Down Expand Up @@ -134,8 +138,8 @@ def reset(self):
result = flatten_list(result)
obs, info = zip(*result)
for e in range(self.num_envs):
self.buf_obs_dict[e].update(obs)
self.buf_infos_dict[e].update(info["infos"])
self.buf_obs_dict[e].update(obs[e])
self.buf_infos_dict[e].update(info[e]["infos"])
for h, agent_keys_h in enumerate(self.agent_keys):
self.buf_obs[h][e] = itemgetter(*agent_keys_h)(self.buf_obs_dict[e])
return self.buf_obs.copy(), self.buf_infos_dict.copy()
Expand All @@ -155,6 +159,7 @@ def step_async(self, actions):
assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(
actions, self.num_envs)
self.actions = [actions]
self.actions = np.array_split(self.actions, self.n_remotes)
for remote, action in zip(self.remotes, self.actions):
remote.send(('step', action))
self.waiting = True
Expand All @@ -163,12 +168,12 @@ def step_wait(self):
if not self.waiting:
raise NotSteppingError

for e, remote in zip(range(self.num_envs, self.remotes)):
for e, remote in zip(range(self.num_envs), self.remotes):
result = remote.recv()
result = flatten_list(result)
o, r, d, t, info = result
remote.send(('state', None))
self.buf_state[e] = remote.recv()
self.buf_state[e] = flatten_list(remote.recv())

if len(o.keys()) < self.n_agent_all:
self.empty_dict_buffers(e)
Expand All @@ -182,7 +187,7 @@ def step_wait(self):
# resort the data as group-wise
episode_scores = []
remote.send(('get_agent_mask', None))
mask = remote.recv()
mask = np.array(flatten_list(remote.recv()))
for h, agent_keys_h in enumerate(self.agent_keys):
getter = itemgetter(*agent_keys_h)
self.buf_agent_mask[h][e] = mask[self.agent_ids[h]]
Expand All @@ -195,11 +200,11 @@ def step_wait(self):

if all(self.buf_dones_dict[e].values()) or all(self.buf_trunctions_dict[e].values()):
remote.send(('reset', None))
obs_reset, _ = remote.recv()
obs_reset, _ = flatten_list(remote.recv())
remote.send(('state', None))
state_reset = remote.recv()
state_reset = flatten_list(remote.recv())
remote.send(('get_agent_mask', None))
mask_reset = remote.recv()
mask_reset = np.array(flatten_list(remote.recv()))
obs_reset_handles, mask_reset_handles = [], []
for h, agent_keys_h in enumerate(self.agent_keys):
getter = itemgetter(*agent_keys_h)
Expand All @@ -213,9 +218,19 @@ def step_wait(self):
self.waiting = False
return self.buf_obs.copy(), self.buf_rews.copy(), self.buf_dones.copy(), self.buf_trunctions.copy(), self.buf_infos_dict.copy()

def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()

def render(self, mode=None):
for pipe in self.remotes:
pipe.send(('render', mode))
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
imgs = flatten_list(imgs)
return imgs
Expand Down Expand Up @@ -244,6 +259,7 @@ def __init__(self, env_fns):
obs_n_space = env.observation_spaces # [Box(dim_o), Box(dim_o), ...] ----> dict
self.agent_ids = env.agent_ids
self.n_agents = [env.get_num(h) for h in self.handles]
self.side_names = env.side_names

self.keys, self.shapes, self.dtypes = obs_n_space_info(obs_n_space)
self.agent_keys = [[self.keys[k] for k in ids] for ids in self.agent_ids]
Expand Down
2 changes: 1 addition & 1 deletion xuance/torch/runners/runner_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, args):

# environment details, representations, policies, optimizers, and agents.
for h, arg in enumerate(self.args):
arg.handle_name = self.envs.envs[0].side_names[h]
arg.handle_name = self.envs.side_names[h]
if self.n_handles > 1 and arg.agent != "RANDOM":
arg.model_dir += "{}/".format(arg.handle_name)
arg.handle, arg.n_agents = h, self.envs.n_agents[h]
Expand Down

0 comments on commit 2f292e1

Please sign in to comment.