Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More parameter functionality in matplotlib (default) drawer visualization #2242

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
186 changes: 117 additions & 69 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections import defaultdict

import matplotlib.pyplot as plt
import networkx as nx
import solara
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

import mesa
from mesa.space import GridContent


@solara.component
Expand All @@ -22,47 +25,100 @@
_draw_continuous_space(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)

solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


# used to make non(less?)-breaking change
# this *does* however block the matplotlib 'color' param which is somewhat distinct from 'c'.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain or link to how color and c differ in matplotlib (scatter)?

# maybe translate 'size' and 'shape' but not 'color'?
def _translate_old_keywords(data):
"""
Translates old keyword names in the given dictionary to the new names.
"""
key_mapping = {"size": "s", "color": "c", "shape": "marker"}

Check warning on line 39 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L39

Added line #L39 was not covered by tests
return {key_mapping.get(key, key): val for (key, val) in data.items()}


def _apply_color_map(color, cmap=None, norm=None, vmin=None, vmax=None):
"""
Given parameters for manual colormap application, applies color map
according to default implementation in matplotlib
"""
if not cmap: # if no colormap is provided, return original color
return color
color_map = plt.get_cmap(cmap)

Check warning on line 50 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L49-L50

Added lines #L49 - L50 were not covered by tests
if norm: # check if norm is provided and apply it
if not isinstance(norm, Normalize):
raise TypeError(

Check warning on line 53 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L53

Added line #L53 was not covered by tests
"'norm' must be an instance of Normalize or its subclasses."
)
return color_map(norm(color))

Check warning on line 56 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L56

Added line #L56 was not covered by tests
if not (vmin == None or vmax == None): # check for custom norm params
new_norm = Normalize(vmin, vmax)
return color_map(new_norm(color))
try:
return color_map(color)
except Exception as e:
raise ValueError("Color mapping failed due to invalid arguments") from e

Check warning on line 63 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L58-L63

Added lines #L58 - L63 were not covered by tests


# matplotlib scatter does not allow for multiple shapes in one call
def _split_and_scatter(portray_data, space_ax):
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []})

# Extract data from the dictionary
x = portray_data["x"]
y = portray_data["y"]
s = portray_data["s"]
c = portray_data["c"]
m = portray_data["m"]

if not (len(x) == len(y) == len(s) == len(c) == len(m)):
raise ValueError(
"Length mismatch in portrayal data lists: "
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, "
f"color: {len(c)}, marker: {len(m)}"
)

# Group the data by marker
for i in range(len(x)):
marker = m[i]
grouped_data[marker]["x"].append(x[i])
grouped_data[marker]["y"].append(y[i])
grouped_data[marker]["s"].append(s[i])
grouped_data[marker]["c"].append(c[i])

# Plot each group with the same marker
def _split_and_scatter(portray_data: dict, space_ax) -> None:
# if any of the following params are passed into portray(), this is true
cmap_exists = portray_data.pop("cmap", None)
norm_exists = portray_data.pop("norm", None)
vmin_exists = portray_data.pop("vmin", None)
vmax_exists = portray_data.pop("vmax", None)

Check warning on line 72 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L69-L72

Added lines #L69 - L72 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling this can be done more elegant


# enforce marker iterability
markers = portray_data.pop("marker", ["o"] * len(portray_data["x"]))

Check warning on line 75 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L75

Added line #L75 was not covered by tests
# enforce default color
if ( # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color
"color" not in portray_data
and "facecolor" not in portray_data
and "c" not in portray_data
):
portray_data["color"] = ["tab:blue"] * len(portray_data["x"])

Check warning on line 82 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L82

Added line #L82 was not covered by tests

grouped_data = defaultdict(lambda: {key: [] for key in portray_data})

for i, marker in enumerate(markers):
for key in portray_data:
if key == "c": # apply colormap if possible
# prepare arguments
cmap = cmap_exists[i] if cmap_exists else None
norm = norm_exists[i] if norm_exists else None
vmin = vmin_exists[i] if vmin_exists else None
vmax = vmax_exists[i] if vmax_exists else None

Check warning on line 93 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L90-L93

Added lines #L90 - L93 were not covered by tests
# apply colormap with prepared arguments
portray_data["c"][i] = _apply_color_map(

Check warning on line 95 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L95

Added line #L95 was not covered by tests
portray_data["c"][i], cmap, norm, vmin, vmax
)

