Skip to content

Commit

Permalink
refactor: Option to pad observation separately
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Jan 13, 2025
1 parent 814bdc5 commit a3d7861
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions generals/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class Observation(dict):
Observation object to be accessible in dictionary-style format,
e.g. observation["armies"]. And to allow for providing a
listing of the keys/attributes.
These steps are necessary because PettingZoo & Gymnasium expect
dictionary-like Observation objects, but we want the benefits of
knowing the dictionaries' members which a dataclass/class provides.
Expand Down Expand Up @@ -44,26 +43,46 @@ def values(self):
def items(self):
return dataclasses.asdict(self).items()

def pad_observation(self, pad_to: int) -> None:
"""
Pads all the observation arrays to the specified size.
Args:
pad_to (int): The target size to pad to. Must be >= the current observation size.
"""
assert pad_to >= max(self.armies.shape), "Can't pad to a smaller size than the original observation."

h_pad = (0, pad_to - self.armies.shape[0])
w_pad = (0, pad_to - self.armies.shape[1])

# Regular zero padding for most arrays
zero_pad_arrays = [
"armies",
"generals",
"cities",
"neutral_cells",
"owned_cells",
"opponent_cells",
"fog_cells",
"structures_in_fog",
]

for array_name in zero_pad_arrays:
setattr(self, array_name, np.pad(getattr(self, array_name), (h_pad, w_pad), "constant"))

# Special case for mountains which are padded with ones
self.mountains = np.pad(self.mountains, (h_pad, w_pad), "constant", constant_values=1)

def as_tensor(self, pad_to: int | None = None) -> np.ndarray:
"""
Returns a 3D tensor of shape (15, rows, cols). Suitable for neural nets.
"""
shape = self.armies.shape
if pad_to is not None:
self.pad_observation(pad_to)
shape = (pad_to, pad_to)
assert pad_to >= max(self.armies.shape), "Can't pad to a smaller size than the original observation."
# pad every channel with zeros, except for mountains, those are padded with ones
h_pad = (0, pad_to - self.armies.shape[0])
w_pad = (0, pad_to - self.armies.shape[1])
self.armies = np.pad(self.armies, (h_pad, w_pad), "constant")
self.generals = np.pad(self.generals, (h_pad, w_pad), "constant")
self.cities = np.pad(self.cities, (h_pad, w_pad), "constant")
self.mountains = np.pad(self.mountains, (h_pad, w_pad), "constant", constant_values=1)
self.neutral_cells = np.pad(self.neutral_cells, (h_pad, w_pad), "constant")
self.owned_cells = np.pad(self.owned_cells, (h_pad, w_pad), "constant")
self.opponent_cells = np.pad(self.opponent_cells, (h_pad, w_pad), "constant")
self.fog_cells = np.pad(self.fog_cells, (h_pad, w_pad), "constant")
self.structures_in_fog = np.pad(self.structures_in_fog, (h_pad, w_pad), "constant")
else:
shape = self.armies.shape

return np.stack(
[
self.armies,
Expand Down

0 comments on commit a3d7861

Please sign in to comment.