diff --git a/just_agents/cot_agent.py b/just_agents/cot_agent.py index e08d88c..204be1d 100644 --- a/just_agents/cot_agent.py +++ b/just_agents/cot_agent.py @@ -4,7 +4,7 @@ from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol from pathlib import Path -from just_agents.utils import _resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools +from just_agents.utils import resolve_and_validate_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools # schema parameters: LLM_SESSION = "llm_session" @@ -21,7 +21,7 @@ class ChainOfThoughtAgent(IAgent): def __init__(self, llm_options: dict = None, agent_schema: str | Path | dict | None = None, tools: list = None, output_streaming:AbstractStreamingProtocol = OpenaiStreamingProtocol()): - self.agent_schema: dict = _resolve_agent_schema(agent_schema, "cot_agent_prompt.yaml") + self.agent_schema: dict = resolve_and_validate_agent_schema(agent_schema, "cot_agent_prompt.yaml") if tools is None: tools = resolve_tools(self.agent_schema) self.session: LLMSession = LLMSession(llm_options=resolve_llm_options(self.agent_schema, llm_options), diff --git a/just_agents/llm_session.py b/just_agents/llm_session.py index f503de4..b9759fb 100644 --- a/just_agents/llm_session.py +++ b/just_agents/llm_session.py @@ -15,7 +15,7 @@ from just_agents.streaming.abstract_streaming import AbstractStreaming from just_agents.streaming.openai_streaming import AsyncSession # from just_agents.utils import rotate_completion -from just_agents.utils import _resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools +from just_agents.utils import resolve_and_validate_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools from just_agents.rotate_keys import RotateKeys OnCompletion = Callable[[ModelResponse], None] @@ -46,7 +46,7 @@ def __init__(self, llm_options: dict[str, Any] = None, self.key_getter: RotateKeys = None self.on_response: list[OnCompletion] = [] - self.agent_schema = _resolve_agent_schema(agent_schema, "llm_session_schema.yaml") + self.agent_schema = resolve_and_validate_agent_schema(agent_schema, "llm_session_schema.yaml") self.llm_options: dict[str, Any] = resolve_llm_options(self.agent_schema, llm_options) if self.agent_schema.get(KEY_LIST_PATH, None) is not None: self.key_getter = RotateKeys(self.agent_schema[KEY_LIST_PATH]) diff --git a/just_agents/utils.py b/just_agents/utils.py index 0bcc1fe..702ac5a 100644 --- a/just_agents/utils.py +++ b/just_agents/utils.py @@ -8,26 +8,11 @@ from dotenv import load_dotenv import importlib from typing import Callable -from litellm import Message, ModelResponse, completion -import copy - -# -# class RotateKeys(): -# keys:list[str] -# -# def __init__(self, file_path:str): -# with open(file_path) as f: -# text = f.read().strip() -# self.keys = text.split("\n") -# -# def __call__(self, *args, **kwargs): -# return random.choice(self.keys) -# -# def remove(self, key:str): -# self.keys.remove(key) -# -# def len(self): -# return len(self.keys) + +class SchemaValidationError(ValueError): + pass + +VALIDATION_EXTRAS = ["package", "function"] def resolve_agent_schema(agent_schema: str | Path | dict): """ @@ -47,11 +32,36 @@ def resolve_agent_schema(agent_schema: str | Path | dict): return agent_schema -def _resolve_agent_schema(agent_schema: str | Path | dict | None, default_file_name: str): +def resolve_and_validate_agent_schema(agent_schema: str | Path | dict | None, default_file_name: str): + reference_schema = resolve_agent_schema(Path(Path(__file__).parent, "config", default_file_name)) if agent_schema is None: - agent_schema = Path(Path(__file__).parent, "config", default_file_name) + return reference_schema + + agent_schema = resolve_agent_schema(agent_schema) + validate_schema(reference_schema, agent_schema) + + return agent_schema + + +def create_fields_set(source: dict[str, Any], fields_set: set[str]): + for key in source: + fields_set.add(key) + if isinstance(source[key], dict): + create_fields_set(source[key], fields_set) + + +def validate_schema(reference: dict[str, Any], schema: dict[str, Any]): + reference_set: set[str] = set(VALIDATION_EXTRAS) + schema_set: set[str] = set() + create_fields_set(reference, reference_set) + create_fields_set(schema, schema_set) + error_fields = [] + for field in schema_set: + if field not in reference_set: + error_fields.append(field) - return resolve_agent_schema(agent_schema) + if len(error_fields) > 0: + raise SchemaValidationError(f" Fields {error_fields} not exists in yaml schema. Choose from {reference_set}") def resolve_llm_options(agent_schema: dict, llm_options: dict): diff --git a/tests/test_yaml_validator.py b/tests/test_yaml_validator.py new file mode 100644 index 0000000..70640ba --- /dev/null +++ b/tests/test_yaml_validator.py @@ -0,0 +1,9 @@ +from just_agents.interfaces.IAgent import build_agent +from just_agents.utils import SchemaValidationError + +def test_yaml_validator(): + try: + agent = build_agent("wrong_yaml.yaml") + except SchemaValidationError as e: + assert "jus_streming_method" in str(e) + assert "backup_optionss" in str(e) \ No newline at end of file diff --git a/tests/wrong_yaml.yaml b/tests/wrong_yaml.yaml new file mode 100644 index 0000000..7e9fb76 --- /dev/null +++ b/tests/wrong_yaml.yaml @@ -0,0 +1,10 @@ +class: "LLMSession" # class name to create could be LLMSession, ChainOfThoughtAgent. Default is LLMSession +jus_streming_method: "openai" # protocol to handle llm format for function calling +system_prompt: # system prompt exclude use of system_prompt_path +completion_max_tries: 2 # maximum number of completion retries before giving up +backup_optionss: # options that will be used after we give up with main options, one more completion call will be done with backup options +key_list_path: # path to text file with list of api keys, one key per line +tools: # list of functions that will be used as tools, each record should contain package and function name +drop_params: True # drop params from the request, useful for some models that do not support them +# - package: +# function: \ No newline at end of file