Skip to content

Commit

Permalink
Have a dedicated neighborhood property and a get_neighborhood method …
Browse files Browse the repository at this point in the history
…on Cell (#2309)
  • Loading branch information
quaquel authored Sep 21, 2024
1 parent 25925a2 commit 1ad2e1f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
2 changes: 1 addition & 1 deletion benchmarks/Schelling/schelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, model, agent_type, radius, homophily):
def step(self):
"""Run one step of the agent."""
similar = 0
neighborhood = self.cell.neighborhood(radius=self.radius)
neighborhood = self.cell.get_neighborhood(radius=self.radius)
for neighbor in neighborhood.agents:
if neighbor.type == self.type:
similar += 1
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/WolfSheep/wolf_sheep.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, model, energy, p_reproduce, energy_from_food):

def random_move(self):
"""Move to a random neighboring cell."""
self.move_to(self.cell.neighborhood().select_random_cell())
self.move_to(self.cell.neighborhood.select_random_cell())

def spawn_offspring(self):
"""Create offspring."""
Expand Down
30 changes: 27 additions & 3 deletions mesa/experimental/cell_space/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from functools import cache
from functools import cache, cached_property
from random import Random
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -34,6 +34,7 @@ class Cell:
"capacity",
"properties",
"random",
"__dict__",
]

# def __new__(cls,
Expand Down Expand Up @@ -131,10 +132,33 @@ def is_full(self) -> bool:
def __repr__(self): # noqa
return f"Cell({self.coordinate}, {self.agents})"

@cached_property
def neighborhood(self) -> CellCollection:
"""Returns the direct neigborhood of the cell.
This is equivalent to cell.get_neighborhood(radius=1)
"""
return self.get_neighborhood()

# FIXME: Revisit caching strategy on methods
@cache # noqa: B019
def neighborhood(self, radius: int = 1, include_center: bool = False):
"""Returns a list of all neighboring cells."""
def get_neighborhood(
self, radius: int = 1, include_center: bool = False
) -> CellCollection:
"""Returns a list of all neighboring cells for the given radius.
For getting the direct neighborhood (i.e., radius=1) you can also use
the `neighborhood` property.
Args:
radius (int): the radius of the neighborhood
include_center (bool): include the center of the neighborhood
Returns:
a list of all neighboring cells
"""
return CellCollection[Cell](
self._neighborhood(radius=radius, include_center=include_center),
random=self.random,
Expand Down
22 changes: 17 additions & 5 deletions tests/test_cell_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,33 +280,45 @@ def test_cell_neighborhood():
height = 10
grid = OrthogonalVonNeumannGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [2, 5, 9]):
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(0, 0)].neighborhood
else:
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

## Moore
width = 10
height = 10
grid = OrthogonalMooreGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [3, 8, 15]):
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(0, 0)].neighborhood
else:
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

with pytest.raises(ValueError):
grid._cells[(0, 0)].neighborhood(radius=0)
grid._cells[(0, 0)].get_neighborhood(radius=0)

# hexgrid
width = 10
height = 10
grid = HexGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [2, 6, 11]):
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(0, 0)].neighborhood
else:
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

width = 10
height = 10
grid = HexGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [5, 10, 17]):
neighborhood = grid._cells[(1, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(1, 0)].neighborhood
else:
neighborhood = grid._cells[(1, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

# networkgrid
Expand Down

0 comments on commit 1ad2e1f

Please sign in to comment.