diff --git a/benchmark_marl.py b/benchmark_marl.py index 614424f9..3c6a0e70 100644 --- a/benchmark_marl.py +++ b/benchmark_marl.py @@ -4,9 +4,9 @@ def parse_args(): parser = argparse.ArgumentParser("Run an MARL demo.") - parser.add_argument("--method", type=str, default="vdac") - parser.add_argument("--env", type=str, default="sc2") - parser.add_argument("--env-id", type=str, default="3m") + parser.add_argument("--method", type=str, default="ippo") + parser.add_argument("--env", type=str, default="mpe") + parser.add_argument("--env-id", type=str, default="simple_spread_v3") parser.add_argument("--seed", type=int, default=10) parser.add_argument("--test", type=int, default=0) parser.add_argument("--device", type=str, default="cuda:0") diff --git a/xuance/environment/pettingzoo/pettingzoo_env.py b/xuance/environment/pettingzoo/pettingzoo_env.py index 3eeca9b7..4609e5f1 100644 --- a/xuance/environment/pettingzoo/pettingzoo_env.py +++ b/xuance/environment/pettingzoo/pettingzoo_env.py @@ -9,7 +9,8 @@ class PettingZoo_Env(ParallelEnv): def __init__(self, env_name: str, env_id: str, seed: int, **kwargs): super(PettingZoo_Env, self).__init__() scenario = importlib.import_module('pettingzoo.' + env_name + '.' + env_id) - self.env = scenario.parallel_env(continuous_actions=kwargs["continuous"], + self.continuous_actions = kwargs["continuous"] + self.env = scenario.parallel_env(continuous_actions=self.continuous_actions, render_mode=kwargs["render_mode"]) self.scenario_name = env_name + "." + env_id self.n_handles = len(AGENT_NAME_DICT[self.scenario_name]) @@ -53,8 +54,9 @@ def reset(self, seed=None, options=None): return observations, reset_info def step(self, actions): - for k, v in actions.items(): - actions[k] = np.clip(v, self.action_spaces[k].low, self.action_spaces[k].high) + if self.continuous_actions: + for k, v in actions.items(): + actions[k] = np.clip(v, self.action_spaces[k].low, self.action_spaces[k].high) observations, rewards, terminations, truncations, infos = self.env.step(actions) for k, v in rewards.items(): self.individual_episode_reward[k] += v