diff --git a/mesa/experimental/__init__.py b/mesa/experimental/__init__.py index e37cdb042e3..f12c2fb2928 100644 --- a/mesa/experimental/__init__.py +++ b/mesa/experimental/__init__.py @@ -1,3 +1,4 @@ from mesa.experimental import cell_space +from mesa.experimental.datacollector import DataCollector -__all__ = ["cell_space"] +__all__ = ["cell_space", "DataCollector"] diff --git a/mesa/experimental/datacollector.py b/mesa/experimental/datacollector.py new file mode 100644 index 00000000000..61b63f9ddb9 --- /dev/null +++ b/mesa/experimental/datacollector.py @@ -0,0 +1,42 @@ +from collections import defaultdict + +import pandas as pd + + +class DataCollector: + def __init__(self, model, group_reporters, groups=None): + self.model = model + self.group_reporters = group_reporters + self.groups = groups + self.data = defaultdict(lambda: defaultdict(list)) + + def get_group(self, group_name): + if group_name == "model": + return self.model + elif group_name == "agents": + return self.model.agents + else: + try: + return getattr(self.model, group_name) + except AttributeError as e: + raise Exception(f"Unknown group: {group_name}") from e + + def report(self, reporter, group): + if group is self.model: + return reporter(group) + if isinstance(reporter, str): + if hasattr(group, "get"): + return group.get(reporter) + else: + raise Exception() + return reporter(group) + + def collect(self): + for group_name, reporters in self.group_reporters.items(): + group = self.get_group(group_name) + for name, reporter in reporters.items(): + value = self.report(reporter, group) + self.data[group_name][name].append(value) + + def to_df(self, group_name): + return pd.DataFrame(self.data[group_name]) diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index aadfa206472..c2d725e9af1 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -117,7 +117,7 @@ def portray(space): def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): fig = Figure() ax = fig.subplots() - df = model.datacollector.get_model_vars_dataframe() + df = model.datacollector.to_df("model") if isinstance(measure, str): ax.plot(df.loc[:, measure]) ax.set_ylabel(measure)