Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mypy fixes for boids example
Browse files Browse the repository at this point in the history
stevebachmeier committed Dec 26, 2024
1 parent 56dd0a0 commit 8799346
Showing 6 changed files with 31 additions and 32 deletions.
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -32,11 +32,6 @@ exclude = [
# You will need to remove the mypy: ignore-errors comment from the file heading as well
'docs/source/conf.py',
'setup.py',
'src/vivarium/examples/boids/forces.py',
'src/vivarium/examples/boids/movement.py',
'src/vivarium/examples/boids/neighbors.py',
'src/vivarium/examples/boids/population.py',
'src/vivarium/examples/boids/visualization.py',
'src/vivarium/examples/disease_model/__init__.py',
'src/vivarium/examples/disease_model/disease.py',
'src/vivarium/examples/disease_model/intervention.py',
27 changes: 15 additions & 12 deletions src/vivarium/examples/boids/forces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: ignore-errors

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any

@@ -43,7 +44,7 @@ def setup(self, builder: Builder) -> None:
# Pipeline sources and modifiers #
##################################

def apply_force(self, index: pd.Index, acceleration: pd.DataFrame) -> pd.DataFrame:
def apply_force(self, index: pd.Index[int], acceleration: pd.DataFrame) -> pd.DataFrame:
neighbors = self.neighbors(index)
pop = self.population_view.get(index)
pairs = self._get_pairs(neighbors, pop)
@@ -56,18 +57,18 @@ def apply_force(self, index: pd.Index, acceleration: pd.DataFrame) -> pd.DataFra
max_speed=self.max_speed,
)

acceleration.loc[force.index] += force[["x", "y"]]
acceleration.loc[force.index, ["x", "y"]] += force[["x", "y"]]
return acceleration

##################
# Helper methods #
##################

@abstractmethod
def calculate_force(self, neighbors: pd.DataFrame):
def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame:
pass

def _get_pairs(self, neighbors: pd.Series, pop: pd.DataFrame):
def _get_pairs(self, neighbors: pd.Series[int], pop: pd.DataFrame) -> pd.DataFrame:
pairs = (
pop.join(neighbors.rename("neighbors"))
.reset_index()
@@ -91,7 +92,7 @@ def _normalize_and_limit_force(
velocity: pd.DataFrame,
max_force: float,
max_speed: float,
):
) -> pd.DataFrame:
normalization_factor = np.where(
(force.x != 0) | (force.y != 0),
max_speed / self._magnitude(force),
@@ -111,8 +112,8 @@ def _normalize_and_limit_force(
force["y"] *= limit_scaling_factor
return force[["x", "y"]]

def _magnitude(self, df: pd.DataFrame):
return np.sqrt(np.square(df.x) + np.square(df.y))
def _magnitude(self, df: pd.DataFrame) -> pd.Series[float]:
return pd.Series(np.sqrt(np.square(df.x) + np.square(df.y)), dtype=float)


class Separation(Force):
@@ -125,7 +126,7 @@ class Separation(Force):
},
}

def calculate_force(self, neighbors: pd.DataFrame):
def calculate_force(self, neighbors: pd.DataFrame) -> pd.DataFrame:
# Push boids apart when they get too close
separation_neighbors = neighbors[neighbors.distance < self.config.distance].copy()
force_scaling_factor = np.where(
@@ -140,17 +141,19 @@ def calculate_force(self, neighbors: pd.DataFrame):
separation_neighbors["distance_y"] * force_scaling_factor
)

return (
force: pd.DataFrame = (
separation_neighbors.groupby("index_self")[["force_x", "force_y"]]
.sum()
.rename(columns=lambda c: c.replace("force_", ""))
)

return force


class Cohesion(Force):
"""Push boids together."""

def calculate_force(self, pairs: pd.DataFrame):
def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame:
return (
pairs.groupby("index_self")[["distance_x", "distance_y"]]
.sum()
@@ -161,7 +164,7 @@ def calculate_force(self, pairs: pd.DataFrame):
class Alignment(Force):
"""Push boids toward where others are going."""

def calculate_force(self, pairs: pd.DataFrame):
def calculate_force(self, pairs: pd.DataFrame) -> pd.DataFrame:
return (
pairs.groupby("index_self")[["vx_other", "vy_other"]]
.sum()
7 changes: 4 additions & 3 deletions src/vivarium/examples/boids/movement.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mypy: ignore-errors
from __future__ import annotations
import numpy as np
import pandas as pd

from vivarium.framework.event import Event
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData
@@ -38,7 +39,7 @@ def setup(self, builder: Builder) -> None:
# Pipeline sources and modifiers #
##################################

def base_acceleration(self, index: pd.Index) -> pd.DataFrame:
def base_acceleration(self, index: pd.Index[int]) -> pd.DataFrame:
return pd.DataFrame(0.0, columns=["x", "y"], index=index)

########################
@@ -59,7 +60,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
)
self.population_view.update(new_population)

def on_time_step(self, event):
def on_time_step(self, event: Event) -> None:
pop = self.population_view.get(event.index)

acceleration = self.acceleration(event.index)
4 changes: 2 additions & 2 deletions src/vivarium/examples/boids/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# mypy: ignore-errors
from __future__ import annotations
import pandas as pd
from scipy import spatial

@@ -43,7 +43,7 @@ def on_time_step(self, event: Event) -> None:
# Pipeline sources and modifiers #
##################################

def get_neighbors(self, index: pd.Index) -> pd.Series:
def get_neighbors(self, index: pd.Index[int]) -> pd.Series[list[int]]: # type: ignore[type-var]
if not self.neighbors_calculated:
self._calculate_neighbors()
return self._neighbors[index]
1 change: 0 additions & 1 deletion src/vivarium/examples/boids/population.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
import numpy as np
import pandas as pd

19 changes: 10 additions & 9 deletions src/vivarium/examples/boids/visualization.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
# mypy: ignore-errors
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from vivarium import InteractiveContext

def plot_boids(simulation, plot_velocity=False):

def plot_boids(simulation: InteractiveContext, plot_velocity: bool=False) -> None:
width = simulation.configuration.field.width
height = simulation.configuration.field.height
pop = simulation.get_population()

plt.figure(figsize=[12, 12])
plt.figure(figsize=(12, 12))
plt.scatter(pop.x, pop.y, color=pop.color)
if plot_velocity:
plt.quiver(pop.x, pop.y, pop.vx, pop.vy, color=pop.color, width=0.002)
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, width, 0, height])
plt.axis((0, width, 0, height))
plt.show()


def plot_boids_animated(simulation):
def plot_boids_animated(simulation: InteractiveContext) -> FuncAnimation:
width = simulation.configuration.field.width
height = simulation.configuration.field.height
pop = simulation.get_population()

fig = plt.figure(figsize=[12, 12])
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111)
s = ax.scatter(pop.x, pop.y, color=pop.color)
plt.xlabel("x")
plt.ylabel("y")
plt.axis([0, width, 0, height])
plt.axis((0, width, 0, height))

frames = range(2_000)
frame_pops = []
for _ in frames:
simulation.step()
frame_pops.append(simulation.get_population()[["x", "y"]])

def animate(i):
def animate(i: int) -> None:
s.set_offsets(frame_pops[i])

return FuncAnimation(fig, animate, frames=frames, interval=10)
return FuncAnimation(fig, animate, frames=frames, interval=10) # type: ignore[arg-type]

0 comments on commit 8799346

Please sign in to comment.