Skip to content

Commit

Permalink
Merge pull request #210 from indralab/semantics
Browse files Browse the repository at this point in the history
Implement additional parameter logic for semantics reconstitution
  • Loading branch information
bgyori authored Jul 14, 2023
2 parents 02fb6cd + c4ae22a commit add147d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
37 changes: 36 additions & 1 deletion mira/sources/askenet/flux_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

def reproduce_ode_semantics(flux_span):
"""Reproduce ODE semantics from a flux span."""

# First we make the original template model
tm = template_model_from_askenet_json(flux_span)

# We grab the relevant pieces of the flux span structure
semantics = flux_span['semantics']
span_1 = semantics['span'][0]
span_2 = semantics['span'][1]
Expand All @@ -27,6 +31,9 @@ def reproduce_ode_semantics(flux_span):
tm_2 = template_model_from_askenet_json(model_2)
sem_1 = model_1['semantics']['ode']
sem_2 = model_2['semantics']['ode']

# Make sure we have forward and reverse mappings between
# the original models and the stratified model
map_1 = dict(span_1['map'])
map_2 = dict(span_2['map'])
reverse_map_1 = defaultdict(list)
Expand All @@ -36,10 +43,16 @@ def reproduce_ode_semantics(flux_span):
for k, v in map_2.items():
reverse_map_2[v].append(k)

# We need to be able to look up templates by ID
template_map = {t.name: t for t in tm.templates}
template_map_1 = {t.name: t for t in tm_1.templates}
template_map_2 = {t.name: t for t in tm_2.templates}

# To handle non-standard transitions we need to also
# have a transition mapping
transition_map_1 = {t['id']: t for t in model_1['model']['transitions']}
transition_map_2 = {t['id']: t for t in model_2['model']['transitions']}

# If we are missing semantics, we have to make them up
if not sem_1:
set_semantics(tm_1, '1')
Expand All @@ -56,12 +69,32 @@ def reproduce_ode_semantics(flux_span):
# Find what this template is mapped to in the original models
mapped_1 = map_1[template.name]
mapped_2 = map_2[template.name]
# Find the template in the original models - only one of these exists

# Find the transitions and templates in the original models
transition_1 = transition_map_1[mapped_1]
transition_2 = transition_map_2[mapped_2]
template_1 = template_map_1.get(mapped_1)
template_2 = template_map_2.get(mapped_2)

# This happens if the transition is not a standard transition
# but something like old -> old or old + young -> old + young.
# We can pragmatically handle these as cases where only an
# extra parameter needs to be introduced and applied to the
# joint rate law.
if not template_1:
extra_param = f'p_1_{transition_1["id"]}'
elif not template_2:
extra_param = f'p_2_{transition_2["id"]}'
else:
extra_param = None

if extra_param:
all_parameters[extra_param] = Parameter(name=extra_param, value=1.0)

original_map = map_1 if template_1 else map_2
original_model = tm_1 if template_1 else tm_2
original_template = template_1 if template_1 else template_2

# Find the rate law components in the original model
rate_law = deepcopy(original_template.rate_law.args[0])
# Now we need to map states to new states
Expand All @@ -71,6 +104,8 @@ def reproduce_ode_semantics(flux_span):
original_concept = original_map[concept_name]
rate_law = rate_law.subs(sympy.Symbol(original_concept),
sympy.Symbol(concept_name))
if extra_param:
rate_law *= sympy.Symbol(extra_param)
template.rate_law = SympyExprStr(rate_law)

# Deal with observables
Expand Down
2 changes: 1 addition & 1 deletion tests/test_flux_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ def test_flux_span_ode_semantics():
flux_span = json.load(fh)
tm = reproduce_ode_semantics(flux_span)
assert len(tm.templates) == 10
assert len(tm.parameters) == 4
assert len(tm.parameters) == 11
assert all(t.rate_law for t in tm.templates)
2 changes: 1 addition & 1 deletion tests/test_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,6 @@ def test_flux_span_endpoint(self):
flux_span_tm_json = response.json()
flux_span_tm = TemplateModel.from_json(flux_span_tm_json)
assert len(flux_span_tm.templates) == 10
assert len(flux_span_tm.parameters) == 4
assert len(flux_span_tm.parameters) == 11
assert all(t.rate_law for t in flux_span_tm.templates)

0 comments on commit add147d

Please sign in to comment.