diff --git a/generals/envs/multiagent_gymnasium_generals.py b/generals/envs/multiagent_gymnasium_generals.py index b6090e3..5a2fd1d 100644 --- a/generals/envs/multiagent_gymnasium_generals.py +++ b/generals/envs/multiagent_gymnasium_generals.py @@ -107,12 +107,22 @@ def reset( elif hasattr(self, "replay"): del self.replay - obs1 = self.game.agent_observation(self.agents[0]).as_tensor() - obs2 = self.game.agent_observation(self.agents[1]).as_tensor() - observations = np.stack([obs1, obs2], dtype=np.float32) + _obs = {agent: self.game.agent_observation(agent) for agent in self.agents} + observations = np.stack([_obs[agent].as_tensor() for agent in self.agents], dtype=np.float32) - info: dict[str, Any] = {} - return observations, info + infos: dict[str, Any] = self.game.get_infos() + # flatten infos + infos = { + agent: [ + infos[agent]["army"], + infos[agent]["land"], + infos[agent]["is_done"], + infos[agent]["is_winner"], + compute_valid_move_mask(_obs[agent]), + ] + for i, agent in enumerate(self.agents) + } + return observations, infos def step(self, actions: list[Action]) -> tuple[Any, Any, bool, bool, dict[str, Any]]: _actions = {