From 6b1ec3b6453ce8183af96b00303693bd332fbb96 Mon Sep 17 00:00:00 2001 From: Matthew Printz Date: Wed, 10 Jan 2024 15:44:24 -0600 Subject: [PATCH] Add decapode reset message handler --- .../contexts/decapodes/context.py | 33 +++++++++++++++---- .../decapodes/procedures/julia/load_model.jl | 9 +++-- .../decapodes/procedures/julia/reset.jl | 1 + .../decapodes/procedures/julia/setup.jl | 3 ++ 4 files changed, 34 insertions(+), 12 deletions(-) create mode 100644 src/askem_beaker/contexts/decapodes/procedures/julia/reset.jl diff --git a/src/askem_beaker/contexts/decapodes/context.py b/src/askem_beaker/contexts/decapodes/context.py index fd8dfbe..c378891 100644 --- a/src/askem_beaker/contexts/decapodes/context.py +++ b/src/askem_beaker/contexts/decapodes/context.py @@ -30,7 +30,6 @@ def __init__(self, beaker_kernel: "LLMKernel", subkernel: "BaseSubkernel", confi async def setup(self, config, parent_header): self.config = config - var_names = list(config.keys()) def fetch_model(model_id): meta_url = f"{os.environ['DATA_SERVICE_URL']}/models/{model_id}" @@ -40,16 +39,16 @@ def fetch_model(model_id): model = json.dumps(response.json()["model"]) return model - load_commands = [ - '%s = parse_json_acset(SummationDecapode{Symbol, Symbol, Symbol},"""%s""")' % (var_name, fetch_model(decapode_id)) - for var_name, decapode_id in config.items() - ] + variables = { + var_name: fetch_model(decapode_id) for var_name, decapode_id in config.items() + } command = "\n".join( [ self.get_code("setup"), - "decapode = @decapode begin end", - *load_commands + self.get_code("load_model", { + "variables": variables, + }), ] ) print(f"Running command:\n-------\n{command}\n---------") @@ -220,3 +219,23 @@ async def save_amr_request(self, message): self.beaker_kernel.send_response( "iopub", "save_model_response", content, parent_header=message.header ) + + @intercept() + async def reset_request(self, message): + content = message.content + + model_name = content.get("model_name", self.target) + reset_code = self.get_code("reset", { + "var_name": model_name, + }) + reset_result = await self.execute(reset_code) + + content = { + "success": True, + "executed_code": reset_result["parent"].content["code"], + } + + self.beaker_kernel.send_response( + "iopub", "reset_response", content, parent_header=message.header + ) + await self.send_decapodes_preview_message(parent_header=message.header) diff --git a/src/askem_beaker/contexts/decapodes/procedures/julia/load_model.jl b/src/askem_beaker/contexts/decapodes/procedures/julia/load_model.jl index 829cfb5..7ac68a6 100644 --- a/src/askem_beaker/contexts/decapodes/procedures/julia/load_model.jl +++ b/src/askem_beaker/contexts/decapodes/procedures/julia/load_model.jl @@ -1,5 +1,4 @@ -using ACSets, Decapodes, SyntacticModels -import HTTP, JSON3, DisplayAs -_amr = JSON3.read(String(HTTP.get("{{ model_url }}").body), SyntacticModels.ASKEMDecapodes.ASKEMDecaExpr) -{{ var_name|default("model") }} = Decapodes.SummationDecapode(_amr.model) -Dict(["var_name" => "{{ var_name|default("model") }}"]) |> DisplayAs.unlimited ∘ JSON3.write # TODO: Fix 'Unable to parse result' false error +{% for var_name, definition in variables.items() %} +{{ var_name }} = parse_json_acset(SummationDecapode{Symbol, Symbol, Symbol},"""{{ definition }}""") +_model_reset_cache[:{{ var_name }}] = {{ var_name }} +{% endfor %} diff --git a/src/askem_beaker/contexts/decapodes/procedures/julia/reset.jl b/src/askem_beaker/contexts/decapodes/procedures/julia/reset.jl new file mode 100644 index 0000000..7be6ca5 --- /dev/null +++ b/src/askem_beaker/contexts/decapodes/procedures/julia/reset.jl @@ -0,0 +1 @@ +{{ var_name|default("decapode") }} = (haskey(_model_reset_cache, :{{ var_name|default("decapode") }}) ? _model_reset_cache[:{{ var_name|default("decapode") }}] : @decapode begin end) diff --git a/src/askem_beaker/contexts/decapodes/procedures/julia/setup.jl b/src/askem_beaker/contexts/decapodes/procedures/julia/setup.jl index 3280070..596076a 100644 --- a/src/askem_beaker/contexts/decapodes/procedures/julia/setup.jl +++ b/src/askem_beaker/contexts/decapodes/procedures/julia/setup.jl @@ -2,3 +2,6 @@ using ACSets using Catlab, Catlab.Graphics using CombinatorialSpaces using Decapodes + +_model_reset_cache = Dict() +decapode = @decapode begin end