Skip to content

Commit

Permalink
Sympy code to AMR pipeline (#5846)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwdchang authored Dec 16, 2024
1 parent c405cb5 commit 7da2180
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 26 deletions.
9 changes: 9 additions & 0 deletions packages/client/hmi-client/src/temp/Equations.vue
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<h4>Cleaned LaTeX => SymPy equation strings + AMR json</h4>
<div style="display: flex; flex-direction: row">
<textarea ref="latex2"></textarea>
<pre class="result" ref="resultCode"></pre>
<pre class="result" ref="resultSympy"></pre>
<pre class="result" ref="resultAmr"></pre>
</div>
Expand All @@ -28,6 +29,7 @@ import API from '@/api/api';
const latex1 = ref<HTMLTextAreaElement | null>(null);
const latex2 = ref<HTMLTextAreaElement | null>(null);
const result1 = ref<HTMLDivElement | null>(null);
const resultCode = ref<HTMLDivElement | null>(null);
const resultSympy = ref<HTMLDivElement | null>(null);
const resultAmr = ref<HTMLDivElement | null>(null);
Expand All @@ -52,6 +54,9 @@ const latex2amr = async () => {
const inputStr = latex2.value?.value || '[]';
const equations = JSON.parse(inputStr);
if (resultCode.value) {
resultCode.value.innerHTML = 'processing...';
}
if (resultSympy.value) {
resultSympy.value.innerHTML = 'processing...';
}
Expand All @@ -62,6 +67,10 @@ const latex2amr = async () => {
const resp = await API.post('/mira/latex-to-amr', equations);
const respData = resp.data;
if (resultCode.value) {
resultCode.value.innerHTML = '';
resultCode.value.innerHTML = respData.response.sympyCode;
}
if (resultSympy.value) {
resultSympy.value.innerHTML = '';
resultSympy.value.innerHTML = JSON.stringify(respData.response.sympyExprs, null, 2);
Expand Down
26 changes: 26 additions & 0 deletions packages/gollm/gollm_openai/prompts/latex_to_sympy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
LATEX_TO_SYMPY_PROMPT="""
You are a helpful assistant who is an expert in writing mathematical expressions in LaTeX code and the Python package SymPy.
Here is an example input LaTeX
["\frac{{d S(t)}}{{d t}} = -beta * S(t) * I(t) + b - m * S(t)"]
You should return this output in SymPy:
```
import sympy
# Define time variable
t = sympy.symbols("t")
# Define time-dependent variables
S, I = sympy.symbols("S I", cls = sympy.Function)
# Define constant parameters
beta, b, m = sympy.symbols("beta b m")
equation_output = [sympy.Eq(S(t).diff(t), -beta * S(t) * I(t) + b - m * S(t))]
```
Now, do the same for this LaTeX input:
{latex_equations}
Respond with just the code and nothing else, do not surround the answer in ```
Answer:
"""
29 changes: 29 additions & 0 deletions packages/gollm/gollm_openai/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
MODEL_METADATA_COMPARE_PROMPT,
MODEL_METADATA_COMPARE_GOAL_PROMPT
)
from gollm_openai.prompts.latex_to_sympy import LATEX_TO_SYMPY_PROMPT

from openai import OpenAI
from openai.types.chat.completion_create_params import ResponseFormat
from typing import List
Expand Down Expand Up @@ -77,6 +79,33 @@ def get_image_format_string(image_format: str) -> str:
return format_strings.get(image_format.lower())


def latex_to_sympy(equations: List[str]) -> str:
print("latex to sympy ...")

prompt = LATEX_TO_SYMPY_PROMPT.format(
latex_equations=",\n".join(equations)
)

client = OpenAI()
output = client.chat.completions.create(
model=GPT_MODEL,
frequency_penalty=0,
max_tokens=4000,
presence_penalty=0,
seed=905,
temperature=0,
top_p=1,
response_format={
"type": "text"
},
messages=[
{"role": "user", "content": prompt},
]
)
return output.choices[0].message.content



def equations_cleanup(equations: List[str]) -> dict:
print("Reformatting equations...")

Expand Down
1 change: 1 addition & 0 deletions packages/gollm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"gollm:embedding=tasks.embedding:main",
"gollm:enrich_amr=tasks.enrich_amr:main",
"gollm:enrich_dataset=tasks.enrich_dataset:main",
"gollm:latex_to_sympy=tasks.latex_to_sympy:main",
"gollm:equations_cleanup=tasks.equations_cleanup:main",
"gollm:equations_from_image=tasks.equations_from_image:main",
"gollm:generate_response=tasks.generate_response:main",
Expand Down
39 changes: 39 additions & 0 deletions packages/gollm/tasks/latex_to_sympy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import sys
from entities import EquationsCleanup
from gollm_openai.tool_utils import latex_to_sympy

from taskrunner import TaskRunnerInterface
import traceback


def cleanup():
pass


def main():
exitCode = 0
try:
taskrunner = TaskRunnerInterface(description="Converting latex to sympy")
taskrunner.on_cancellation(cleanup)

input_dict = taskrunner.read_input_dict_with_timeout()

taskrunner.log("Sending request to OpenAI API")
response = latex_to_sympy(equations=input_dict)
taskrunner.log("Received response from OpenAI API")

taskrunner.write_output_dict_with_timeout({"response": response})

except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
sys.stderr.write(traceback.format_exc())
sys.stderr.flush()
exitCode = 1

taskrunner.log("Shutting down")
taskrunner.shutdown()
sys.exit(exitCode)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion packages/mira/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"mira_task:mdl_to_stockflow=tasks.mdl_to_stockflow:main",
"mira_task:stella_to_stockflow=tasks.stella_to_stockflow:main",
"mira_task:amr_to_mmt=tasks.amr_to_mmt:main",
"mira_task:generate_model_latex=tasks.generate_model_latex:main",
"mira_task:generate_model_latex=tasks.generate_model_latex:main",
"mira_task:sympy_to_amr=tasks.sympy_to_amr:main",
],
},
python_requires=">=3.10",
Expand Down
56 changes: 56 additions & 0 deletions packages/mira/tasks/sympy_to_amr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import sys
import json
import traceback
import sympy

from taskrunner import TaskRunnerInterface
from mira.sources.sympy_ode import template_model_from_sympy_odes
from mira.modeling.amr.petrinet import template_model_to_petrinet_json

def cleanup():
pass

def main():
exitCode = 0

try:
taskrunner = TaskRunnerInterface(description="Sympy to AMR")
taskrunner.on_cancellation(cleanup)

sympy_code = taskrunner.read_input_str_with_timeout()
taskrunner.log("== input code")
taskrunner.log(sympy_code)
taskrunner.log("")

globals = {}
exec(sympy_code, globals) # output should be in placed into "equation_output"
taskrunner.log("== equations")
taskrunner.log(globals["equation_output"])
taskrunner.log("")

# SymPy to MMT
mmt = template_model_from_sympy_odes(globals["equation_output"])

# MMT to AMR
amr_json = template_model_to_petrinet_json(mmt)

# Gather results
response = {}
response["amr"] = amr_json
response["sympyCode"] = sympy_code
response["sympyExprs"] = list(map(lambda x: str(x), globals["equation_output"]))

taskrunner.log(f"Sympy to AMR conversion succeeded")
taskrunner.write_output_dict_with_timeout({"response": response })
except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
sys.stderr.write(traceback.format_exc())
sys.stderr.flush()
exitCode = 1

taskrunner.log("Shutting down")
taskrunner.shutdown()
sys.exit(exitCode)

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToSympyResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SympyToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.utils.Messages;
import software.uncharted.terarium.hmiserver.utils.rebac.Schema;
Expand Down Expand Up @@ -269,47 +271,71 @@ public ResponseEntity<JsonNode> generateModelLatex(@RequestBody final JsonNode m
}
)
public ResponseEntity<JsonNode> latexToAMR(@RequestBody final String latex) {
// create request:
final TaskRequest req = new TaskRequest();
req.setType(TaskType.MIRA);
////////////////////////////////////////////////////////////////////////////////
// 1. Convert latex string to python sympy code string
//
// Note this is a gollm string => string task
////////////////////////////////////////////////////////////////////////////////
final TaskRequest latexToSympyRequest = new TaskRequest();
final TaskResponse latexToSympyResponse;
String code = null;

try {
req.setInput(latex.getBytes());
latexToSympyRequest.setType(TaskType.GOLLM);
latexToSympyRequest.setInput(latex.getBytes());
latexToSympyRequest.setScript(LatexToSympyResponseHandler.NAME);
latexToSympyRequest.setUserId(currentUserService.get().getId());
latexToSympyResponse = taskService.runTaskSync(latexToSympyRequest);

final JsonNode node = objectMapper.readValue(latexToSympyResponse.getOutput(), JsonNode.class);
code = node.get("response").asText();
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
log.error("Unexpected error", e);
}

req.setScript(LatexToAMRResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());
if (code == null) {
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
}

////////////////////////////////////////////////////////////////////////////////
// 2. Convert python sympy code string to amr
//
// This returns the AMR json, and intermediate data representations for debugging
////////////////////////////////////////////////////////////////////////////////
final TaskRequest sympyToAMRRequest = new TaskRequest();
final TaskResponse sympyToAMRResponse;
final JsonNode response;

// send the request
final TaskResponse resp;
try {
resp = taskService.runTaskSync(req);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.json-processing"));
sympyToAMRRequest.setType(TaskType.MIRA);
sympyToAMRRequest.setInput(code.getBytes());
sympyToAMRRequest.setScript(SympyToAMRResponseHandler.NAME);
sympyToAMRRequest.setUserId(currentUserService.get().getId());
sympyToAMRResponse = taskService.runTaskSync(sympyToAMRRequest);
response = objectMapper.readValue(sympyToAMRResponse.getOutput(), JsonNode.class);
return ResponseEntity.ok().body(response);
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.mira.interrupted"));
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure"));
} catch (final Exception e) {
log.error("Unexpected error", e);
}

final JsonNode latexResponse;
try {
latexResponse = objectMapper.readValue(resp.getOutput(), JsonNode.class);
} catch (final IOException e) {
log.error("Unable to deserialize output", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}

return ResponseEntity.ok().body(latexResponse);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}

@PostMapping("/convert-and-create-model")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;

@Component
@RequiredArgsConstructor
@Slf4j
public class LatexToSympyResponseHandler extends TaskResponseHandler {

public static final String NAME = "gollm:latex_to_sympy";

@Override
public String getName() {
return NAME;
}

@Data
public static class Response {

String response;
}

@Override
public TaskResponse onSuccess(final TaskResponse resp) {
return resp;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
@Slf4j
public class SympyToAMRResponseHandler extends TaskResponseHandler {

public static final String NAME = "mira_task:sympy_to_amr";

@Override
public String getName() {
return NAME;
}
}

0 comments on commit 7da2180

Please sign in to comment.