Skip to content

Commit

Permalink
refactor: Move Matplotlib-specific Solara components to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
rht committed Jan 16, 2024
1 parent dd686fa commit 8cf7abd
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 116 deletions.
114 changes: 114 additions & 0 deletions mesa/experimental/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Optional

import networkx as nx
import solara
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

import mesa


@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[list[any]] = None):
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
if not content:
continue
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content]
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def _draw_network_grid(space, space_ax, agent_portrayal):
graph = space.G
pos = nx.spring_layout(graph, seed=0)
nx.draw(
graph,
ax=space_ax,
pos=pos,
**agent_portrayal(graph),
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def portray(space):
x = []
y = []
s = [] # size
c = [] # color
for agent in space._agent_to_index:
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def make_plot(model, measure):
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
fig.legend()
elif isinstance(measure, (list, tuple)):
for m in measure:
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
solara.FigureMatplotlib(fig)
121 changes: 6 additions & 115 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from typing import Optional

import matplotlib.pyplot as plt
import networkx as nx
import reacton.ipywidgets as widgets
import solara
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from solara.alias import rv

import mesa
import mesa.experimental.components.matplotlib as components_matplotlib

# Avoid interactive backend
plt.switch_backend("agg")
Expand Down Expand Up @@ -72,7 +69,7 @@ def ColorCard(color, layout_type):
rv.CardTitle(children=["Space"])
if space_drawer == "default":
# draw with the default implementation
SpaceMatplotlib(
components_matplotlib.SpaceMatplotlib(
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer:
Expand All @@ -85,7 +82,7 @@ def ColorCard(color, layout_type):
# Is a custom object
measure(model)
else:
make_plot(model, measure)
components_matplotlib.make_plot(model, measure)
return main

# 3. Set up UI
Expand All @@ -106,7 +103,7 @@ def render_in_jupyter():
# 4. Space
if space_drawer == "default":
# draw with the default implementation
SpaceMatplotlib(
components_matplotlib.SpaceMatplotlib(
model, agent_portrayal, dependencies=[current_step.value]
)
elif space_drawer:
Expand All @@ -121,7 +118,7 @@ def render_in_jupyter():
# Is a custom object
measure(model)
else:
make_plot(model, measure)
components_matplotlib.make_plot(model, measure)

def render_in_browser():
# if space drawer is disabled, do not include it
Expand Down Expand Up @@ -182,7 +179,7 @@ def on_value_play(change):
def do_step():
model.step()
previous_step.value = current_step.value
current_step.value = model.schedule.steps
current_step.value += 1

def do_play():
model.running = True
Expand Down Expand Up @@ -316,112 +313,6 @@ def change_handler(value, name=name):
raise ValueError(f"{input_type} is not a supported input type")


@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[list[any]] = None):
space_fig = Figure()
space_ax = space_fig.subplots()
space = getattr(model, "grid", None)
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
if not content:
continue
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content]
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def _draw_network_grid(space, space_ax, agent_portrayal):
graph = space.G
pos = nx.spring_layout(graph, seed=0)
nx.draw(
graph,
ax=space_ax,
pos=pos,
**agent_portrayal(graph),
)


def _draw_continuous_space(space, space_ax, agent_portrayal):
def portray(space):
x = []
y = []
s = [] # size
c = [] # color
for agent in space._agent_to_index:
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
return out

space_ax.scatter(**portray(space))


def make_plot(model, measure):
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
fig.legend()
elif isinstance(measure, (list, tuple)):
for m in measure:
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
solara.FigureMatplotlib(fig)


def make_text(renderer):
def function(model):
solara.Markdown(renderer(model))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def Test(user_params):


class TestJupyterViz(unittest.TestCase):
@patch("mesa.experimental.jupyter_viz.SpaceMatplotlib")
@patch("mesa.experimental.components.matplotlib.SpaceMatplotlib")
def test_call_space_drawer(self, mock_space_matplotlib):
mock_model_class = Mock()
agent_portrayal = {
Expand Down

0 comments on commit 8cf7abd

Please sign in to comment.