diff --git a/generals/core/observation.py b/generals/core/observation.py index dbab230..0af9a9a 100644 --- a/generals/core/observation.py +++ b/generals/core/observation.py @@ -10,7 +10,6 @@ class Observation(dict): Observation object to be accessible in dictionary-style format, e.g. observation["armies"]. And to allow for providing a listing of the keys/attributes. - These steps are necessary because PettingZoo & Gymnasium expect dictionary-like Observation objects, but we want the benefits of knowing the dictionaries' members which a dataclass/class provides. @@ -44,26 +43,46 @@ def values(self): def items(self): return dataclasses.asdict(self).items() + def pad_observation(self, pad_to: int) -> None: + """ + Pads all the observation arrays to the specified size. + + Args: + pad_to (int): The target size to pad to. Must be >= the current observation size. + """ + assert pad_to >= max(self.armies.shape), "Can't pad to a smaller size than the original observation." + + h_pad = (0, pad_to - self.armies.shape[0]) + w_pad = (0, pad_to - self.armies.shape[1]) + + # Regular zero padding for most arrays + zero_pad_arrays = [ + "armies", + "generals", + "cities", + "neutral_cells", + "owned_cells", + "opponent_cells", + "fog_cells", + "structures_in_fog", + ] + + for array_name in zero_pad_arrays: + setattr(self, array_name, np.pad(getattr(self, array_name), (h_pad, w_pad), "constant")) + + # Special case for mountains which are padded with ones + self.mountains = np.pad(self.mountains, (h_pad, w_pad), "constant", constant_values=1) + def as_tensor(self, pad_to: int | None = None) -> np.ndarray: """ Returns a 3D tensor of shape (15, rows, cols). Suitable for neural nets. """ - shape = self.armies.shape if pad_to is not None: + self.pad_observation(pad_to) shape = (pad_to, pad_to) - assert pad_to >= max(self.armies.shape), "Can't pad to a smaller size than the original observation." - # pad every channel with zeros, except for mountains, those are padded with ones - h_pad = (0, pad_to - self.armies.shape[0]) - w_pad = (0, pad_to - self.armies.shape[1]) - self.armies = np.pad(self.armies, (h_pad, w_pad), "constant") - self.generals = np.pad(self.generals, (h_pad, w_pad), "constant") - self.cities = np.pad(self.cities, (h_pad, w_pad), "constant") - self.mountains = np.pad(self.mountains, (h_pad, w_pad), "constant", constant_values=1) - self.neutral_cells = np.pad(self.neutral_cells, (h_pad, w_pad), "constant") - self.owned_cells = np.pad(self.owned_cells, (h_pad, w_pad), "constant") - self.opponent_cells = np.pad(self.opponent_cells, (h_pad, w_pad), "constant") - self.fog_cells = np.pad(self.fog_cells, (h_pad, w_pad), "constant") - self.structures_in_fog = np.pad(self.structures_in_fog, (h_pad, w_pad), "constant") + else: + shape = self.armies.shape + return np.stack( [ self.armies,