Skip to content

Commit

Permalink
Add handler to detect incorrect parameter name (#137)
Browse files Browse the repository at this point in the history
* add handler to detect incorrect parameter name

* cleanup where model config is loaded so it can be used globally by the Agent

---------

Co-authored-by: Brandon Rose <[email protected]>
  • Loading branch information
brandomr and Brandon Rose committed Jun 11, 2024
1 parent f2a1651 commit 8e6f3e0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/askem_beaker/contexts/mira_config_edit/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,25 @@ async def update_parameters(self, parameter_values: dict, agent: AgentRef, loop:
Please generate the code as if you were programming inside a Jupyter Notebook and the code is to be executed inside a cell.
You MUST wrap the code with a line containing three backticks (```) before and after the generated code.
No addtional text is needed in the response, just the code block.
No addtional text is needed in the response, just the code block.
Args:
parameter_values (dict): the dictionary of parameter names and the values to update them with
"""
# load in model config's parameters to use in comparison to the
# user provided parameters to update
model_params = agent.context.model_config.parameters.keys()
user_params = parameter_values['parameter_values'].keys()

# check if any in user_params is not in model_params and return an error
if not all(param in model_params for param in user_params):
loop.set_state(loop.STOP_FATAL)
error_message = f"It looks like you're trying to update parameter(s) that don't exist: " \
f"[{', '.join(param for param in user_params if param not in model_params)}]. " \
f"Please ensure you are updating a valid parameter: " \
f"[{', '.join(param for param in model_params)}]."
return error_message

loop.set_state(loop.STOP_SUCCESS)
code = agent.context.get_code("update_params", {"parameter_values": parameter_values})
return json.dumps(
Expand Down
2 changes: 2 additions & 0 deletions src/askem_beaker/contexts/mira_config_edit/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

logger = logging.getLogger(__name__)

from mira.sources.amr import model_from_json;

class MiraConfigEditContext(BaseContext):

Expand Down Expand Up @@ -64,6 +65,7 @@ async def set_model_config(self, item_id, agent=None, parent_header={}):
self.original_amr = copy.deepcopy(self.amr)
if self.amr:
await self.load_mira()
self.model_config = model_from_json(self.amr)
else:
raise Exception(f"Model config '{item_id}' not found.")
await self.send_mira_preview_message(parent_header=parent_header)
Expand Down

0 comments on commit 8e6f3e0

Please sign in to comment.