Skip to content

Commit

Permalink
Simplify API for specifying observable patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Sep 20, 2024
1 parent d24064d commit f26b909
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 29 deletions.
36 changes: 14 additions & 22 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"aggregate_parameters",
"get_term_roles",
"counts_to_dimensionless",
"deactivate_templates"
"deactivate_templates",
"add_observable_pattern",
]


Expand Down Expand Up @@ -748,43 +749,34 @@ def deactivate_templates(

def add_observable_pattern(
template_model: TemplateModel,
concept_pattern: Concept,
name: str,
refinement_func=None,
identifiers: Mapping = None,
context: Mapping = None,
):
"""Add an observable for a pattern of concepts.
Parameters
----------
template_model :
A template model.
concept_pattern :
A concept pattern.
name :
The name of the observable.
Returns
-------
:
A template model with the observable added.
identifiers :
Identifiers serving as a pattern for concepts to observe.
context :
Context serving as a pattern for concepts to observe.
"""
observable_concepts = []
identifiers = set(concept_pattern.identifiers.items())
contexts = set(concept_pattern.context.items())
name_only = (not identifiers) and (not contexts)
identifiers = set(identifiers.items() if identifiers else {})
contexts = set(context.items() if context else {})
for key, concept in template_model.get_concepts_map().items():
if name_only:
if concept.name == concept_pattern.name:
if (not identifiers) or identifiers.issubset(
set(concept.identifiers.items())):
if (not contexts) or contexts.issubset(
set(concept.context.items())):
observable_concepts.append(concept)
else:
if (not identifiers) or identifiers.issubset(
set(concept.identifiers.items())):
if (not contexts) or contexts.issubset(
set(concept.context.items())):
observable_concepts.append(concept)
obs = get_observable_for_concepts(observable_concepts, name)
template_model.observables[name] = obs
return template_model


def get_observable_for_concepts(concepts: List[Concept], name: str):
Expand Down
2 changes: 1 addition & 1 deletion notebooks/viz_strat_petri.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,12 @@ def test_add_observable_pattern():
tm = TemplateModel(templates=templates,
parameters={'alpha': Parameter(name='alpha', value=0.1)})
tm = stratify(tm, key='age', strata=['young', 'old'], structure=[])
add_observable_pattern(tm, Concept(name='A', identifiers={'ido': '0000514'}), 'obs')
assert 'obs' in tm.observables
obs = tm.observables['obs']
add_observable_pattern(tm, name='A', identifiers={'ido': '0000514'})
assert 'A' in tm.observables
obs = tm.observables['A']
assert obs.expression.args[0] == sympy.Symbol('A_old') + sympy.Symbol('A_young')

add_observable_pattern(tm, Concept(name='young', context={'age': 'young'}), 'obs2')
assert 'obs2' in tm.observables
obs = tm.observables['obs2']
add_observable_pattern(tm, 'young', context={'age': 'young'})
assert 'young' in tm.observables
obs = tm.observables['young']
assert obs.expression.args[0] == sympy.Symbol('A_young') + sympy.Symbol('B_young')

0 comments on commit f26b909

Please sign in to comment.