From 8e6f3e0f36674bc127a173c7b35461d560f1e6eb Mon Sep 17 00:00:00 2001 From: Brandon Rose Date: Tue, 11 Jun 2024 08:51:45 -0500 Subject: [PATCH] Add handler to detect incorrect parameter name (#137) * add handler to detect incorrect parameter name * cleanup where model config is loaded so it can be used globally by the Agent --------- Co-authored-by: Brandon Rose --- .../contexts/mira_config_edit/agent.py | 16 +++++++++++++++- .../contexts/mira_config_edit/context.py | 2 ++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/askem_beaker/contexts/mira_config_edit/agent.py b/src/askem_beaker/contexts/mira_config_edit/agent.py index eb18452..07c32a7 100644 --- a/src/askem_beaker/contexts/mira_config_edit/agent.py +++ b/src/askem_beaker/contexts/mira_config_edit/agent.py @@ -74,11 +74,25 @@ async def update_parameters(self, parameter_values: dict, agent: AgentRef, loop: Please generate the code as if you were programming inside a Jupyter Notebook and the code is to be executed inside a cell. You MUST wrap the code with a line containing three backticks (```) before and after the generated code. - No addtional text is needed in the response, just the code block. + No addtional text is needed in the response, just the code block. Args: parameter_values (dict): the dictionary of parameter names and the values to update them with """ + # load in model config's parameters to use in comparison to the + # user provided parameters to update + model_params = agent.context.model_config.parameters.keys() + user_params = parameter_values['parameter_values'].keys() + + # check if any in user_params is not in model_params and return an error + if not all(param in model_params for param in user_params): + loop.set_state(loop.STOP_FATAL) + error_message = f"It looks like you're trying to update parameter(s) that don't exist: " \ + f"[{', '.join(param for param in user_params if param not in model_params)}]. " \ + f"Please ensure you are updating a valid parameter: " \ + f"[{', '.join(param for param in model_params)}]." + return error_message + loop.set_state(loop.STOP_SUCCESS) code = agent.context.get_code("update_params", {"parameter_values": parameter_values}) return json.dumps( diff --git a/src/askem_beaker/contexts/mira_config_edit/context.py b/src/askem_beaker/contexts/mira_config_edit/context.py index a95a13b..8b464ab 100644 --- a/src/askem_beaker/contexts/mira_config_edit/context.py +++ b/src/askem_beaker/contexts/mira_config_edit/context.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) +from mira.sources.amr import model_from_json; class MiraConfigEditContext(BaseContext): @@ -64,6 +65,7 @@ async def set_model_config(self, item_id, agent=None, parent_header={}): self.original_amr = copy.deepcopy(self.amr) if self.amr: await self.load_mira() + self.model_config = model_from_json(self.amr) else: raise Exception(f"Model config '{item_id}' not found.") await self.send_mira_preview_message(parent_header=parent_header)