diff --git a/packages/client/hmi-client/src/temp/Equations.vue b/packages/client/hmi-client/src/temp/Equations.vue index 82d24f044b..d8c04f8a81 100644 --- a/packages/client/hmi-client/src/temp/Equations.vue +++ b/packages/client/hmi-client/src/temp/Equations.vue @@ -12,6 +12,7 @@

Cleaned LaTeX => SymPy equation strings + AMR json

+

 				

 				

 			
@@ -28,6 +29,7 @@ import API from '@/api/api'; const latex1 = ref(null); const latex2 = ref(null); const result1 = ref(null); +const resultCode = ref(null); const resultSympy = ref(null); const resultAmr = ref(null); @@ -52,6 +54,9 @@ const latex2amr = async () => { const inputStr = latex2.value?.value || '[]'; const equations = JSON.parse(inputStr); + if (resultCode.value) { + resultCode.value.innerHTML = 'processing...'; + } if (resultSympy.value) { resultSympy.value.innerHTML = 'processing...'; } @@ -62,6 +67,10 @@ const latex2amr = async () => { const resp = await API.post('/mira/latex-to-amr', equations); const respData = resp.data; + if (resultCode.value) { + resultCode.value.innerHTML = ''; + resultCode.value.innerHTML = respData.response.sympyCode; + } if (resultSympy.value) { resultSympy.value.innerHTML = ''; resultSympy.value.innerHTML = JSON.stringify(respData.response.sympyExprs, null, 2); diff --git a/packages/gollm/gollm_openai/prompts/latex_to_sympy.py b/packages/gollm/gollm_openai/prompts/latex_to_sympy.py new file mode 100644 index 0000000000..024ffc98d8 --- /dev/null +++ b/packages/gollm/gollm_openai/prompts/latex_to_sympy.py @@ -0,0 +1,26 @@ +LATEX_TO_SYMPY_PROMPT=""" +You are a helpful assistant who is an expert in writing mathematical expressions in LaTeX code and the Python package SymPy. + +Here is an example input LaTeX +["\frac{{d S(t)}}{{d t}} = -beta * S(t) * I(t) + b - m * S(t)"] + +You should return this output in SymPy: +``` +import sympy +# Define time variable +t = sympy.symbols("t") +# Define time-dependent variables +S, I = sympy.symbols("S I", cls = sympy.Function) +# Define constant parameters +beta, b, m = sympy.symbols("beta b m") +equation_output = [sympy.Eq(S(t).diff(t), -beta * S(t) * I(t) + b - m * S(t))] +``` + +Now, do the same for this LaTeX input: +{latex_equations} + +Respond with just the code and nothing else, do not surround the answer in ``` + + +Answer: +""" diff --git a/packages/gollm/gollm_openai/tool_utils.py b/packages/gollm/gollm_openai/tool_utils.py index 4ca980f72c..d84e58c37d 100644 --- a/packages/gollm/gollm_openai/tool_utils.py +++ b/packages/gollm/gollm_openai/tool_utils.py @@ -24,6 +24,8 @@ MODEL_METADATA_COMPARE_PROMPT, MODEL_METADATA_COMPARE_GOAL_PROMPT ) +from gollm_openai.prompts.latex_to_sympy import LATEX_TO_SYMPY_PROMPT + from openai import OpenAI from openai.types.chat.completion_create_params import ResponseFormat from typing import List @@ -77,6 +79,33 @@ def get_image_format_string(image_format: str) -> str: return format_strings.get(image_format.lower()) +def latex_to_sympy(equations: List[str]) -> str: + print("latex to sympy ...") + + prompt = LATEX_TO_SYMPY_PROMPT.format( + latex_equations=",\n".join(equations) + ) + + client = OpenAI() + output = client.chat.completions.create( + model=GPT_MODEL, + frequency_penalty=0, + max_tokens=4000, + presence_penalty=0, + seed=905, + temperature=0, + top_p=1, + response_format={ + "type": "text" + }, + messages=[ + {"role": "user", "content": prompt}, + ] + ) + return output.choices[0].message.content + + + def equations_cleanup(equations: List[str]) -> dict: print("Reformatting equations...") diff --git a/packages/gollm/setup.py b/packages/gollm/setup.py index ae803bd414..717f0ec0dd 100644 --- a/packages/gollm/setup.py +++ b/packages/gollm/setup.py @@ -24,6 +24,7 @@ "gollm:embedding=tasks.embedding:main", "gollm:enrich_amr=tasks.enrich_amr:main", "gollm:enrich_dataset=tasks.enrich_dataset:main", + "gollm:latex_to_sympy=tasks.latex_to_sympy:main", "gollm:equations_cleanup=tasks.equations_cleanup:main", "gollm:equations_from_image=tasks.equations_from_image:main", "gollm:generate_response=tasks.generate_response:main", diff --git a/packages/gollm/tasks/latex_to_sympy.py b/packages/gollm/tasks/latex_to_sympy.py new file mode 100644 index 0000000000..998fd23b6f --- /dev/null +++ b/packages/gollm/tasks/latex_to_sympy.py @@ -0,0 +1,39 @@ +import sys +from entities import EquationsCleanup +from gollm_openai.tool_utils import latex_to_sympy + +from taskrunner import TaskRunnerInterface +import traceback + + +def cleanup(): + pass + + +def main(): + exitCode = 0 + try: + taskrunner = TaskRunnerInterface(description="Converting latex to sympy") + taskrunner.on_cancellation(cleanup) + + input_dict = taskrunner.read_input_dict_with_timeout() + + taskrunner.log("Sending request to OpenAI API") + response = latex_to_sympy(equations=input_dict) + taskrunner.log("Received response from OpenAI API") + + taskrunner.write_output_dict_with_timeout({"response": response}) + + except Exception as e: + sys.stderr.write(f"Error: {str(e)}\n") + sys.stderr.write(traceback.format_exc()) + sys.stderr.flush() + exitCode = 1 + + taskrunner.log("Shutting down") + taskrunner.shutdown() + sys.exit(exitCode) + + +if __name__ == "__main__": + main() diff --git a/packages/mira/setup.py b/packages/mira/setup.py index 605c273a7f..7c2754869e 100644 --- a/packages/mira/setup.py +++ b/packages/mira/setup.py @@ -12,7 +12,8 @@ "mira_task:mdl_to_stockflow=tasks.mdl_to_stockflow:main", "mira_task:stella_to_stockflow=tasks.stella_to_stockflow:main", "mira_task:amr_to_mmt=tasks.amr_to_mmt:main", - "mira_task:generate_model_latex=tasks.generate_model_latex:main", + "mira_task:generate_model_latex=tasks.generate_model_latex:main", + "mira_task:sympy_to_amr=tasks.sympy_to_amr:main", ], }, python_requires=">=3.10", diff --git a/packages/mira/tasks/sympy_to_amr.py b/packages/mira/tasks/sympy_to_amr.py new file mode 100644 index 0000000000..25de994f2f --- /dev/null +++ b/packages/mira/tasks/sympy_to_amr.py @@ -0,0 +1,56 @@ +import sys +import json +import traceback +import sympy + +from taskrunner import TaskRunnerInterface +from mira.sources.sympy_ode import template_model_from_sympy_odes +from mira.modeling.amr.petrinet import template_model_to_petrinet_json + +def cleanup(): + pass + +def main(): + exitCode = 0 + + try: + taskrunner = TaskRunnerInterface(description="Sympy to AMR") + taskrunner.on_cancellation(cleanup) + + sympy_code = taskrunner.read_input_str_with_timeout() + taskrunner.log("== input code") + taskrunner.log(sympy_code) + taskrunner.log("") + + globals = {} + exec(sympy_code, globals) # output should be in placed into "equation_output" + taskrunner.log("== equations") + taskrunner.log(globals["equation_output"]) + taskrunner.log("") + + # SymPy to MMT + mmt = template_model_from_sympy_odes(globals["equation_output"]) + + # MMT to AMR + amr_json = template_model_to_petrinet_json(mmt) + + # Gather results + response = {} + response["amr"] = amr_json + response["sympyCode"] = sympy_code + response["sympyExprs"] = list(map(lambda x: str(x), globals["equation_output"])) + + taskrunner.log(f"Sympy to AMR conversion succeeded") + taskrunner.write_output_dict_with_timeout({"response": response }) + except Exception as e: + sys.stderr.write(f"Error: {str(e)}\n") + sys.stderr.write(traceback.format_exc()) + sys.stderr.flush() + exitCode = 1 + + taskrunner.log("Shutting down") + taskrunner.shutdown() + sys.exit(exitCode) + +if __name__ == "__main__": + main() diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java index bbbffa670b..39c69f1347 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java @@ -49,9 +49,11 @@ import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.LatexToAMRResponseHandler; +import software.uncharted.terarium.hmiserver.service.tasks.LatexToSympyResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler; +import software.uncharted.terarium.hmiserver.service.tasks.SympyToAMRResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.TaskService; import software.uncharted.terarium.hmiserver.utils.Messages; import software.uncharted.terarium.hmiserver.utils.rebac.Schema; @@ -269,47 +271,71 @@ public ResponseEntity generateModelLatex(@RequestBody final JsonNode m } ) public ResponseEntity latexToAMR(@RequestBody final String latex) { - // create request: - final TaskRequest req = new TaskRequest(); - req.setType(TaskType.MIRA); + //////////////////////////////////////////////////////////////////////////////// + // 1. Convert latex string to python sympy code string + // + // Note this is a gollm string => string task + //////////////////////////////////////////////////////////////////////////////// + final TaskRequest latexToSympyRequest = new TaskRequest(); + final TaskResponse latexToSympyResponse; + String code = null; try { - req.setInput(latex.getBytes()); + latexToSympyRequest.setType(TaskType.GOLLM); + latexToSympyRequest.setInput(latex.getBytes()); + latexToSympyRequest.setScript(LatexToSympyResponseHandler.NAME); + latexToSympyRequest.setUserId(currentUserService.get().getId()); + latexToSympyResponse = taskService.runTaskSync(latexToSympyRequest); + + final JsonNode node = objectMapper.readValue(latexToSympyResponse.getOutput(), JsonNode.class); + code = node.get("response").asText(); + } catch (final TimeoutException e) { + log.warn("Timeout while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout")); + } catch (final InterruptedException e) { + log.warn("Interrupted while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.interrupted")); + } catch (final ExecutionException e) { + log.error("Error while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure")); } catch (final Exception e) { - log.error("Unable to serialize input", e); - throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + log.error("Unexpected error", e); } - req.setScript(LatexToAMRResponseHandler.NAME); - req.setUserId(currentUserService.get().getId()); + if (code == null) { + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure")); + } + + //////////////////////////////////////////////////////////////////////////////// + // 2. Convert python sympy code string to amr + // + // This returns the AMR json, and intermediate data representations for debugging + //////////////////////////////////////////////////////////////////////////////// + final TaskRequest sympyToAMRRequest = new TaskRequest(); + final TaskResponse sympyToAMRResponse; + final JsonNode response; - // send the request - final TaskResponse resp; try { - resp = taskService.runTaskSync(req); - } catch (final JsonProcessingException e) { - log.error("Unable to serialize input", e); - throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.json-processing")); + sympyToAMRRequest.setType(TaskType.MIRA); + sympyToAMRRequest.setInput(code.getBytes()); + sympyToAMRRequest.setScript(SympyToAMRResponseHandler.NAME); + sympyToAMRRequest.setUserId(currentUserService.get().getId()); + sympyToAMRResponse = taskService.runTaskSync(sympyToAMRRequest); + response = objectMapper.readValue(sympyToAMRResponse.getOutput(), JsonNode.class); + return ResponseEntity.ok().body(response); } catch (final TimeoutException e) { log.warn("Timeout while waiting for task response", e); throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout")); } catch (final InterruptedException e) { log.warn("Interrupted while waiting for task response", e); - throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.mira.interrupted")); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.interrupted")); } catch (final ExecutionException e) { log.error("Error while waiting for task response", e); throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure")); + } catch (final Exception e) { + log.error("Unexpected error", e); } - - final JsonNode latexResponse; - try { - latexResponse = objectMapper.readValue(resp.getOutput(), JsonNode.class); - } catch (final IOException e) { - log.error("Unable to deserialize output", e); - throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read")); - } - - return ResponseEntity.ok().body(latexResponse); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read")); } @PostMapping("/convert-and-create-model") diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/LatexToSympyResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/LatexToSympyResponseHandler.java new file mode 100644 index 0000000000..3591e52e81 --- /dev/null +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/LatexToSympyResponseHandler.java @@ -0,0 +1,31 @@ +package software.uncharted.terarium.hmiserver.service.tasks; + +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import software.uncharted.terarium.hmiserver.models.task.TaskResponse; + +@Component +@RequiredArgsConstructor +@Slf4j +public class LatexToSympyResponseHandler extends TaskResponseHandler { + + public static final String NAME = "gollm:latex_to_sympy"; + + @Override + public String getName() { + return NAME; + } + + @Data + public static class Response { + + String response; + } + + @Override + public TaskResponse onSuccess(final TaskResponse resp) { + return resp; + } +} diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/SympyToAMRResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/SympyToAMRResponseHandler.java new file mode 100644 index 0000000000..dd4ccc5de9 --- /dev/null +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/SympyToAMRResponseHandler.java @@ -0,0 +1,18 @@ +package software.uncharted.terarium.hmiserver.service.tasks; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +@Component +@RequiredArgsConstructor +@Slf4j +public class SympyToAMRResponseHandler extends TaskResponseHandler { + + public static final String NAME = "mira_task:sympy_to_amr"; + + @Override + public String getName() { + return NAME; + } +}