diff --git a/benchmarks/Schelling/schelling.py b/benchmarks/Schelling/schelling.py index 73ddeaba4e2..52519efc9ae 100644 --- a/benchmarks/Schelling/schelling.py +++ b/benchmarks/Schelling/schelling.py @@ -24,7 +24,7 @@ def __init__(self, model, agent_type, radius, homophily): def step(self): """Run one step of the agent.""" similar = 0 - neighborhood = self.cell.neighborhood(radius=self.radius) + neighborhood = self.cell.get_neighborhood(radius=self.radius) for neighbor in neighborhood.agents: if neighbor.type == self.type: similar += 1 diff --git a/benchmarks/WolfSheep/wolf_sheep.py b/benchmarks/WolfSheep/wolf_sheep.py index 16c01e000c1..aaeed13536c 100644 --- a/benchmarks/WolfSheep/wolf_sheep.py +++ b/benchmarks/WolfSheep/wolf_sheep.py @@ -33,7 +33,7 @@ def __init__(self, model, energy, p_reproduce, energy_from_food): def random_move(self): """Move to a random neighboring cell.""" - self.move_to(self.cell.neighborhood().select_random_cell()) + self.move_to(self.cell.neighborhood.select_random_cell()) def spawn_offspring(self): """Create offspring.""" diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 08a37102ebb..4ee0c5ea84e 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -2,7 +2,7 @@ from __future__ import annotations -from functools import cache +from functools import cache, cached_property from random import Random from typing import TYPE_CHECKING @@ -34,6 +34,7 @@ class Cell: "capacity", "properties", "random", + "__dict__", ] # def __new__(cls, @@ -131,10 +132,33 @@ def is_full(self) -> bool: def __repr__(self): # noqa return f"Cell({self.coordinate}, {self.agents})" + @cached_property + def neighborhood(self) -> CellCollection: + """Returns the direct neigborhood of the cell. + + This is equivalent to cell.get_neighborhood(radius=1) + + """ + return self.get_neighborhood() + # FIXME: Revisit caching strategy on methods @cache # noqa: B019 - def neighborhood(self, radius: int = 1, include_center: bool = False): - """Returns a list of all neighboring cells.""" + def get_neighborhood( + self, radius: int = 1, include_center: bool = False + ) -> CellCollection: + """Returns a list of all neighboring cells for the given radius. + + For getting the direct neighborhood (i.e., radius=1) you can also use + the `neighborhood` property. + + Args: + radius (int): the radius of the neighborhood + include_center (bool): include the center of the neighborhood + + Returns: + a list of all neighboring cells + + """ return CellCollection[Cell]( self._neighborhood(radius=radius, include_center=include_center), random=self.random, diff --git a/tests/test_cell_space.py b/tests/test_cell_space.py index 64d275663d2..01420043797 100644 --- a/tests/test_cell_space.py +++ b/tests/test_cell_space.py @@ -280,7 +280,10 @@ def test_cell_neighborhood(): height = 10 grid = OrthogonalVonNeumannGrid((width, height), torus=False, capacity=None) for radius, n in zip(range(1, 4), [2, 5, 9]): - neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius) + if radius == 1: + neighborhood = grid._cells[(0, 0)].neighborhood + else: + neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius) assert len(neighborhood) == n ## Moore @@ -288,25 +291,34 @@ def test_cell_neighborhood(): height = 10 grid = OrthogonalMooreGrid((width, height), torus=False, capacity=None) for radius, n in zip(range(1, 4), [3, 8, 15]): - neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius) + if radius == 1: + neighborhood = grid._cells[(0, 0)].neighborhood + else: + neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius) assert len(neighborhood) == n with pytest.raises(ValueError): - grid._cells[(0, 0)].neighborhood(radius=0) + grid._cells[(0, 0)].get_neighborhood(radius=0) # hexgrid width = 10 height = 10 grid = HexGrid((width, height), torus=False, capacity=None) for radius, n in zip(range(1, 4), [2, 6, 11]): - neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius) + if radius == 1: + neighborhood = grid._cells[(0, 0)].neighborhood + else: + neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius) assert len(neighborhood) == n width = 10 height = 10 grid = HexGrid((width, height), torus=False, capacity=None) for radius, n in zip(range(1, 4), [5, 10, 17]): - neighborhood = grid._cells[(1, 0)].neighborhood(radius=radius) + if radius == 1: + neighborhood = grid._cells[(1, 0)].neighborhood + else: + neighborhood = grid._cells[(1, 0)].get_neighborhood(radius=radius) assert len(neighborhood) == n # networkgrid