Skip to content

Commit

Permalink
Add missing "(t)" to all the model variables in the generated LaTeX t…
Browse files Browse the repository at this point in the history
…o align input/output math in the "Create model from equation" operator (#5926)
  • Loading branch information
liunelson authored Jan 2, 2025
1 parent 3e052cc commit 8e5d121
Showing 1 changed file with 40 additions and 21 deletions.
61 changes: 40 additions & 21 deletions packages/mira/tasks/generate_model_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,72 @@ def main():
exitCode = 0

try:
taskrunner = TaskRunnerInterface(description="Generate latex")
taskrunner = TaskRunnerInterface(description="Generate LaTeX")
taskrunner.on_cancellation(cleanup)

data = taskrunner.read_input_str_with_timeout()
amr = json.loads(data)
model = model_from_json(amr)

# =========================================
# Generate LaTeX code string from MMT model
# =========================================

odeterms = {var: 0 for var in model.get_concepts_name_map().keys()}

for t in model.templates:
if hasattr(t, "subject"):
var = t.subject.name
odeterms[var] -= t.rate_law.args[0]

if hasattr(t, "outcome"):
var = t.outcome.name
odeterms[var] += t.rate_law.args[0]

# Time
if model.time and model.time.name:
time = model.time.name
else:
time = "t"

t = sympy.Symbol(time)

# Construct Sympy equations
odesys = [
sympy.latex(sympy.Eq(sympy.diff(sympy.Function(var)(t), t), terms))
for var, terms in odeterms.items()
]

odesys = []
for var, terms in odeterms.items():
lhs = sympy.diff(sympy.Function(var)(t), t)

# Write (time-dependent) symbols with "(t)"
rhs = terms
for atom in terms.atoms(sympy.Symbol):
if str(atom) in odeterms.keys():
rhs = rhs.subs(atom, sympy.Function(str(atom))(t))

odesys.append(sympy.latex(sympy.Eq(lhs, rhs)))

# Observables
if len(model.observables) != 0:
obs_eqs = [
f"{{{obs.name}}}(t) = " + sympy.latex(obs.expression.args[0])
for obs in model.observables.values()
]

#add observables.
if len(model.observables) > 0:

# Write (time-dependent) symbols with "(t)"
obs_eqs = []
for obs in model.observables.values():
lhs = sympy.Function(obs.name)(t)
terms = obs.expression.args[0]
rhs = terms
for atom in terms.atoms(sympy.Symbol):
if str(atom) in odeterms.keys():
rhs = rhs.subs(atom, sympy.Function(str(atom))(t))
obs_eqs.append(sympy.latex(sympy.Eq(lhs, rhs)))

# Add observables
odesys += obs_eqs

# Reformat:
odesys = "\\begin{align} \n " + " \\\\ \n ".join([eq for eq in odesys]) + "\n\\end{align}"

# =========================================

taskrunner.write_output_dict_with_timeout({"response": odesys})
print("Generate latex succeeded")
print("Generate LaTeX succeeded")

except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
Expand Down

0 comments on commit 8e5d121

Please sign in to comment.