Skip to content

Commit

Permalink
datacollector: Allow collecting data from Agent (sub)classes (#2300)
Browse files Browse the repository at this point in the history
Enhanced Mesa's DataCollector to allow collecting data from Agent (sub)classes, providing more flexible and granular data collection capabilities.

## Motive
To enable more comprehensive data collection in multi-agent simulations, allowing researchers to track attributes and behaviors specific to different agent types, including custom Agent subclasses.

## Implementation
- Modified `DataCollector` class to accept `agenttype_reporters` parameter
- Added `_new_agenttype_reporter` method for handling agent-type-specific reporters
- Updated `collect` method to handle agent-type-specific data collection
- Added `get_agenttype_vars_dataframe` method for retrieving agent-type-specific data
- Updated `Model` class to support `agenttype_reporters` in `initialize_data_collector`
- Added support for collecting data from all Agent subclasses, not just predefined agent types
- Updated docstrings and module-level documentation
- Added comprehensive unit tests for the new functionality

## Usage Examples
```python
class MyModel(Model):
    def __init__(self):
        super().__init__()
        self.datacollector = DataCollector(
            agent_reporters={"life_span": "life_span"},
            # The new agenttype_reporters argument
            agenttype_reporters={
                Wolf: {"sheep_eaten": "sheep_eaten"},
                Sheep: {"wool": "wool_amount"},
                Animal: {"energy": "energy"}  # Collects from all animals
            }
        )

# Retrieve data for a specific agent type
wolf_data = model.datacollector.get_agenttype_vars_dataframe(Wolf)
# Retrieve data for all Animal subclasses, which are in this case all Wolf and Sheep
animal_data = model.datacollector.get_agenttype_vars_dataframe(Animal)
```

## Additional Notes
- Backward compatible with existing DataCollector usage
- Supports collecting data from custom Agent subclasses and superclasses
  • Loading branch information
EwoutH authored Sep 21, 2024
1 parent 7d4a4af commit e6874ad
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 23 deletions.
159 changes: 136 additions & 23 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Mesa Data Collection Module.
DataCollector is meant to provide a simple, standard way to collect data
generated by a Mesa model. It collects three types of data: model-level data,
agent-level data, and tables.
generated by a Mesa model. It collects four types of data: model-level data,
agent-level data, agent-type-level data, and tables.
A DataCollector is instantiated with two dictionaries of reporter names and
associated variable names or functions for each, one for model-level data and
one for agent-level data; a third dictionary provides table names and columns.
Variable names are converted into functions which retrieve attributes of that
name.
A DataCollector is instantiated with three dictionaries of reporter names and
associated variable names or functions for each, one for model-level data,
one for agent-level data, and one for agent-type-level data; a fourth dictionary
provides table names and columns. Variable names are converted into functions
which retrieve attributes of that name.
When the collect() method is called, each model-level function is called, with
the model as the argument, and the results associated with the relevant
variable. Then the agent-level functions are called on each agent.
variable. Then the agent-level functions are called on each agent, and the
agent-type-level functions are called on each agent of the specified type.
Additionally, other objects can write directly to tables by passing in an
appropriate dictionary object for a table row.
Expand All @@ -21,19 +22,18 @@
* model_vars maps each reporter to a list of its values
* tables maps each table to a dictionary, with each column as a key with a
list as its value.
* _agent_records maps each model step to a list of each agents id
* _agent_records maps each model step to a list of each agent's id
and its values.
* _agenttype_records maps each model step to a dictionary of agent types,
each containing a list of each agent's id and its values.
Finally, DataCollector can create a pandas DataFrame from each collection.
The default DataCollector here makes several assumptions:
* The model has an agent list called agents
* For collecting agent-level variables, agents must have a unique_id
"""

import contextlib
import itertools
import types
import warnings
from copy import deepcopy
from functools import partial

Expand All @@ -44,24 +44,25 @@
class DataCollector:
"""Class for collecting data generated by a Mesa model.
A DataCollector is instantiated with dictionaries of names of model- and
agent-level variables to collect, associated with attribute names or
functions which actually collect them. When the collect(...) method is
called, it collects these attributes and executes these functions one by
one and stores the results.
A DataCollector is instantiated with dictionaries of names of model-,
agent-, and agent-type-level variables to collect, associated with
attribute names or functions which actually collect them. When the
collect(...) method is called, it collects these attributes and executes
these functions one by one and stores the results.
"""

def __init__(
self,
model_reporters=None,
agent_reporters=None,
agenttype_reporters=None,
tables=None,
):
"""Instantiate a DataCollector with lists of model and agent reporters.
"""Instantiate a DataCollector with lists of model, agent, and agent-type reporters.
Both model_reporters and agent_reporters accept a dictionary mapping a
variable name to either an attribute name, a function, a method of a class/instance,
or a function with parameters placed in a list.
Both model_reporters, agent_reporters, and agenttype_reporters accept a
dictionary mapping a variable name to either an attribute name, a function,
a method of a class/instance, or a function with parameters placed in a list.
Model reporters can take four types of arguments:
1. Lambda function:
Expand All @@ -85,6 +86,10 @@ def __init__(
4. Functions with parameters placed in a list:
{"Agent_Function": [function, [param_1, param_2]]}
Agenttype reporters take a dictionary mapping agent types to dictionaries
of reporter names and attributes/funcs/methods, similar to agent_reporters:
{Wolf: {"energy": lambda a: a.energy}}
The tables arg accepts a dictionary mapping names of tables to lists of
columns. For example, if we want to allow agents to write their age
when they are destroyed (to keep track of lifespans), it might look
Expand All @@ -94,6 +99,8 @@ def __init__(
Args:
model_reporters: Dictionary of reporter names and attributes/funcs/methods.
agent_reporters: Dictionary of reporter names and attributes/funcs/methods.
agenttype_reporters: Dictionary of agent types to dictionaries of
reporter names and attributes/funcs/methods.
tables: Dictionary of table names to lists of column names.
Notes:
Expand All @@ -103,9 +110,11 @@ def __init__(
"""
self.model_reporters = {}
self.agent_reporters = {}
self.agenttype_reporters = {}

self.model_vars = {}
self._agent_records = {}
self._agenttype_records = {}
self.tables = {}

if model_reporters is not None:
Expand All @@ -116,6 +125,11 @@ def __init__(
for name, reporter in agent_reporters.items():
self._new_agent_reporter(name, reporter)

if agenttype_reporters is not None:
for agent_type, reporters in agenttype_reporters.items():
for name, reporter in reporters.items():
self._new_agenttype_reporter(agent_type, name, reporter)

if tables is not None:
for name, columns in tables.items():
self._new_table(name, columns)
Expand Down Expand Up @@ -163,6 +177,38 @@ def func_with_params(agent):

self.agent_reporters[name] = reporter

def _new_agenttype_reporter(self, agent_type, name, reporter):
"""Add a new agent-type-level reporter to collect.
Args:
agent_type: The type of agent to collect data for.
name: Name of the agent-type-level variable to collect.
reporter: Attribute string, function object, method of a class/instance, or
function with parameters placed in a list that returns the
variable when given an agent instance.
"""
if agent_type not in self.agenttype_reporters:
self.agenttype_reporters[agent_type] = {}

# Use the same logic as _new_agent_reporter
if isinstance(reporter, str):
attribute_name = reporter

def attr_reporter(agent):
return getattr(agent, attribute_name, None)

reporter = attr_reporter

elif isinstance(reporter, list):
func, params = reporter[0], reporter[1]

def func_with_params(agent):
return func(agent, *params)

reporter = func_with_params

self.agenttype_reporters[agent_type][name] = reporter

def _new_table(self, table_name, table_columns):
"""Add a new table that objects can write to.
Expand Down Expand Up @@ -190,6 +236,34 @@ def get_reports(agent):
)
return agent_records

def _record_agenttype(self, model, agent_type):
"""Record agent-type data in a mapping of functions and agents."""
rep_funcs = self.agenttype_reporters[agent_type].values()

def get_reports(agent):
_prefix = (agent.model.steps, agent.unique_id)
reports = tuple(rep(agent) for rep in rep_funcs)
return _prefix + reports

agent_types = model.agent_types
if agent_type in agent_types:
agents = model.agents_by_type[agent_type]
else:
from mesa import Agent

if issubclass(agent_type, Agent):
agents = [
agent for agent in model.agents if isinstance(agent, agent_type)
]
else:
# Raise error if agent_type is not in model.agent_types
raise ValueError(
f"Agent type {agent_type} is not recognized as an Agent type in the model or Agent subclass. Use an Agent (sub)class, like {agent_types}."
)

agenttype_records = map(get_reports, agents)
return agenttype_records

def collect(self, model):
"""Collect all the data for the given model object."""
if self.model_reporters:
Expand All @@ -208,14 +282,21 @@ def collect(self, model):
elif isinstance(reporter, list):
self.model_vars[var].append(deepcopy(reporter[0](*reporter[1])))
# Assume it's a callable otherwise (e.g., method)
# TODO: Check if method of a class explicitly
else:
self.model_vars[var].append(deepcopy(reporter()))

if self.agent_reporters:
agent_records = self._record_agents(model)
self._agent_records[model.steps] = list(agent_records)

if self.agenttype_reporters:
self._agenttype_records[model.steps] = {}
for agent_type in self.agenttype_reporters:
agenttype_records = self._record_agenttype(model, agent_type)
self._agenttype_records[model.steps][agent_type] = list(
agenttype_records
)

def add_table_row(self, table_name, row, ignore_missing=False):
"""Add a row dictionary to a specific table.
Expand Down Expand Up @@ -272,6 +353,38 @@ def get_agent_vars_dataframe(self):
)
return df

def get_agenttype_vars_dataframe(self, agent_type):
"""Create a pandas DataFrame from the agent-type variables for a specific agent type.
The DataFrame has one column for each variable, with two additional
columns for tick and agent_id.
Args:
agent_type: The type of agent to get the data for.
"""
# Check if self.agenttype_reporters dictionary is empty for this agent type, if so return empty DataFrame
if agent_type not in self.agenttype_reporters:
warnings.warn(
f"No agent-type reporters have been defined for {agent_type} in the DataCollector, returning empty DataFrame.",
UserWarning,
stacklevel=2,
)
return pd.DataFrame()

all_records = itertools.chain.from_iterable(
records[agent_type]
for records in self._agenttype_records.values()
if agent_type in records
)
rep_names = list(self.agenttype_reporters[agent_type])

df = pd.DataFrame.from_records(
data=all_records,
columns=["Step", "AgentID", *rep_names],
index=["Step", "AgentID"],
)
return df

def get_table_dataframe(self, table_name):
"""Create a pandas DataFrame from a particular table.
Expand Down
3 changes: 3 additions & 0 deletions mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,15 @@ def initialize_data_collector(
self,
model_reporters=None,
agent_reporters=None,
agenttype_reporters=None,
tables=None,
) -> None:
"""Initialize the data collector for the model.
Args:
model_reporters: model reporters to collect
agent_reporters: agent reporters to collect
agenttype_reporters: agent type reporters to collect
tables: tables to collect
"""
Expand All @@ -219,6 +221,7 @@ def initialize_data_collector(
self.datacollector = DataCollector(
model_reporters=model_reporters,
agent_reporters=agent_reporters,
agenttype_reporters=agenttype_reporters,
tables=tables,
)
# Collect data for the first time during initialization.
Expand Down
Loading

0 comments on commit e6874ad

Please sign in to comment.