From 74a25107ea8414dd72e5a2c8dffbfc8cdbcda68a Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 19 Dec 2024 15:49:10 +0300 Subject: [PATCH] feat: Add SQL validation --- agents-api/agents_api/exceptions.py | 9 ++ .../queries/chat/prepare_chat_context.py | 90 ++++++++++--------- 2 files changed, 56 insertions(+), 43 deletions(-) diff --git a/agents-api/agents_api/exceptions.py b/agents-api/agents_api/exceptions.py index 615958a87..f6fcc4741 100644 --- a/agents-api/agents_api/exceptions.py +++ b/agents-api/agents_api/exceptions.py @@ -49,3 +49,12 @@ class FailedEncodingSentinel: """Sentinel object returned when failed to encode payload.""" payload_data: bytes + + +class QueriesBaseException(AgentsBaseException): + pass + + +class InvalidSQLQuery(QueriesBaseException): + def __init__(self, query_name: str): + super().__init__(f"invalid query: {query_name}") diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py index 23926ea4c..1d9bd52fb 100644 --- a/agents-api/agents_api/queries/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -1,9 +1,11 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype from ...common.protocol.sessions import ChatContext, make_session +from ...exceptions import InvalidSQLQuery from ..utils import ( pg_query, wrap_in_class, @@ -13,19 +15,19 @@ T = TypeVar("T") -query = """ -SELECT * FROM +sql_query = sqlvalidator.parse( + """SELECT * FROM ( SELECT jsonb_agg(u) AS users FROM ( SELECT session_lookup.participant_id, users.user_id AS id, - users.developer_id, - users.name, - users.about, - users.created_at, - users.updated_at, - users.metadata + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata FROM session_lookup INNER JOIN users ON session_lookup.participant_id = users.user_id WHERE @@ -39,16 +41,16 @@ SELECT session_lookup.participant_id, agents.agent_id AS id, - agents.developer_id, - agents.canonical_name, - agents.name, - agents.about, - agents.instructions, - agents.model, - agents.created_at, - agents.updated_at, - agents.metadata, - agents.default_settings + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings FROM session_lookup INNER JOIN agents ON session_lookup.participant_id = agents.agent_id WHERE @@ -58,24 +60,24 @@ ) a ) AS agents, ( - SELECT to_jsonb(s) AS session FROM ( + SELECT to_jsonb(s) AS session FROM ( SELECT sessions.session_id AS id, - sessions.developer_id, - sessions.situation, - sessions.system_template, - sessions.created_at, - sessions.metadata, - sessions.render_templates, - sessions.token_budget, - sessions.context_overflow, - sessions.forward_tool_calls, - sessions.recall_options + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options FROM sessions WHERE developer_id = $1 AND session_id = $2 - LIMIT 1 + LIMIT 1 ) s ) AS session, ( @@ -83,16 +85,16 @@ SELECT session_lookup.participant_id, tools.tool_id as id, - tools.developer_id, - tools.agent_id, - tools.task_id, - tools.task_version, - tools.type, - tools.name, - tools.description, - tools.spec, - tools.updated_at, - tools.created_at + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.task_version, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at FROM session_lookup INNER JOIN tools ON session_lookup.participant_id = tools.agent_id WHERE @@ -100,8 +102,10 @@ session_id = $2 AND session_lookup.participant_type = 'agent' ) r -) AS toolsets -""" +) AS toolsets""" +) +if not sql_query.is_valid(): + raise InvalidSQLQuery("prepare_chat_context") def _transform(d): @@ -160,6 +164,6 @@ async def prepare_chat_context( """ return ( - [query], + [sql_query.format()], [developer_id, session_id], )