diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py index 31b44e7b4..3862131bb 100644 --- a/agents-api/agents_api/queries/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -23,10 +23,10 @@ from .delete_doc import delete_doc from .get_doc import get_doc from .list_docs import list_docs - from .search_docs_by_embedding import search_docs_by_embedding from .search_docs_by_text import search_docs_by_text from .search_docs_hybrid import search_docs_hybrid + __all__ = [ "create_doc", "delete_doc", diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index 9c8b15955..d573b4d8f 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -1,12 +1,12 @@ from typing import Any, List, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException -import asyncpg from ...autogen.openapi_model import DocReference -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Raw query for vector search search_docs_by_embedding_query = """ @@ -21,6 +21,7 @@ ) """ + @rewrap_exceptions( { asyncpg.UniqueViolationError: partialclass( diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index 8e14f36dd..aa27ed648 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -1,12 +1,11 @@ -from typing import List, Any, Literal +from typing import Any, List, Literal from uuid import UUID -from beartype import beartype - -from ...autogen.openapi_model import DocReference import asyncpg +from beartype import beartype from fastapi import HTTPException +from ...autogen.openapi_model import DocReference from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Raw query for hybrid search @@ -46,7 +45,6 @@ **d, }, ) - @pg_query @beartype async def search_docs_hybrid( diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py index b91964a39..70277ab99 100644 --- a/agents-api/agents_api/queries/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -3,20 +3,19 @@ from typing import Any from uuid import UUID +import asyncpg from beartype import beartype -from uuid_extensions import uuid7 from fastapi import HTTPException -import asyncpg -from sqlglot import parse_one +from sqlglot import parse_one +from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateToolRequest, Tool from ...metrics.counters import increase_counter - from ..utils import ( + partialclass, pg_query, rewrap_exceptions, wrap_in_class, - partialclass, ) # Define the raw SQL query for creating tools @@ -50,15 +49,15 @@ { asyncpg.UniqueViolationError: partialclass( HTTPException, - status_code=409, - detail="A tool with this name already exists for this agent" - ), + status_code=409, + detail="A tool with this name already exists for this agent", + ), asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, detail="Agent not found", ), -} + } ) @wrap_in_class( Tool, @@ -113,4 +112,3 @@ async def create_tools( tools_data, "fetchmany", ) - diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py index 9a507523d..32fca1571 100644 --- a/agents-api/agents_api/queries/tools/delete_tool.py +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -1,20 +1,14 @@ from typing import Any from uuid import UUID -from fastapi import HTTPException +import asyncpg from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow -from sqlglot import parse_one -import asyncpg - -from ..utils import ( - pg_query, - wrap_in_class, - rewrap_exceptions, - partialclass -) +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for deleting a tool tools_query = parse_one(""" @@ -29,14 +23,14 @@ @rewrap_exceptions( -{ + { # Handle foreign key constraint asyncpg.ForeignKeyViolationError: partialclass( HTTPException, status_code=404, detail="Developer or agent not found", ), -} + } ) @wrap_in_class( ResourceDeletedResponse, diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py index 9f71dec40..6f25d3893 100644 --- a/agents-api/agents_api/queries/tools/get_tool.py +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -1,19 +1,13 @@ from typing import Any from uuid import UUID +import asyncpg from beartype import beartype - -from ...autogen.openapi_model import Tool -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg -from ..utils import ( - pg_query, - wrap_in_class, - rewrap_exceptions, - partialclass -) +from sqlglot import parse_one +from ...autogen.openapi_model import Tool +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for getting a tool tools_query = parse_one(""" @@ -25,6 +19,7 @@ LIMIT 1 """).sql(pretty=True) + @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index 937442797..0171f5093 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -3,13 +3,13 @@ import sqlvalidator from beartype import beartype - from sqlglot import parse_one + from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query for getting tool args from metadata @@ -54,7 +54,6 @@ ) AS sessions_md""").sql(pretty=True) - # @rewrap_exceptions( # { # QueryException: partialclass(HTTPException, status_code=400), diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py index d85bb9da0..fbd14f8b1 100644 --- a/agents-api/agents_api/queries/tools/list_tools.py +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -1,18 +1,13 @@ from typing import Literal from uuid import UUID -from beartype import beartype import asyncpg +from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import Tool -from sqlglot import parse_one -from ..utils import ( - pg_query, - wrap_in_class, - rewrap_exceptions, - partialclass -) +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for listing tools tools_query = parse_one(""" @@ -30,13 +25,13 @@ @rewrap_exceptions( -{ - asyncpg.ForeignKeyViolationError: partialclass( - HTTPException, - status_code=400, - detail="Developer or agent not found", - ), -} + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=400, + detail="Developer or agent not found", + ), + } ) @wrap_in_class( Tool, diff --git a/agents-api/agents_api/queries/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py index fb4c680e1..b65eca481 100644 --- a/agents-api/agents_api/queries/tools/patch_tool.py +++ b/agents-api/agents_api/queries/tools/patch_tool.py @@ -1,19 +1,14 @@ from typing import Any from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import PatchToolRequest, ResourceUpdatedResponse -from sqlglot import parse_one -import asyncpg -from fastapi import HTTPException from ...metrics.counters import increase_counter -from ..utils import ( - pg_query, - wrap_in_class, - rewrap_exceptions, - partialclass -) +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for patching a tool tools_query = parse_one(""" @@ -35,13 +30,13 @@ @rewrap_exceptions( -{ - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=409, - detail="Developer or agent not found", - ), -} + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="Developer or agent not found", + ), + } ) @wrap_in_class( ResourceUpdatedResponse, diff --git a/agents-api/agents_api/queries/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py index 18ff44f18..45c5a022d 100644 --- a/agents-api/agents_api/queries/tools/update_tool.py +++ b/agents-api/agents_api/queries/tools/update_tool.py @@ -1,24 +1,18 @@ +import json from typing import Any, TypeVar from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import ( ResourceUpdatedResponse, UpdateToolRequest, ) -import asyncpg -import json -from fastapi import HTTPException - -from sqlglot import parse_one from ...metrics.counters import increase_counter -from ..utils import ( - pg_query, - wrap_in_class, - rewrap_exceptions, - partialclass -) +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for updating a tool tools_query = parse_one(""" @@ -37,18 +31,18 @@ @rewrap_exceptions( -{ - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=409, - detail="A tool with this name already exists for this agent", - ), - json.JSONDecodeError: partialclass( - HTTPException, - status_code=400, - detail="Invalid tool specification format", - ), -} + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A tool with this name already exists for this agent", + ), + json.JSONDecodeError: partialclass( + HTTPException, + status_code=400, + detail="Invalid tool specification format", + ), + } ) @wrap_in_class( ResourceUpdatedResponse, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 1760209a8..2c43ba9d6 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -30,7 +30,6 @@ from agents_api.queries.tasks.create_task import create_task from agents_api.queries.tools.create_tools import create_tools from agents_api.queries.users.create_user import create_user -from agents_api.queries.users.create_user import create_user from agents_api.web import app from .utils import ( diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 125033276..f0070adfe 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -13,6 +13,7 @@ EMBEDDING_SIZE: int = 1024 + @test("query: create user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -276,6 +277,7 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): assert len(result) >= 1 assert result[0].metadata is not None + @test("query: search docs by hybrid") async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -306,4 +308,4 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): ) assert len(result) >= 1 - assert result[0].metadata is not None \ No newline at end of file + assert result[0].metadata is not None