diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 661543ba5e7..912449f97e7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,11 @@ repos: hooks: - id: pyupgrade args: [--py38-plus] +- repo: https://github.com/MarcoGorelli/cython-lint + rev: v0.16.0 + hooks: + - id: cython-lint + - id: double-quote-cython-strings - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 # Use the ref you want to point at hooks: diff --git a/mesa/space.py b/mesa/space.pyx similarity index 98% rename from mesa/space.py rename to mesa/space.pyx index 9be3b5637ce..914d877aea7 100644 --- a/mesa/space.py +++ b/mesa/space.pyx @@ -13,10 +13,6 @@ NetworkGrid: a network where each node contains zero or more agents. """ -# Mypy; for the `|` operator purpose -# Remove this __future__ import once the oldest supported Python is 3.10 -from __future__ import annotations - import collections import contextlib import inspect @@ -25,7 +21,7 @@ import warnings from collections.abc import Iterable, Iterator, Sequence from numbers import Real -from typing import Any, Callable, TypeVar, Union, cast, overload +from typing import Any, Callable, TypeVar, Union, cast from warnings import warn with contextlib.suppress(ImportError): @@ -49,6 +45,7 @@ GridContent = Union[Agent, None] MultiGridContent = list[Agent] +GridIndex = tuple[int | slice, int | slice] | int | Sequence[Coordinate] F = TypeVar("F", bound=Callable[..., Any]) @@ -66,12 +63,18 @@ def wrapper(grid_instance, positions) -> Any: return cast(F, wrapper) +def ensure_positions_as_list(positions): + if len(positions) == 2 and not isinstance(positions[0], tuple): + return [positions] + return positions + + def is_integer(x: Real) -> bool: # Check if x is either a CPython integer or Numpy integer. return isinstance(x, _types_integer) -class _Grid: +cdef class _Grid: """Base class for a rectangular grid. Grid cells are indexed by [x, y], where [0, 0] is assumed to be the @@ -134,17 +137,7 @@ def build_empties(self) -> None: ) self._empties_built = True - @overload - def __getitem__(self, index: int | Sequence[Coordinate]) -> list[GridContent]: - ... - - @overload - def __getitem__( - self, index: tuple[int | slice, int | slice] - ) -> GridContent | list[GridContent]: - ... - - def __getitem__(self, index): + def __getitem__(self, index: GridIndex) -> GridContent | list[GridContent]: """Access contents from the grid.""" if isinstance(index, int): @@ -328,7 +321,8 @@ def iter_neighbors( """ default_val = self.default_val() for x, y in self.get_neighborhood(pos, moore, include_center, radius): - if (cell := self._grid[x][y]) != default_val: + cell = self._grid[x][y] + if cell != default_val: yield cell def get_neighbors( @@ -389,11 +383,11 @@ def iter_cell_list_contents( # iter_cell_list_contents returns only non-empty contents. default_val = self.default_val() for x, y in cell_list: - if (cell := self._grid[x][y]) != default_val: + cell = self._grid[x][y] + if cell != default_val: yield cell - @accept_tuple_argument - def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]: + cpdef object get_cell_list_contents(self, cell_list: Iterable[Coordinate]): """Returns an iterator of the agents contained in the cells identified in `cell_list`; cells with empty content are excluded. @@ -403,7 +397,7 @@ def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent] Returns: A list of the agents contained in the cells identified in `cell_list`. """ - return list(self.iter_cell_list_contents(cell_list)) + return list(self.iter_cell_list_contents(ensure_positions_as_list(cell_list))) def place_agent(self, agent: Agent, pos: Coordinate) -> None: ... @@ -489,9 +483,11 @@ def _distance_squared(self, pos1: Coordinate, pos2: Coordinate) -> float: def swap_pos(self, agent_a: Agent, agent_b: Agent) -> None: """Swap agents positions""" agents_no_pos = [] - if (pos_a := agent_a.pos) is None: + pos_a = agent_a.pos + if pos_a is None: agents_no_pos.append(agent_a) - if (pos_b := agent_b.pos) is None: + pos_b = agent_b.pos + if pos_b is None: agents_no_pos.append(agent_b) if agents_no_pos: agents_no_pos = [f"" for a in agents_no_pos] @@ -992,7 +988,8 @@ def place_agent(self, agent: Agent, pos: Coordinate) -> None: def remove_agent(self, agent: Agent) -> None: """Remove the agent from the grid and set its pos attribute to None.""" - if (pos := agent.pos) is None: + pos = agent.pos + if pos is None: return x, y = pos self._grid[x][y] = self.default_val() @@ -1072,7 +1069,7 @@ def iter_cell_list_contents( """ default_val = self.default_val() return itertools.chain.from_iterable( - cell for x, y in cell_list if (cell := self._grid[x][y]) != default_val + self._grid[x][y] for x, y in cell_list if self._grid[x][y] != default_val ) @@ -1099,7 +1096,7 @@ def torus_adj_2d(self, pos: Coordinate) -> Coordinate: def get_neighborhood( self, pos: Coordinate, include_center: bool = False, radius: int = 1 - ) -> list[Coordinate]: + ) -> Sequence[Coordinate]: """Return a list of coordinates that are in the neighborhood of a certain point. To calculate the neighborhood for a HexGrid the parity of the x coordinate of the point is @@ -1561,7 +1558,7 @@ def is_cell_empty(self, node_id: int) -> bool: """Returns a bool of the contents of a cell.""" return self.G.nodes[node_id]["agent"] == self.default_val() - def get_cell_list_contents(self, cell_list: list[int]) -> list[Agent]: + def get_cell_list_contents(self, cell_list: list[int] | nx.Graph) -> list[Agent]: """Returns a list of the agents contained in the nodes identified in `cell_list`; nodes with empty content are excluded. """ @@ -1571,7 +1568,7 @@ def get_all_cell_contents(self) -> list[Agent]: """Returns a list of all the agents in the network.""" return self.get_cell_list_contents(self.G) - def iter_cell_list_contents(self, cell_list: list[int]) -> Iterator[Agent]: + def iter_cell_list_contents(self, cell_list: list[int] | nx.Graph) -> Iterator[Agent]: """Returns an iterator of the agents contained in the nodes identified in `cell_list`; nodes with empty content are excluded. """ diff --git a/pyproject.toml b/pyproject.toml index 0233bc5874b..0df3d17e99f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,13 @@ packages = ["mesa"] [tool.hatch.version] path = "mesa/__init__.py" +[tool.hatch.build.hooks.cython] +dependencies = ["hatch-cython"] + +[tool.hatch.build.hooks.cython.options] +src = "mesa" +compile_py = false + [tool.ruff] # See https://github.com/charliermarsh/ruff#rules for error code definitions. select = [ @@ -130,3 +137,6 @@ extend-exclude = ["docs", "build"] # Hardcode to Python 3.9. # Reminder to update mesa-examples if the value below is changed. target-version = "py39" + +[tool.cython-lint] +ignore = ["E501"]