diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index 83d0e3d8eaf..f85b640ed25 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,11 +1,14 @@ from collections import defaultdict +import matplotlib.pyplot as plt import networkx as nx import solara +from matplotlib.colors import Normalize from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator import mesa +from mesa.space import GridContent @solara.component @@ -22,47 +25,100 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non _draw_continuous_space(space, space_ax, agent_portrayal) else: _draw_grid(space, space_ax, agent_portrayal) + solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) +# used to make non(less?)-breaking change +# this *does* however block the matplotlib 'color' param which is somewhat distinct from 'c'. +# maybe translate 'size' and 'shape' but not 'color'? +def _translate_old_keywords(data): + """ + Translates old keyword names in the given dictionary to the new names. + """ + key_mapping = {"size": "s", "color": "c", "shape": "marker"} + return {key_mapping.get(key, key): val for (key, val) in data.items()} + + +def _apply_color_map(color, cmap=None, norm=None, vmin=None, vmax=None): + """ + Given parameters for manual colormap application, applies color map + according to default implementation in matplotlib + """ + if not cmap: # if no colormap is provided, return original color + return color + color_map = plt.get_cmap(cmap) + if norm: # check if norm is provided and apply it + if not isinstance(norm, Normalize): + raise TypeError( + "'norm' must be an instance of Normalize or its subclasses." + ) + return color_map(norm(color)) + if not (vmin == None or vmax == None): # check for custom norm params + new_norm = Normalize(vmin, vmax) + return color_map(new_norm(color)) + try: + return color_map(color) + except Exception as e: + raise ValueError("Color mapping failed due to invalid arguments") from e + + # matplotlib scatter does not allow for multiple shapes in one call -def _split_and_scatter(portray_data, space_ax): - grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []}) - - # Extract data from the dictionary - x = portray_data["x"] - y = portray_data["y"] - s = portray_data["s"] - c = portray_data["c"] - m = portray_data["m"] - - if not (len(x) == len(y) == len(s) == len(c) == len(m)): - raise ValueError( - "Length mismatch in portrayal data lists: " - f"x: {len(x)}, y: {len(y)}, size: {len(s)}, " - f"color: {len(c)}, marker: {len(m)}" - ) - - # Group the data by marker - for i in range(len(x)): - marker = m[i] - grouped_data[marker]["x"].append(x[i]) - grouped_data[marker]["y"].append(y[i]) - grouped_data[marker]["s"].append(s[i]) - grouped_data[marker]["c"].append(c[i]) - - # Plot each group with the same marker +def _split_and_scatter(portray_data: dict, space_ax) -> None: + # if any of the following params are passed into portray(), this is true + cmap_exists = portray_data.pop("cmap", None) + norm_exists = portray_data.pop("norm", None) + vmin_exists = portray_data.pop("vmin", None) + vmax_exists = portray_data.pop("vmax", None) + + # enforce marker iterability + markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) + # enforce default color + if ( # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color + "color" not in portray_data + and "facecolor" not in portray_data + and "c" not in portray_data + ): + portray_data["color"] = ["tab:blue"] * len(portray_data["x"]) + + grouped_data = defaultdict(lambda: {key: [] for key in portray_data}) + + for i, marker in enumerate(markers): + for key in portray_data: + if key == "c": # apply colormap if possible + # prepare arguments + cmap = cmap_exists[i] if cmap_exists else None + norm = norm_exists[i] if norm_exists else None + vmin = vmin_exists[i] if vmin_exists else None + vmax = vmax_exists[i] if vmax_exists else None + # apply colormap with prepared arguments + portray_data["c"][i] = _apply_color_map( + portray_data["c"][i], cmap, norm, vmin, vmax + ) + + grouped_data[marker][key].append(portray_data[key][i]) + for marker, data in grouped_data.items(): - space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker) + space_ax.scatter(marker=marker, **data) def _draw_grid(space, space_ax, agent_portrayal): def portray(g): - x = [] - y = [] - s = [] # size - c = [] # color - m = [] # shape + default_values = { + "size": (180 / max(g.width, g.height)) ** 2, + } + + out = {} + num_agents = 0 + for content in g: + if not content: + continue + if isinstance(content, GridContent): # one agent + num_agents += 1 + continue + num_agents += len(content) + + index = 0 for i in range(g.width): for j in range(g.height): content = g._grid[i][j] @@ -73,27 +129,25 @@ def portray(g): content = [content] for agent in content: data = agent_portrayal(agent) - x.append(i) - y.append(j) - - # This is the default value for the marker size, which auto-scales - # according to the grid area. - default_size = (180 / max(g.width, g.height)) ** 2 - # establishing a default prevents misalignment if some agents are not given size, color, etc. - size = data.get("size", default_size) - s.append(size) - color = data.get("color", "tab:blue") - c.append(color) - mark = data.get("shape", "o") - m.append(mark) - out = {"x": x, "y": y, "s": s, "c": c, "m": m} - return out + data["x"] = i + data["y"] = j + + for key, value in data.items(): + if key not in out: + # initialize list + out[key] = [default_values.get(key)] * num_agents + out[key][index] = value + index += 1 + + return _translate_old_keywords(out) space_ax.set_xlim(-1, space.width) space_ax.set_ylim(-1, space.height) + _split_and_scatter(portray(space), space_ax) +# draws using networkx's matplotlib integration def _draw_network_grid(space, space_ax, agent_portrayal): graph = space.G pos = nx.spring_layout(graph, seed=0) @@ -107,28 +161,23 @@ def _draw_network_grid(space, space_ax, agent_portrayal): def _draw_continuous_space(space, space_ax, agent_portrayal): def portray(space): - x = [] - y = [] - s = [] # size - c = [] # color - m = [] # shape - for agent in space._agent_to_index: + # TODO: look into if more default values are needed + # especially relating to 'color', 'facecolor', and 'c' params & + # interactions w/ the current implementation of _split_and_scatter + default_values = {"s": 20} + out = {} + num_agents = len(space._agent_to_index) + + for i, agent in enumerate(space._agent_to_index): data = agent_portrayal(agent) - _x, _y = agent.pos - x.append(_x) - y.append(_y) - - # This is matplotlib's default marker size - default_size = 20 - # establishing a default prevents misalignment if some agents are not given size, color, etc. - size = data.get("size", default_size) - s.append(size) - color = data.get("color", "tab:blue") - c.append(color) - mark = data.get("shape", "o") - m.append(mark) - out = {"x": x, "y": y, "s": s, "c": c, "m": m} - return out + data["x"], data["y"] = agent.pos + + for key, value in data.items(): + if key not in out: # initialize list + out[key] = [default_values.get(key, default=None)] * num_agents + out[key][i] = value + + return _translate_old_keywords(out) # Determine border style based on space.torus border_style = "solid" if not space.torus else (0, (5, 10)) @@ -146,7 +195,6 @@ def portray(space): space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) - # Portray and scatter the agents in the space _split_and_scatter(portray(space), space_ax) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 6ec33231cae..347a44616d9 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -104,7 +104,8 @@ def SolaraViz( measures: List of callables or data attributes to plot name: Name for display agent_portrayal: Options for rendering agents (dictionary); - Default drawer supports custom `"size"`, `"color"`, and `"shape"`. + Default drawer supports custom matplotlib's [scatter](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html) + params, with the exception of (currently) vmin, vmax, & plotnonfinite. space_drawer: Method to render the agent space for the model; default implementation is the `SpaceMatplotlib` component; simulations with no space to visualize should