Skip to content

Commit

Permalink
Adding mira model context and handler
Browse files Browse the repository at this point in the history
  • Loading branch information
mattprintz committed Jul 14, 2023
1 parent 458e4cf commit fc7804e
Show file tree
Hide file tree
Showing 4 changed files with 1,005 additions and 594 deletions.
146 changes: 126 additions & 20 deletions llmkernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ipykernel.kernelbase import Kernel
from ipykernel.ipkernel import IPythonKernel
from toolsets.dataset_toolset import DatasetToolset
from toolsets.mira_model_toolset import MiraModelToolset
from archytas.react import ReActAgent

logger = logging.getLogger(__name__)
Expand All @@ -30,49 +31,69 @@ class PythonLLMKernel(IPythonKernel):


def setup_instance(self, *args, **kwargs):
self.toolset = None
self.agent = None
# Init LLM agent
self.toolset = DatasetToolset()
self.agent = ReActAgent(tools=[self.toolset], allow_ask_user=False, verbose=True, spinner=None, rich_print=False)
self.toolset.agent = self.agent
if getattr(self, 'context', None) is not None:
self.agent.clear_all_context()
self.context = None
self.msg_types.append("context_setup_request")
self.msg_types.append("llm_request")
self.msg_types.append("download_dataset_request")
self.msg_types.append("save_dataset_request")
self.msg_types.append("save_amr_request")
self.msg_types.append("load_dataset")
self.msg_types.append("load_mira_model")
return super().setup_instance(*args, **kwargs)


def set_context(self, context, context_info):
self.toolset = None
match context:
case "dataset":
self.toolset = DatasetToolset()
self.agent = ReActAgent(tools=[self.toolset], allow_ask_user=False, verbose=True, spinner=None, rich_print=False)
self.toolset.agent = self.agent
if getattr(self, 'context', None) is not None:
self.agent.clear_all_context()
dataset_id = context_info["id"]
print(f"Processing dataset w/id {dataset_id}")
self.toolset.set_dataset(dataset_id)
self.toolset.kernel = self.shell
self.toolset.set_dataset(dataset_id)
self.context = self.agent.add_context(self.toolset.context())
self.shell.ex("""import pandas as pd; import numpy as np; import scipy;""")
self.shell.push({
"df": self.toolset.df
})
self.send_df_preview_message()
case "mira_model":
self.toolset = MiraModelToolset()
self.agent = ReActAgent(tools=[self.toolset], allow_ask_user=False, verbose=True, spinner=None, rich_print=False)
self.toolset.agent = self.agent
if getattr(self, 'context', None) is not None:
self.agent.clear_all_context()
model_id = context_info["id"]
print(f"Processing AMR {model_id} as a MIRA model")
self.toolset.kernel = self.shell
self.toolset.set_model(model_id)
self.context = self.agent.add_context(self.toolset.context())


def send_df_preview_message(self):
df = self.shell.ev("df")
if isinstance(df, pd.DataFrame):
split_df = json.loads(df.head(30).to_json(orient="split"))
payload = {
"name": "Temp dataset (not saved)",
"headers": split_df["columns"],
"csv": [split_df["columns"]] + split_df["data"],
}
self.send_response(
stream=self.iopub_socket,
msg_or_type="dataset",
content=payload,
)
try:
df = self.shell.ev("df")
if isinstance(df, pd.DataFrame):
split_df = json.loads(df.head(30).to_json(orient="split"))
payload = {
"name": "Temp dataset (not saved)",
"headers": split_df["columns"],
"csv": [split_df["columns"]] + split_df["data"],
}
self.send_response(
stream=self.iopub_socket,
msg_or_type="dataset",
content=payload,
)
except:
pass


