From d6d232a6fdc9b2204f7308ae30932236206cfa0c Mon Sep 17 00:00:00 2001 From: Brandon Rose Date: Mon, 10 Jun 2024 16:01:02 -0500 Subject: [PATCH] cleanup where model config is loaded so it can be used globally by the Agent --- src/askem_beaker/contexts/mira_config_edit/agent.py | 8 +++----- src/askem_beaker/contexts/mira_config_edit/context.py | 2 ++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/askem_beaker/contexts/mira_config_edit/agent.py b/src/askem_beaker/contexts/mira_config_edit/agent.py index efa736b..07c32a7 100644 --- a/src/askem_beaker/contexts/mira_config_edit/agent.py +++ b/src/askem_beaker/contexts/mira_config_edit/agent.py @@ -13,7 +13,6 @@ logging.disable(logging.WARNING) # Disable warnings logger = logging.Logger(__name__) -from mira.sources.amr import model_from_json; class MiraConfigEditAgent(BaseAgent): """ @@ -82,17 +81,16 @@ async def update_parameters(self, parameter_values: dict, agent: AgentRef, loop: """ # load in model config's parameters to use in comparison to the # user provided parameters to update - model_config = model_from_json(agent.context.amr) - model_params = model_config.parameters.keys() + model_params = agent.context.model_config.parameters.keys() user_params = parameter_values['parameter_values'].keys() # check if any in user_params is not in model_params and return an error if not all(param in model_params for param in user_params): loop.set_state(loop.STOP_FATAL) error_message = f"It looks like you're trying to update parameter(s) that don't exist: " \ - f"{', '.join(param for param in user_params if param not in model_params)}." \ + f"[{', '.join(param for param in user_params if param not in model_params)}]. " \ f"Please ensure you are updating a valid parameter: " \ - f"{', '.join(param for param in model_params)}" + f"[{', '.join(param for param in model_params)}]." return error_message loop.set_state(loop.STOP_SUCCESS) diff --git a/src/askem_beaker/contexts/mira_config_edit/context.py b/src/askem_beaker/contexts/mira_config_edit/context.py index a95a13b..8b464ab 100644 --- a/src/askem_beaker/contexts/mira_config_edit/context.py +++ b/src/askem_beaker/contexts/mira_config_edit/context.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) +from mira.sources.amr import model_from_json; class MiraConfigEditContext(BaseContext): @@ -64,6 +65,7 @@ async def set_model_config(self, item_id, agent=None, parent_header={}): self.original_amr = copy.deepcopy(self.amr) if self.amr: await self.load_mira() + self.model_config = model_from_json(self.amr) else: raise Exception(f"Model config '{item_id}' not found.") await self.send_mira_preview_message(parent_header=parent_header)