Skip to content

Commit

Permalink
feat: Implement experimental DataCollector API
Browse files Browse the repository at this point in the history
  • Loading branch information
rht committed Feb 4, 2024
1 parent 2dc485f commit c08bede
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions mesa/experimental/measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from collections import defaultdict

class Group:
def __init__(self, model, fn):
self.model = model
self.fn = fn

@property
def value(self):
return self.fn(self.model)

class Measure:
def __init__(self, group, measurer):
self.group = group
self.measurer = measurer

def _measure_group(self, group, measurer):
# get an attribute
if isinstance(measurer, str):
return getattr(group, measurer)
# apply
return measurer(group)

@property
def value(self):
group_object = self.group
if isinstance(self.group, Group):
group_object = self.group.value
return self._measure_group(group_object, self.measurer)


class DataCollector:
"""
Example: a model consisting of a hybrid of Boltzmann wealth model and
Epstein civil violence.
class EpsteinBoltzmannModel:
def __init__(self):
# Groups
self.quiescents = Group(
lambda model: model.agents.select(
agent_type=Citizen, filter_func=lambda a: a.condition == "Quiescent"
)
)
self.citizens = Group(lambda model: model.get_agents_of_type(Citizen))
# Measurements
self.num_quiescents = Measure(self.quiescents, len)
self.gini = Measure(
self.agents, lambda agents: calculate_gini(agents.get("wealth"))
)
self.gini_quiescents = Measure(
self.quiescents, lambda agents: calculate_gini(agents.get("wealth"))
)
self.condition = Measure(self.citizens, "condition")
self.wealth = Measure(self.agents, "wealth")
def run():
model = EpsteinBoltzmannModel()
datacollector = DataCollector(
model, ["num_quiescents", "gini_quiescents", "wealth"]
)
for _ in range(10):
model.step()
datacollector.collect()
"""
def __init__(self, model, attributes):
self.model = model
self.attributes = attributes
self.data_collection = defaultdict(list)

def collect(self):
for name in self.attributes:
attribute = getattr(self.model, name)
if isinstance(attribute, Measure):
attribute = attribute.value
self.data_collection[name] = attribute

0 comments on commit c08bede

Please sign in to comment.