diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 3a8f2051d..609c4f533 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -258,6 +258,98 @@ def stratify( all_param_mappings[old_param].add(new_param) templates.append(stratified_template) + # Handle initial values and expressions depending on different + # criteria + initials = {} + param_value_mappings = {} + for initial_key, initial in template_model.initials.items(): + # We need to keep track of whether we stratified any parameters in + # the expression for this initial and if the parameter is being + # replaced by multiple stratified parameters + any_param_stratified = False + param_replacements = defaultdict(set) + + for stratum_idx, stratum in enumerate(strata): + # Figure out if the concept for this initial is one that we + # need to stratify or not + if (exclude_concepts and initial.concept.name in exclude_concepts) or \ + (concepts_to_preserve and initial.concept.name in concepts_to_preserve): + # Just make a copy of the original initial concept + new_concept = deepcopy(initial.concept) + concept_stratified = False + else: + # We create a new concept for the given stratum + new_concept = initial.concept.with_context( + do_rename=modify_names, + curie_to_name_map=strata_curie_to_name, + **{key: stratum}, + ) + concept_stratified = True + # Now we may have to rewrite the expression so that we can + # update for stratified parameters so we make a copy and figure + # out what parameters are in the expression + new_expression = deepcopy(initial.expression) + init_expr_params = template_model.get_parameters_from_expression( + new_expression.args[0] + ) + template_strata = [stratum if + param_renaming_uses_strata_names else stratum_idx] + for parameter in init_expr_params: + # If a parameter is explicitly listed as one to preserve, then + # don't stratify it + if params_to_preserve is not None and parameter in params_to_preserve: + continue + # If we have an explicit stratification list then if something isn't + # in the list then don't stratify it. + elif params_to_stratify is not None and parameter not in params_to_stratify: + continue + # Otherwise we go ahead with stratification, i.e., in cases + # where nothing was said about parameter stratification or the + # parameter was listed explicitly to be stratified + else: + # We create a new parameter symbol for the given stratum + param_suffix = '_'.join([str(s) for s in template_strata]) + new_param = f'{parameter}_{param_suffix}' + any_param_stratified = True + all_param_mappings[parameter].add(new_param) + # We need to update the new, stratified parameter's value + # to be the original parameter's value divided by the number + # of strata + param_value_mappings[new_param] = \ + template_model.parameters[parameter].value / len(strata) + # If the concept is not stratified then we have to replace + # the original parameter with the sum of stratified ones + # so we just keep track of that in a set + if not concept_stratified: + param_replacements[parameter].add(new_param) + # Otherwise we have to rewrite the expression to use the + # new parameter as replacement for the original one + else: + new_expression = new_expression.subs(parameter, + sympy.Symbol(new_param)) + + # If we stratified any parameters in the expression then we have + # to update the initial value expression to reflect that + if any_param_stratified: + if param_replacements: + for orig_param, new_params in param_replacements.items(): + new_expression = new_expression.subs( + orig_param, + sympy.Add(*[sympy.Symbol(np) for np in new_params]) + ) + new_initial = new_expression + # Otherwise we can just use the original expression, except if the + # concept was stratified, then we have to divide the initial + # expression into as many parts as there are strata + else: + if concept_stratified: + new_initial = SympyExprStr(new_expression.args[0] / len(strata)) + else: + new_initial = new_expression + + initials[new_concept.name] = \ + Initial(concept=new_concept, expression=new_initial) + parameters = {} for parameter_key, parameter in template_model.parameters.items(): if parameter_key not in all_param_mappings: @@ -273,26 +365,10 @@ def stratify( for stratified_param in all_param_mappings[parameter_key]: d = deepcopy(parameter) d.name = stratified_param + if stratified_param in param_value_mappings: + d.value = param_value_mappings[stratified_param] parameters[stratified_param] = d - # Create new initial values for each of the strata - # of the original compartments, copied from the initial - # values of the original compartments - initials = {} - for initial_key, initial in template_model.initials.items(): - if initial.concept.name in exclude_concepts: - initials[initial.concept.name] = deepcopy(initial) - continue - for stratum in strata: - new_concept = initial.concept.with_context( - do_rename=modify_names, - curie_to_name_map=strata_curie_to_name, - **{key: stratum}, - ) - initials[new_concept.name] = Initial( - concept=new_concept, expression=SympyExprStr(initial.expression.args[0] / len(strata)) - ) - observables = {} for observable_key, observable in template_model.observables.items(): syms = {s.name for s in observable.expression.args[0].free_symbols} diff --git a/mira/metamodel/template_model.py b/mira/metamodel/template_model.py index 2c4fade32..fb4874d2d 100644 --- a/mira/metamodel/template_model.py +++ b/mira/metamodel/template_model.py @@ -372,41 +372,60 @@ class TemplateModel(BaseModel): "Note that all annotations are optional.", ) - def get_parameters_from_rate_law(self, rate_law) -> Set[str]: - """Given a rate law, find its elements that are model parameters. + def get_parameters_from_expression(self, expression) -> Set[str]: + """Given a symbolic expression, find its elements that are model parameters. - Rate laws consist of some combination of participants, rate parameters - and potentially other factors. This function finds those elements of - rate laws that are rate parameters. + Expressions such as rate laws consist of some combination of participants, + rate parameters and potentially other factors. This function finds those + elements of expressions that are rate parameters. Parameters ---------- - rate_law : sympy.Symbol | sympy.Expr - A sympy expression or symbol, whose names are extracted. + expression : sympy.Symbol | sympy.Expr + A sympy expression or symbol, whose parameters are extracted. Returns ------- : A set of parameter names (as strings). """ - if rate_law is None: + if expression is None: return set() params = set() - if isinstance(rate_law, sympy.Symbol): - if rate_law.name in self.parameters: + if isinstance(expression, sympy.Symbol): + if expression.name in self.parameters: # add the string name to the set - params.add(rate_law.name) + params.add(expression.name) # There are many sympy classes that have args that can occur here # so it's better to check for the presence of args - elif not hasattr(rate_law, "args"): + elif not hasattr(expression, "args"): raise ValueError( - f"Rate law is of invalid type {type(rate_law)}: {rate_law}" + f"Rate law is of invalid type {type(expression)}: {expression}" ) else: - for arg in rate_law.args: - params |= self.get_parameters_from_rate_law(arg) + for arg in expression.args: + params |= self.get_parameters_from_expression(arg) return params + def get_parameters_from_rate_law(self, rate_law) -> Set[str]: + """Given a rate law, find its elements that are model parameters. + + Rate laws consist of some combination of participants, rate parameters + and potentially other factors. This function finds those elements of + rate laws that are rate parameters. + + Parameters + ---------- + rate_law : sympy.Symbol | sympy.Expr + A sympy expression or symbol, whose parameters are extracted. + + Returns + ------- + : + A set of parameter names (as strings). + """ + return self.get_parameters_from_expression(rate_law) + def update_parameters(self, parameter_dict): """ Update parameter values. diff --git a/tests/test_ops.py b/tests/test_ops.py index 585daab74..f40f199b8 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -636,3 +636,44 @@ def test_add_observable_pattern(): assert 'young' in tm.observables obs = tm.observables['young'] assert obs.expression.args[0] == sympy.Symbol('A_young') + sympy.Symbol('B_young') + + +def test_stratify_initials_parameters(): + s = Concept(name='S') + t = NaturalDegradation(subject=s, rate_law=sympy.Symbol('alpha') * + sympy.Symbol(s.name)) + S0 = Initial(concept=s, expression=sympy.Symbol('S0')) + tm = TemplateModel(templates=[t], + parameters={'alpha': Parameter(name='alpha', value=0.1), + 'S0': Parameter(name='S0', value=1000)}, + initials={'S': S0}) + tm1 = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True) + assert 'S_young' in tm1.initials + assert tm1.initials['S_young'].expression.args[0] == sympy.Symbol('S0_young') + assert 'S_old' in tm1.initials + assert tm1.initials['S_old'].expression.args[0] == sympy.Symbol('S0_old') + assert 'S0_young' in tm1.parameters + assert tm1.parameters['S0_young'].value == 500 + assert 'S0_old' in tm1.parameters + assert tm1.parameters['S0_old'].value == 500 + + tm2 = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True, + params_to_preserve={'S0'}) + assert 'S_young' in tm2.initials + assert tm2.initials['S_young'].expression.args[0] == sympy.Symbol('S0') / 2 + assert 'S_old' in tm2.initials + assert tm2.initials['S_old'].expression.args[0] == sympy.Symbol('S0') / 2 + assert 'S0' in tm2.parameters + assert tm2.parameters['S0'].value == 1000 + + tm3 = stratify(tm, key='age', strata=['young', 'old'], structure=[], + param_renaming_uses_strata_names=True, + concepts_to_preserve={'S'}) + assert set(tm3.initials) == {'S'} + assert tm3.initials['S'].expression.args[0] == \ + sympy.Symbol('S0_old') + sympy.Symbol('S0_young') + assert set(tm3.parameters) == {'alpha', 'S0_old', 'S0_young'} + assert tm3.parameters['S0_old'].value == 500 + assert tm3.parameters['S0_young'].value == 500