grouped_data[marker][key].append(portray_data[key][i])

Check warning on line 99 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L99

Added line #L99 was not covered by tests

for marker, data in grouped_data.items():
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker)
space_ax.scatter(marker=marker, **data)

Check warning on line 102 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L102

Added line #L102 was not covered by tests


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
m = [] # shape
default_values = {

Check warning on line 107 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L107

Added line #L107 was not covered by tests
"size": (180 / max(g.width, g.height)) ** 2,
}

out = {}
num_agents = 0

Check warning on line 112 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L111-L112

Added lines #L111 - L112 were not covered by tests
for content in g:
if not content:
continue

Check warning on line 115 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L115

Added line #L115 was not covered by tests
if isinstance(content, GridContent): # one agent
num_agents += 1
continue
num_agents += len(content)

Check warning on line 119 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L117-L119

Added lines #L117 - L119 were not covered by tests

index = 0

Check warning on line 121 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L121

Added line #L121 was not covered by tests
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
Expand All @@ -73,27 +129,25 @@
content = [content]
for agent in content:
data = agent_portrayal(agent)
x.append(i)
y.append(j)

# This is the default value for the marker size, which auto-scales
# according to the grid area.
default_size = (180 / max(g.width, g.height)) ** 2
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out
data["x"] = i
data["y"] = j

Check warning on line 133 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L132-L133

Added lines #L132 - L133 were not covered by tests

for key, value in data.items():
if key not in out:
# initialize list
out[key] = [default_values.get(key)] * num_agents
out[key][index] = value
index += 1

Check warning on line 140 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L138-L140

Added lines #L138 - L140 were not covered by tests

return _translate_old_keywords(out)

Check warning on line 142 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L142

Added line #L142 was not covered by tests

space_ax.set_xlim(-1, space.width)
space_ax.set_ylim(-1, space.height)

_split_and_scatter(portray(space), space_ax)


# draws using networkx's matplotlib integration
def _draw_network_grid(space, space_ax, agent_portrayal):
graph = space.G
pos = nx.spring_layout(graph, seed=0)
Expand All @@ -107,28 +161,23 @@

def _draw_continuous_space(space, space_ax, agent_portrayal):
def portray(space):
x = []
y = []
s = [] # size
c = [] # color
m = [] # shape
for agent in space._agent_to_index:
# TODO: look into if more default values are needed
# especially relating to 'color', 'facecolor', and 'c' params &
# interactions w/ the current implementation of _split_and_scatter
default_values = {"s": 20}
out = {}
num_agents = len(space._agent_to_index)

Check warning on line 169 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L167-L169

Added lines #L167 - L169 were not covered by tests

for i, agent in enumerate(space._agent_to_index):
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)

# This is matplotlib's default marker size
default_size = 20
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out
data["x"], data["y"] = agent.pos

Check warning on line 173 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L173

Added line #L173 was not covered by tests

for key, value in data.items():
if key not in out: # initialize list
out[key] = [default_values.get(key, default=None)] * num_agents
out[key][i] = value

Check warning on line 178 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L177-L178

Added lines #L177 - L178 were not covered by tests

return _translate_old_keywords(out)

Check warning on line 180 in mesa/visualization/components/matplotlib.py

View check run for this annotation

Codecov / codecov/patch

mesa/visualization/components/matplotlib.py#L180

Added line #L180 was not covered by tests

# Determine border style based on space.torus
border_style = "solid" if not space.torus else (0, (5, 10))
Expand All @@ -146,7 +195,6 @@
space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding)
space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)

# Portray and scatter the agents in the space
_split_and_scatter(portray(space), space_ax)


Expand Down
3 changes: 2 additions & 1 deletion mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def SolaraViz(
measures: List of callables or data attributes to plot
name: Name for display
agent_portrayal: Options for rendering agents (dictionary);
Default drawer supports custom `"size"`, `"color"`, and `"shape"`.
Default drawer supports custom matplotlib's [scatter](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html)
params, with the exception of (currently) vmin, vmax, & plotnonfinite.
space_drawer: Method to render the agent space for
the model; default implementation is the `SpaceMatplotlib` component;
simulations with no space to visualize should
Expand Down