# def send_response(self, stream, msg_or_type, content=None, ident=None, buffers=None, track=False, header=None, metadata=None, channel="shell"):
Expand All @@ -87,7 +108,7 @@ async def llm_request(self, queue, message_id, message, **kwargs):
if not request:
return
try:
result = self.agent.react(request)
result = await self.agent.react_async(request)
except Exception as err:
error_text = f"""LLM Error:
{err}
Expand Down Expand Up @@ -140,6 +161,42 @@ async def context_setup_request(self, queue, message_id, message, **kwargs):
channel="iopub",
)

async def load_dataset(self, queue, message_id, message, **kwargs):
content = message.get('content', {})
dataset_id = content.get('dataset_id', None)
filename = content.get('filename', None)
var_name = content.get('var_name', 'df')
meta_url = f"{os.environ['DATA_SERVICE_URL']}/datasets/{dataset_id}"
if not filename:
dataset_meta = requests.get(meta_url).json()
filename = dataset_meta.get('file_names', [])[0]
data_url_req = requests.get(f'{meta_url}/download-url?filename={filename}')
data_url = data_url_req.json().get('url', None)
self.shell.ex(
"""import pandas as pd; import numpy as np; import scipy;\n"""
"""import sympy; import itertools; from mira.metamodel import *; from mira.modeling import Model;\n"""
"""from mira.modeling.askenet.petrinet import AskeNetPetriNetModel; from mira.modeling.viz import GraphicalModel;\n"""
f"""{var_name} = pd.read_csv({data_url});"""
f"""{var_name};"""
)


async def load_mira_model(self, queue, message_id, message, **kwargs):
content = message.get('content', {})
model_id = content.get('model_id', None)
var_name = content.get('var_name', 'model')

model_url = f"{os.environ['DATA_SERVICE_URL']}/models/{model_id}"

self.shell.ex(
"""import requests; import pandas as pd; import numpy as np; import scipy;\n"""
"""import json; from mira.sources.askenet.petrinet import template_model_from_askenet_json;\n"""
"""import sympy; import itertools; from mira.metamodel import *; from mira.modeling import Model;\n"""
"""from mira.modeling.askenet.petrinet import AskeNetPetriNetModel; from mira.modeling.viz import GraphicalModel;\n"""
f"""amr = requests.get({model_url}).json(); {var_name} = template_model_from_askenet_json(amr));"""
f"""${var_name};"""
)

async def execute_request(self, stream, ident, parent):
# Rewrite parent so that this is properly tied to requests in terarium
# notebook_item = parent.get('metadata', {}).get('notebook_item', None)
Expand Down Expand Up @@ -234,6 +291,55 @@ async def save_dataset_request(self, queue, message_id, message, **kwargs):
},
channel="iopub",
)


async def save_amr_request(self, queue, message_id, message, **kwargs):
self.send_response(
stream=self.iopub_socket,
msg_or_type="status",
content={
"execution_state": "busy",
},
channel="iopub",
)

content = message.get('content', {})
parent_model_id = content.get("parent_model_id")
new_name = content.get("name")
var_name = 'model'

amr = self.shell.ev(f"AskeNetPetriNetModel(Model({var_name})).to_json()")

parent_url = f"{os.environ['DATA_SERVICE_URL']}/models/{parent_model_id}"
parent_model = requests.get(parent_url).json()
if not parent_model:
raise Exception(f"Unable to locate parent model '{parent_model_id}'")

new_model = copy.deepcopy(parent_model)
del new_model["id"]
new_model["name"] = new_name
new_model["description"] += f"\nTransformed from model '{parent_model['name']}' ({parent_model['id']}) at {datetime.datetime.utcnow().strftime('%c %Z')}"

create_req = requests.post(f"{os.environ['DATA_SERVICE_URL']}/models", json=new_model)
new_model_id = create_req.json()["id"]

self.send_response(
stream=self.iopub_socket,
msg_or_type="save_model_response",
content={
"model_id": new_model_id,
"parent_model_id": parent_model_id
},
channel="iopub",
)
self.send_response(
stream=self.iopub_socket,
msg_or_type="status",
content={
"execution_state": "idle",
},
channel="iopub",
)

async def do_execute(self, code, silent, store_history=True, user_expressions=None, allow_stdin=False, *, cell_id=None):
result = await super().do_execute(code, silent, store_history, user_expressions, allow_stdin, cell_id=cell_id)
Expand Down
Loading

0 comments on commit fc7804e

Please sign in to comment.