Skip to content

Commit

Permalink
Implement handling of observables
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Jul 13, 2023
1 parent 516256a commit 8c7207e
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 56 deletions.
22 changes: 19 additions & 3 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import sympy

from .template_model import TemplateModel, Initial, Parameter
from .template_model import TemplateModel, Initial, Parameter, Observable
from .templates import *
from .units import Unit, dimensionless_units
from .utils import SympyExprStr
Expand Down Expand Up @@ -104,6 +104,8 @@ def stratify(
# e.g. unvaccinated -> vaccinated.

concept_map = template_model.get_concepts_map()
concept_names_map = template_model.get_concepts_name_map()
concept_names = set(concept_names_map.keys())

templates = []
params_count = Counter()
Expand All @@ -115,7 +117,6 @@ def stratify(
else:
exclude_concepts = set(concepts_to_preserve)
else:
concept_names = set(template_model.get_concepts_name_map())
if concepts_to_preserve is None:
exclude_concepts = concept_names - set(concepts_to_stratify)
else:
Expand Down Expand Up @@ -205,6 +206,21 @@ def stratify(
concept=new_concept, value=initial.value,
)

observables = {}
for observable_key, observable in template_model.observables.items():
syms = set(observable.expression.args[0].free_symbols)
expr = deepcopy(observable.expression.args[0])
for sym in (syms & concept_names) - exclude_concepts:
new_symbols = []
for stratum in strata:
new_concept = concept_names_map[sym].with_context(
do_rename=modify_names, **{key: stratum},
)
new_symbols.append(sympy.Symbol(new_concept.name))
expr = expr.args[0].subs(sym, sympy.Add(*new_symbols))
observables[observable_key] = deepcopy(observable)
observables[observable_key].expression = SympyExprStr(expr)

# Generate a conversion between each concept of each strata based on the network structure
for (source_stratum, target_stratum), concept in itt.product(structure, concept_map.values()):
if concept.name in exclude_concepts:
Expand All @@ -225,7 +241,7 @@ def stratify(
new_model = TemplateModel(templates=templates,
parameters=parameters,
initials=initials,
observables=deepcopy(template_model.observables),
observables=observables,
annotations=deepcopy(template_model.annotations),
time=template_model.time)
# We do this so that any subsequent stratifications will
Expand Down
Loading

0 comments on commit 8c7207e

Please sign in to comment.