From f26b909e3706c0a4519ee94a07fa8989fdb85b15 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Wed, 18 Sep 2024 13:07:21 -0400 Subject: [PATCH] Simplify API for specifying observable patterns --- mira/metamodel/ops.py | 36 +++++++++++++-------------------- notebooks/viz_strat_petri.ipynb | 2 +- tests/test_ops.py | 12 +++++------ 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 4252f9557..a4b209d43 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -20,7 +20,8 @@ "aggregate_parameters", "get_term_roles", "counts_to_dimensionless", - "deactivate_templates" + "deactivate_templates", + "add_observable_pattern", ] @@ -748,9 +749,9 @@ def deactivate_templates( def add_observable_pattern( template_model: TemplateModel, - concept_pattern: Concept, name: str, - refinement_func=None, + identifiers: Mapping = None, + context: Mapping = None, ): """Add an observable for a pattern of concepts. @@ -758,33 +759,24 @@ def add_observable_pattern( ---------- template_model : A template model. - concept_pattern : - A concept pattern. name : The name of the observable. - - Returns - ------- - : - A template model with the observable added. + identifiers : + Identifiers serving as a pattern for concepts to observe. + context : + Context serving as a pattern for concepts to observe. """ observable_concepts = [] - identifiers = set(concept_pattern.identifiers.items()) - contexts = set(concept_pattern.context.items()) - name_only = (not identifiers) and (not contexts) + identifiers = set(identifiers.items() if identifiers else {}) + contexts = set(context.items() if context else {}) for key, concept in template_model.get_concepts_map().items(): - if name_only: - if concept.name == concept_pattern.name: + if (not identifiers) or identifiers.issubset( + set(concept.identifiers.items())): + if (not contexts) or contexts.issubset( + set(concept.context.items())): observable_concepts.append(concept) - else: - if (not identifiers) or identifiers.issubset( - set(concept.identifiers.items())): - if (not contexts) or contexts.issubset( - set(concept.context.items())): - observable_concepts.append(concept) obs = get_observable_for_concepts(observable_concepts, name) template_model.observables[name] = obs - return template_model def get_observable_for_concepts(concepts: List[Concept], name: str): diff --git a/notebooks/viz_strat_petri.ipynb b/notebooks/viz_strat_petri.ipynb index 775d7fd87..78c33a447 100644 --- a/notebooks/viz_strat_petri.ipynb +++ b/notebooks/viz_strat_petri.ipynb @@ -329,7 +329,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tests/test_ops.py b/tests/test_ops.py index 867d9731d..7276e131c 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -607,12 +607,12 @@ def test_add_observable_pattern(): tm = TemplateModel(templates=templates, parameters={'alpha': Parameter(name='alpha', value=0.1)}) tm = stratify(tm, key='age', strata=['young', 'old'], structure=[]) - add_observable_pattern(tm, Concept(name='A', identifiers={'ido': '0000514'}), 'obs') - assert 'obs' in tm.observables - obs = tm.observables['obs'] + add_observable_pattern(tm, name='A', identifiers={'ido': '0000514'}) + assert 'A' in tm.observables + obs = tm.observables['A'] assert obs.expression.args[0] == sympy.Symbol('A_old') + sympy.Symbol('A_young') - add_observable_pattern(tm, Concept(name='young', context={'age': 'young'}), 'obs2') - assert 'obs2' in tm.observables - obs = tm.observables['obs2'] + add_observable_pattern(tm, 'young', context={'age': 'young'}) + assert 'young' in tm.observables + obs = tm.observables['young'] assert obs.expression.args[0] == sympy.Symbol('A_young') + sympy.Symbol('B_young')