Skip to content

Commit

Permalink
Merge pull request #206 from indralab/hackathon
Browse files Browse the repository at this point in the history
Stratification development and hackathon notebooks
  • Loading branch information
bgyori authored Jul 13, 2023
2 parents 55d6e33 + a08ed55 commit c9771ef
Show file tree
Hide file tree
Showing 16 changed files with 7,040 additions and 110 deletions.
143 changes: 116 additions & 27 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import sympy

from .template_model import TemplateModel, Initial, Parameter
from .template_model import TemplateModel, Initial, Parameter, Observable
from .templates import *
from .units import Unit, dimensionless_units
from .utils import SympyExprStr
Expand All @@ -32,6 +32,10 @@ def stratify(
conversion_cls: Type[Template] = NaturalConversion,
cartesian_control: bool = False,
modify_names: bool = True,
params_to_stratify: Optional[Collection[str]] = None,
params_to_preserve: Optional[Collection[str]] = None,
concepts_to_stratify: Optional[Collection[str]] = None,
concepts_to_preserve: Optional[Collection[str]] = None,
) -> TemplateModel:
"""Multiplies a model into several strata.
Expand Down Expand Up @@ -69,6 +73,22 @@ def stratify(
on cities, since the infected population in one city won't (directly,
through the perspective of the model) affect the infection of susceptible
population in another city.
modify_names :
If true, will modify the names of the concepts to include the strata
(e.g., ``"S"`` becomes ``"S_boston"``). If false, will keep the original
names.
params_to_stratify :
A list of parameters to stratify. If none given, will stratify all
parameters.
params_to_preserve:
A list of parameters to preserve. If none given, will stratify all
parameters.
concepts_to_stratify :
A list of concepts to stratify. If none given, will stratify all
concepts.
concepts_to_preserve:
A list of concepts to preserve. If none given, will stratify all
concepts.
Returns
-------
Expand All @@ -84,20 +104,44 @@ def stratify(
# e.g. unvaccinated -> vaccinated.

concept_map = template_model.get_concepts_map()
concept_names_map = template_model.get_concepts_name_map()
concept_names = set(concept_names_map.keys())

templates = []
params_count = Counter()

# Figure out excluded concepts
if concepts_to_stratify is None:
if concepts_to_preserve is None:
exclude_concepts = set()
else:
exclude_concepts = set(concepts_to_preserve)
else:
if concepts_to_preserve is None:
exclude_concepts = concept_names - set(concepts_to_stratify)
else:
exclude_concepts = set(concepts_to_preserve) | (
concept_names - set(concepts_to_stratify)
)

for template in template_model.templates:
# Generate a derived template for each strata
for stratum in strata:
new_template = template.with_context(
do_rename=modify_names, **{key: stratum},
)
rewrite_rate_law(template, new_template, params_count)
# parameters = list(template_model.get_parameters_from_rate_law(template.rate_law))
# if len(parameters) == 1:
# new_template.set_mass_action_rate_law(parameters[0])
templates.append(new_template)
if set(template.get_concept_names()) - exclude_concepts:
new_template = template.with_context(
do_rename=modify_names, exclude_concepts=exclude_concepts,
**{key: stratum},
)
rewrite_rate_law(template_model=template_model,
old_template=template,
new_template=new_template,
params_count=params_count,
params_to_stratify=params_to_stratify,
params_to_preserve=params_to_preserve)
# parameters = list(template_model.get_parameters_from_rate_law(template.rate_law))
# if len(parameters) == 1:
# new_template.set_mass_action_rate_law(parameters[0])
templates.append(new_template)

# assume all controllers have to get stratified together
# and mixing of strata doesn't occur during control
Expand All @@ -116,6 +160,8 @@ def stratify(
for c_strata_tuple in c_strata_tuples:
stratified_controllers = [
controller.with_context(do_rename=modify_names, **{key: c_stratum})
if controller.name not in exclude_concepts
else controller
for controller, c_stratum in zip(controllers, c_strata_tuple)
]
if isinstance(template, (GroupedControlledConversion, GroupedControlledProduction)):
Expand All @@ -126,7 +172,12 @@ def stratify(
else:
raise NotImplementedError
# the old template is used here on purpose for easier bookkeeping
rewrite_rate_law(template, stratified_template, params_count)
rewrite_rate_law(template_model=template_model,
old_template=template,
new_template=stratified_template,
params_count=params_count,
params_to_stratify=params_to_stratify,
params_to_preserve=params_to_preserve)
templates.append(stratified_template)

parameters = {}
Expand All @@ -145,6 +196,8 @@ def stratify(
# values of the original compartments
initials = {}
for initial_key, initial in template_model.initials.items():
if initial.concept.name in exclude_concepts:
continue
for stratum in strata:
new_concept = initial.concept.with_context(
do_rename=modify_names, **{key: stratum},
Expand All @@ -153,8 +206,25 @@ def stratify(
concept=new_concept, value=initial.value,
)

observables = {}
for observable_key, observable in template_model.observables.items():
syms = {s.name for s in observable.expression.args[0].free_symbols}
expr = deepcopy(observable.expression.args[0])
for sym in (syms & concept_names) - exclude_concepts:
new_symbols = []
for stratum in strata:
new_concept = concept_names_map[sym].with_context(
do_rename=modify_names, **{key: stratum},
)
new_symbols.append(sympy.Symbol(new_concept.name))
expr = expr.subs(sympy.Symbol(sym), sympy.Add(*new_symbols))
observables[observable_key] = deepcopy(observable)
observables[observable_key].expression = SympyExprStr(expr)

# Generate a conversion between each concept of each strata based on the network structure
for (source_stratum, target_stratum), concept in itt.product(structure, concept_map.values()):
if concept.name in exclude_concepts:
continue
subject = concept.with_context(do_rename=modify_names,
**{key: source_stratum})
outcome = concept.with_context(do_rename=modify_names,
Expand All @@ -166,30 +236,47 @@ def stratify(
if not directed:
templates.append(conversion_cls(subject=outcome, outcome=subject))

return TemplateModel(templates=templates,
parameters=parameters,
initials=initials)


def rewrite_rate_law(old_template: Template, new_template: Template, params_count):
new_model = TemplateModel(templates=templates,
parameters=parameters,
initials=initials,
observables=observables,
annotations=deepcopy(template_model.annotations),
time=template_model.time)
# We do this so that any subsequent stratifications will
# be agnostic to previous ones
new_model.reset_base_names()
return new_model


def rewrite_rate_law(template_model: TemplateModel, old_template: Template,
new_template: Template, params_count,
params_to_stratify=None, params_to_preserve=None):
# Rewrite the rate law by substituting new symbols corresponding
# to the stratified controllers in for the originals
rate_law = old_template.rate_law
if not rate_law:
return

# Step 1. Identify the mass action symbol and rename it with a
# TODO replace with pre-existing TemplateModel.get_parameters_from_rate_law()
try:
parameter = old_template.get_mass_action_symbol()
except ValueError:
parameter = None
if parameter:
rate_law = rate_law.subs(
parameter.name,
sympy.Symbol(f"{parameter.name}_{params_count[parameter.name]}")
)
params_count[parameter.name] += 1 # increment this each time to keep unique
parameters = list(template_model.get_parameters_from_rate_law(rate_law))
for parameter in parameters:
# If a parameter is explicitly listed as one to preserve, then
# don't stratify it
if params_to_preserve is not None and parameter in params_to_preserve:
continue
# If we have an explicit stratification list then if something isn't
# in the list then don't stratify it.
elif params_to_stratify is not None and parameter not in params_to_stratify:
continue
# Otherwise we go ahead with stratification, i.e., in cases
# where nothing was said about parameter stratification or the
# parameter was listed explicitly to be stratified
else:
rate_law = rate_law.subs(
parameter,
sympy.Symbol(f"{parameter}_{params_count[parameter]}")
)
params_count[parameter] += 1 # increment this each time to keep unique

# Step 2. Rename symbols corresponding to compartments based on the new concepts
for old_controller, new_controller in zip(
Expand Down Expand Up @@ -457,6 +544,8 @@ def counts_to_dimensionless(tm: TemplateModel,
if p.units:
(coeff, exponent) = \
p.units.expression.args[0].as_coeff_exponent(counts_unit_symbol)
if isinstance(exponent, sympy.core.numbers.One):
exponent = 1
if exponent:
p.units.expression = \
SympyExprStr(p.units.expression.args[0] /
Expand Down
21 changes: 16 additions & 5 deletions mira/metamodel/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,21 +511,32 @@ def print_params_table(self):
print(tabulate.tabulate(rows, headers='firstrow'))

def get_concepts_map(self):
"""
Get a mapping from concept keys to concepts that
appear in this template models' templates.
"""Return a mapping from concept keys to concepts that
appear in this template model's templates.
"""
return {concept.get_key(): concept for concept in _iter_concepts(self)}

def get_concepts_name_map(self):
"""Return a mapping from concept names to concepts that
appear in this template model's templates.
"""
return {concept.name: concept for concept in _iter_concepts(self)}

def get_concept(self, name: str) -> Optional[Concept]:
"""Get the first concept that has the given name."""
"""Return the first concept that has the given name."""
names = self.get_concepts_by_name(name)
if names:
return names[0]
return None

def reset_base_names(self):
"""Reset the base names of all concepts in this model to the current name."""
for template in self.templates:
for concept in template.get_concepts():
concept._base_name = concept.name

def get_concepts_by_name(self, name: str) -> List[Concept]:
"""Get a list of all concepts that have the given name.
"""Return a list of all concepts that have the given name.
.. warning::
Expand Down
Loading

0 comments on commit c9771ef

Please sign in to comment.