Skip to content

Commit

Permalink
Added yaml validator based on reference yaml. Simple but elegant apro…
Browse files Browse the repository at this point in the history
…ach.
  • Loading branch information
Alex-Karmazin committed Oct 30, 2024
1 parent 867590f commit 360bed4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 27 deletions.
4 changes: 2 additions & 2 deletions just_agents/cot_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from just_agents.streaming.protocols.openai_streaming import OpenaiStreamingProtocol
from just_agents.streaming.protocols.abstract_protocol import AbstractStreamingProtocol
from pathlib import Path
from just_agents.utils import _resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools
from just_agents.utils import resolve_and_validate_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools

# schema parameters:
LLM_SESSION = "llm_session"
Expand All @@ -21,7 +21,7 @@ class ChainOfThoughtAgent(IAgent):

def __init__(self, llm_options: dict = None, agent_schema: str | Path | dict | None = None,
tools: list = None, output_streaming:AbstractStreamingProtocol = OpenaiStreamingProtocol()):
self.agent_schema: dict = _resolve_agent_schema(agent_schema, "cot_agent_prompt.yaml")
self.agent_schema: dict = resolve_and_validate_agent_schema(agent_schema, "cot_agent_prompt.yaml")
if tools is None:
tools = resolve_tools(self.agent_schema)
self.session: LLMSession = LLMSession(llm_options=resolve_llm_options(self.agent_schema, llm_options),
Expand Down
4 changes: 2 additions & 2 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from just_agents.streaming.abstract_streaming import AbstractStreaming
from just_agents.streaming.openai_streaming import AsyncSession
# from just_agents.utils import rotate_completion
from just_agents.utils import _resolve_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools
from just_agents.utils import resolve_and_validate_agent_schema, resolve_llm_options, resolve_system_prompt, resolve_tools
from just_agents.rotate_keys import RotateKeys

OnCompletion = Callable[[ModelResponse], None]
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(self, llm_options: dict[str, Any] = None,
self.key_getter: RotateKeys = None
self.on_response: list[OnCompletion] = []

self.agent_schema = _resolve_agent_schema(agent_schema, "llm_session_schema.yaml")
self.agent_schema = resolve_and_validate_agent_schema(agent_schema, "llm_session_schema.yaml")
self.llm_options: dict[str, Any] = resolve_llm_options(self.agent_schema, llm_options)
if self.agent_schema.get(KEY_LIST_PATH, None) is not None:
self.key_getter = RotateKeys(self.agent_schema[KEY_LIST_PATH])
Expand Down
56 changes: 33 additions & 23 deletions just_agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,11 @@
from dotenv import load_dotenv
import importlib
from typing import Callable
from litellm import Message, ModelResponse, completion
import copy

#
# class RotateKeys():
# keys:list[str]
#
# def __init__(self, file_path:str):
# with open(file_path) as f:
# text = f.read().strip()
# self.keys = text.split("\n")
#
# def __call__(self, *args, **kwargs):
# return random.choice(self.keys)
#
# def remove(self, key:str):
# self.keys.remove(key)
#
# def len(self):
# return len(self.keys)

class SchemaValidationError(ValueError):
pass

VALIDATION_EXTRAS = ["package", "function"]

def resolve_agent_schema(agent_schema: str | Path | dict):
"""
Expand All @@ -47,11 +32,36 @@ def resolve_agent_schema(agent_schema: str | Path | dict):

return agent_schema

def _resolve_agent_schema(agent_schema: str | Path | dict | None, default_file_name: str):
def resolve_and_validate_agent_schema(agent_schema: str | Path | dict | None, default_file_name: str):
reference_schema = resolve_agent_schema(Path(Path(__file__).parent, "config", default_file_name))
if agent_schema is None:
agent_schema = Path(Path(__file__).parent, "config", default_file_name)
return reference_schema

agent_schema = resolve_agent_schema(agent_schema)
validate_schema(reference_schema, agent_schema)

return agent_schema


def create_fields_set(source: dict[str, Any], fields_set: set[str]):
for key in source:
fields_set.add(key)
if isinstance(source[key], dict):
create_fields_set(source[key], fields_set)


def validate_schema(reference: dict[str, Any], schema: dict[str, Any]):
reference_set: set[str] = set(VALIDATION_EXTRAS)
schema_set: set[str] = set()
create_fields_set(reference, reference_set)
create_fields_set(schema, schema_set)
error_fields = []
for field in schema_set:
if field not in reference_set:
error_fields.append(field)

return resolve_agent_schema(agent_schema)
if len(error_fields) > 0:
raise SchemaValidationError(f" Fields {error_fields} not exists in yaml schema. Choose from {reference_set}")


def resolve_llm_options(agent_schema: dict, llm_options: dict):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_yaml_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from just_agents.interfaces.IAgent import build_agent
from just_agents.utils import SchemaValidationError

def test_yaml_validator():
try:
agent = build_agent("wrong_yaml.yaml")
except SchemaValidationError as e:
assert "jus_streming_method" in str(e)
assert "backup_optionss" in str(e)
10 changes: 10 additions & 0 deletions tests/wrong_yaml.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class: "LLMSession" # class name to create could be LLMSession, ChainOfThoughtAgent. Default is LLMSession
jus_streming_method: "openai" # protocol to handle llm format for function calling
system_prompt: # system prompt exclude use of system_prompt_path
completion_max_tries: 2 # maximum number of completion retries before giving up
backup_optionss: # options that will be used after we give up with main options, one more completion call will be done with backup options
key_list_path: # path to text file with list of api keys, one key per line
tools: # list of functions that will be used as tools, each record should contain package and function name
drop_params: True # drop params from the request, useful for some models that do not support them
# - package:
# function:

0 comments on commit 360bed4

Please sign in to comment.