diff --git a/packages/client/hmi-client/src/temp/Equations.vue b/packages/client/hmi-client/src/temp/Equations.vue
index 82d24f044b..d8c04f8a81 100644
--- a/packages/client/hmi-client/src/temp/Equations.vue
+++ b/packages/client/hmi-client/src/temp/Equations.vue
@@ -12,6 +12,7 @@
Cleaned LaTeX => SymPy equation strings + AMR json
@@ -28,6 +29,7 @@ import API from '@/api/api';
const latex1 = ref(null);
const latex2 = ref(null);
const result1 = ref(null);
+const resultCode = ref(null);
const resultSympy = ref(null);
const resultAmr = ref(null);
@@ -52,6 +54,9 @@ const latex2amr = async () => {
const inputStr = latex2.value?.value || '[]';
const equations = JSON.parse(inputStr);
+ if (resultCode.value) {
+ resultCode.value.innerHTML = 'processing...';
+ }
if (resultSympy.value) {
resultSympy.value.innerHTML = 'processing...';
}
@@ -62,6 +67,10 @@ const latex2amr = async () => {
const resp = await API.post('/mira/latex-to-amr', equations);
const respData = resp.data;
+ if (resultCode.value) {
+ resultCode.value.innerHTML = '';
+ resultCode.value.innerHTML = respData.response.sympyCode;
+ }
if (resultSympy.value) {
resultSympy.value.innerHTML = '';
resultSympy.value.innerHTML = JSON.stringify(respData.response.sympyExprs, null, 2);
diff --git a/packages/gollm/gollm_openai/prompts/latex_to_sympy.py b/packages/gollm/gollm_openai/prompts/latex_to_sympy.py
new file mode 100644
index 0000000000..024ffc98d8
--- /dev/null
+++ b/packages/gollm/gollm_openai/prompts/latex_to_sympy.py
@@ -0,0 +1,26 @@
+LATEX_TO_SYMPY_PROMPT="""
+You are a helpful assistant who is an expert in writing mathematical expressions in LaTeX code and the Python package SymPy.
+
+Here is an example input LaTeX
+["\frac{{d S(t)}}{{d t}} = -beta * S(t) * I(t) + b - m * S(t)"]
+
+You should return this output in SymPy:
+```
+import sympy
+# Define time variable
+t = sympy.symbols("t")
+# Define time-dependent variables
+S, I = sympy.symbols("S I", cls = sympy.Function)
+# Define constant parameters
+beta, b, m = sympy.symbols("beta b m")
+equation_output = [sympy.Eq(S(t).diff(t), -beta * S(t) * I(t) + b - m * S(t))]
+```
+
+Now, do the same for this LaTeX input:
+{latex_equations}
+
+Respond with just the code and nothing else, do not surround the answer in ```
+
+
+Answer:
+"""
diff --git a/packages/gollm/gollm_openai/tool_utils.py b/packages/gollm/gollm_openai/tool_utils.py
index 4ca980f72c..d84e58c37d 100644
--- a/packages/gollm/gollm_openai/tool_utils.py
+++ b/packages/gollm/gollm_openai/tool_utils.py
@@ -24,6 +24,8 @@
MODEL_METADATA_COMPARE_PROMPT,
MODEL_METADATA_COMPARE_GOAL_PROMPT
)
+from gollm_openai.prompts.latex_to_sympy import LATEX_TO_SYMPY_PROMPT
+
from openai import OpenAI
from openai.types.chat.completion_create_params import ResponseFormat
from typing import List
@@ -77,6 +79,33 @@ def get_image_format_string(image_format: str) -> str:
return format_strings.get(image_format.lower())
+def latex_to_sympy(equations: List[str]) -> str:
+ print("latex to sympy ...")
+
+ prompt = LATEX_TO_SYMPY_PROMPT.format(
+ latex_equations=",\n".join(equations)
+ )
+
+ client = OpenAI()
+ output = client.chat.completions.create(
+ model=GPT_MODEL,
+ frequency_penalty=0,
+ max_tokens=4000,
+ presence_penalty=0,
+ seed=905,
+ temperature=0,
+ top_p=1,
+ response_format={
+ "type": "text"
+ },
+ messages=[
+ {"role": "user", "content": prompt},
+ ]
+ )
+ return output.choices[0].message.content
+
+
+
def equations_cleanup(equations: List[str]) -> dict:
print("Reformatting equations...")
diff --git a/packages/gollm/setup.py b/packages/gollm/setup.py
index ae803bd414..717f0ec0dd 100644
--- a/packages/gollm/setup.py
+++ b/packages/gollm/setup.py
@@ -24,6 +24,7 @@
"gollm:embedding=tasks.embedding:main",
"gollm:enrich_amr=tasks.enrich_amr:main",
"gollm:enrich_dataset=tasks.enrich_dataset:main",
+ "gollm:latex_to_sympy=tasks.latex_to_sympy:main",
"gollm:equations_cleanup=tasks.equations_cleanup:main",
"gollm:equations_from_image=tasks.equations_from_image:main",
"gollm:generate_response=tasks.generate_response:main",
diff --git a/packages/gollm/tasks/latex_to_sympy.py b/packages/gollm/tasks/latex_to_sympy.py
new file mode 100644
index 0000000000..998fd23b6f
--- /dev/null
+++ b/packages/gollm/tasks/latex_to_sympy.py
@@ -0,0 +1,39 @@
+import sys
+from entities import EquationsCleanup
+from gollm_openai.tool_utils import latex_to_sympy
+
+from taskrunner import TaskRunnerInterface
+import traceback
+
+
+def cleanup():
+ pass
+
+
+def main():
+ exitCode = 0
+ try:
+ taskrunner = TaskRunnerInterface(description="Converting latex to sympy")
+ taskrunner.on_cancellation(cleanup)
+
+ input_dict = taskrunner.read_input_dict_with_timeout()
+
+ taskrunner.log("Sending request to OpenAI API")
+ response = latex_to_sympy(equations=input_dict)
+ taskrunner.log("Received response from OpenAI API")
+
+ taskrunner.write_output_dict_with_timeout({"response": response})
+
+ except Exception as e:
+ sys.stderr.write(f"Error: {str(e)}\n")
+ sys.stderr.write(traceback.format_exc())
+ sys.stderr.flush()
+ exitCode = 1
+
+ taskrunner.log("Shutting down")
+ taskrunner.shutdown()
+ sys.exit(exitCode)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/packages/mira/setup.py b/packages/mira/setup.py
index 605c273a7f..7c2754869e 100644
--- a/packages/mira/setup.py
+++ b/packages/mira/setup.py
@@ -12,7 +12,8 @@
"mira_task:mdl_to_stockflow=tasks.mdl_to_stockflow:main",
"mira_task:stella_to_stockflow=tasks.stella_to_stockflow:main",
"mira_task:amr_to_mmt=tasks.amr_to_mmt:main",
- "mira_task:generate_model_latex=tasks.generate_model_latex:main",
+ "mira_task:generate_model_latex=tasks.generate_model_latex:main",
+ "mira_task:sympy_to_amr=tasks.sympy_to_amr:main",
],
},
python_requires=">=3.10",
diff --git a/packages/mira/tasks/sympy_to_amr.py b/packages/mira/tasks/sympy_to_amr.py
new file mode 100644
index 0000000000..25de994f2f
--- /dev/null
+++ b/packages/mira/tasks/sympy_to_amr.py
@@ -0,0 +1,56 @@
+import sys
+import json
+import traceback
+import sympy
+
+from taskrunner import TaskRunnerInterface
+from mira.sources.sympy_ode import template_model_from_sympy_odes
+from mira.modeling.amr.petrinet import template_model_to_petrinet_json
+
+def cleanup():
+ pass
+
+def main():
+ exitCode = 0
+
+ try:
+ taskrunner = TaskRunnerInterface(description="Sympy to AMR")
+ taskrunner.on_cancellation(cleanup)
+
+ sympy_code = taskrunner.read_input_str_with_timeout()
+ taskrunner.log("== input code")
+ taskrunner.log(sympy_code)
+ taskrunner.log("")
+
+ globals = {}
+ exec(sympy_code, globals) # output should be in placed into "equation_output"
+ taskrunner.log("== equations")
+ taskrunner.log(globals["equation_output"])
+ taskrunner.log("")
+
+ # SymPy to MMT
+ mmt = template_model_from_sympy_odes(globals["equation_output"])
+
+ # MMT to AMR
+ amr_json = template_model_to_petrinet_json(mmt)
+
+ # Gather results
+ response = {}
+ response["amr"] = amr_json
+ response["sympyCode"] = sympy_code
+ response["sympyExprs"] = list(map(lambda x: str(x), globals["equation_output"]))
+
+ taskrunner.log(f"Sympy to AMR conversion succeeded")
+ taskrunner.write_output_dict_with_timeout({"response": response })
+ except Exception as e:
+ sys.stderr.write(f"Error: {str(e)}\n")
+ sys.stderr.write(traceback.format_exc())
+ sys.stderr.flush()
+ exitCode = 1
+
+ taskrunner.log("Shutting down")
+ taskrunner.shutdown()
+ sys.exit(exitCode)
+
+if __name__ == "__main__":
+ main()
diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java
index bbbffa670b..39c69f1347 100644
--- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java
+++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java
@@ -49,9 +49,11 @@
import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.LatexToAMRResponseHandler;
+import software.uncharted.terarium.hmiserver.service.tasks.LatexToSympyResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler;
+import software.uncharted.terarium.hmiserver.service.tasks.SympyToAMRResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.utils.Messages;
import software.uncharted.terarium.hmiserver.utils.rebac.Schema;
@@ -269,47 +271,71 @@ public ResponseEntity generateModelLatex(@RequestBody final JsonNode m
}
)
public ResponseEntity latexToAMR(@RequestBody final String latex) {
- // create request:
- final TaskRequest req = new TaskRequest();
- req.setType(TaskType.MIRA);
+ ////////////////////////////////////////////////////////////////////////////////
+ // 1. Convert latex string to python sympy code string
+ //
+ // Note this is a gollm string => string task
+ ////////////////////////////////////////////////////////////////////////////////
+ final TaskRequest latexToSympyRequest = new TaskRequest();
+ final TaskResponse latexToSympyResponse;
+ String code = null;
try {
- req.setInput(latex.getBytes());
+ latexToSympyRequest.setType(TaskType.GOLLM);
+ latexToSympyRequest.setInput(latex.getBytes());
+ latexToSympyRequest.setScript(LatexToSympyResponseHandler.NAME);
+ latexToSympyRequest.setUserId(currentUserService.get().getId());
+ latexToSympyResponse = taskService.runTaskSync(latexToSympyRequest);
+
+ final JsonNode node = objectMapper.readValue(latexToSympyResponse.getOutput(), JsonNode.class);
+ code = node.get("response").asText();
+ } catch (final TimeoutException e) {
+ log.warn("Timeout while waiting for task response", e);
+ throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
+ } catch (final InterruptedException e) {
+ log.warn("Interrupted while waiting for task response", e);
+ throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.interrupted"));
+ } catch (final ExecutionException e) {
+ log.error("Error while waiting for task response", e);
+ throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
} catch (final Exception e) {
- log.error("Unable to serialize input", e);
- throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
+ log.error("Unexpected error", e);
}
- req.setScript(LatexToAMRResponseHandler.NAME);
- req.setUserId(currentUserService.get().getId());
+ if (code == null) {
+ throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
+ }
+
+ ////////////////////////////////////////////////////////////////////////////////
+ // 2. Convert python sympy code string to amr
+ //
+ // This returns the AMR json, and intermediate data representations for debugging
+ ////////////////////////////////////////////////////////////////////////////////
+ final TaskRequest sympyToAMRRequest = new TaskRequest();
+ final TaskResponse sympyToAMRResponse;
+ final JsonNode response;
- // send the request
- final TaskResponse resp;
try {
- resp = taskService.runTaskSync(req);
- } catch (final JsonProcessingException e) {
- log.error("Unable to serialize input", e);
- throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.json-processing"));
+ sympyToAMRRequest.setType(TaskType.MIRA);
+ sympyToAMRRequest.setInput(code.getBytes());
+ sympyToAMRRequest.setScript(SympyToAMRResponseHandler.NAME);
+ sympyToAMRRequest.setUserId(currentUserService.get().getId());
+ sympyToAMRResponse = taskService.runTaskSync(sympyToAMRRequest);
+ response = objectMapper.readValue(sympyToAMRResponse.getOutput(), JsonNode.class);
+ return ResponseEntity.ok().body(response);
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
- throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.mira.interrupted"));
+ throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure"));
+ } catch (final Exception e) {
+ log.error("Unexpected error", e);
}
-
- final JsonNode latexResponse;
- try {
- latexResponse = objectMapper.readValue(resp.getOutput(), JsonNode.class);
- } catch (final IOException e) {
- log.error("Unable to deserialize output", e);
- throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
- }
-
- return ResponseEntity.ok().body(latexResponse);
+ throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}
@PostMapping("/convert-and-create-model")
diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/LatexToSympyResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/LatexToSympyResponseHandler.java
new file mode 100644
index 0000000000..3591e52e81
--- /dev/null
+++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/LatexToSympyResponseHandler.java
@@ -0,0 +1,31 @@
+package software.uncharted.terarium.hmiserver.service.tasks;
+
+import lombok.Data;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
+
+@Component
+@RequiredArgsConstructor
+@Slf4j
+public class LatexToSympyResponseHandler extends TaskResponseHandler {
+
+ public static final String NAME = "gollm:latex_to_sympy";
+
+ @Override
+ public String getName() {
+ return NAME;
+ }
+
+ @Data
+ public static class Response {
+
+ String response;
+ }
+
+ @Override
+ public TaskResponse onSuccess(final TaskResponse resp) {
+ return resp;
+ }
+}
diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/SympyToAMRResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/SympyToAMRResponseHandler.java
new file mode 100644
index 0000000000..dd4ccc5de9
--- /dev/null
+++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/SympyToAMRResponseHandler.java
@@ -0,0 +1,18 @@
+package software.uncharted.terarium.hmiserver.service.tasks;
+
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+
+@Component
+@RequiredArgsConstructor
+@Slf4j
+public class SympyToAMRResponseHandler extends TaskResponseHandler {
+
+ public static final String NAME = "mira_task:sympy_to_amr";
+
+ @Override
+ public String getName() {
+ return NAME;
+ }
+}