Skip to content

Commit

Permalink
feat(agents-api): added docs hybrid search
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 24, 2024
1 parent 358b60b commit 830206b
Show file tree
Hide file tree
Showing 15 changed files with 303 additions and 287 deletions.
9 changes: 6 additions & 3 deletions agents-api/agents_api/queries/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- Deleting documents by their unique identifiers.
- Embedding document snippets for retrieval purposes.
- Searching documents by text.
- Searching documents by hybrid text and embedding.
- Searching documents by embedding.
The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users.
Expand All @@ -22,14 +24,15 @@
from .get_doc import get_doc
from .list_docs import list_docs

# from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text

from .search_docs_hybrid import search_docs_hybrid
__all__ = [
"create_doc",
"delete_doc",
"get_doc",
"list_docs",
# "search_docs_by_embedding",
"search_docs_by_embedding",
"search_docs_by_text",
"search_docs_hybrid",
]
14 changes: 12 additions & 2 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

from beartype import beartype
from fastapi import HTTPException
import asyncpg

from ...autogen.openapi_model import DocReference
from ..utils import pg_query, wrap_in_class
from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass

# Raw query for vector search
search_docs_by_embedding_query = """
SELECT * FROM search_by_vector(
$1, -- developer_id
Expand All @@ -19,7 +21,15 @@
)
"""


@rewrap_exceptions(
{
asyncpg.UniqueViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer does not exist.",
)
}
)
@wrap_in_class(
DocReference,
transform=lambda d: {
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...autogen.openapi_model import DocReference
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Raw query for text search
search_docs_text_query = """
SELECT * FROM search_by_text(
$1, -- developer_id
Expand Down
239 changes: 97 additions & 142 deletions agents-api/agents_api/queries/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
@@ -1,158 +1,113 @@
from typing import List, Literal
from typing import List, Any, Literal
from uuid import UUID

from beartype import beartype

from ...autogen.openapi_model import Doc
from .search_docs_by_embedding import search_docs_by_embedding
from .search_docs_by_text import search_docs_by_text


def dbsf_normalize(scores: List[float]) -> List[float]:
"""
Example distribution-based normalization: clamp each score
from (mean - 3*stddev) to (mean + 3*stddev) and scale to 0..1
"""
import statistics

if len(scores) < 2:
return scores
m = statistics.mean(scores)
sd = statistics.pstdev(scores) # population std
if sd == 0:
return scores
upper = m + 3 * sd
lower = m - 3 * sd

def clamp_scale(v):
c = min(upper, max(lower, v))
return (c - lower) / (upper - lower)

return [clamp_scale(s) for s in scores]


@beartype
def fuse_results(
text_docs: List[Doc], embedding_docs: List[Doc], alpha: float
) -> List[Doc]:
"""
Merges text search results (descending by text rank) with
embedding results (descending by closeness or inverse distance).
alpha ~ how much to weigh the embedding score
"""
# Suppose we stored each doc's "distance" from the embedding query, and
# for text search we store a rank or negative distance. We'll unify them:
# Make up a dictionary of doc_id -> text_score, doc_id -> embed_score
# For example, text_score = -distance if you want bigger = better
text_scores = {}
embed_scores = {}
for doc in text_docs:
# If you had "rank", you might store doc.distance = rank
# For demo, let's assume doc.distance is negative... up to you
text_scores[doc.id] = float(-doc.distance if doc.distance else 0)

for doc in embedding_docs:
# Lower distance => better, so we do embed_score = -distance
embed_scores[doc.id] = float(-doc.distance if doc.distance else 0)

# Normalize them
text_vals = list(text_scores.values())
embed_vals = list(embed_scores.values())
text_vals_norm = dbsf_normalize(text_vals)
embed_vals_norm = dbsf_normalize(embed_vals)

# Map them back
t_keys = list(text_scores.keys())
for i, key in enumerate(t_keys):
text_scores[key] = text_vals_norm[i]
e_keys = list(embed_scores.keys())
for i, key in enumerate(e_keys):
embed_scores[key] = embed_vals_norm[i]

# Gather all doc IDs
all_ids = set(text_scores.keys()) | set(embed_scores.keys())

# Weighted sum => combined
out = []
for doc_id in all_ids:
# text and embed might be missing doc_id => 0
t_score = text_scores.get(doc_id, 0)
e_score = embed_scores.get(doc_id, 0)
combined = alpha * e_score + (1 - alpha) * t_score
# We'll store final "distance" as -(combined) so bigger combined => smaller distance
out.append((doc_id, combined))

# Sort descending by combined
out.sort(key=lambda x: x[1], reverse=True)

# Convert to doc objects. We can pick from text_docs or embedding_docs or whichever is found.
# If present in both, we can merge fields. For simplicity, just pick from text_docs then fallback embedding_docs.

# Create a quick ID->doc map
text_map = {d.id: d for d in text_docs}
embed_map = {d.id: d for d in embedding_docs}

final_docs = []
for doc_id, score in out:
doc = text_map.get(doc_id) or embed_map.get(doc_id)
doc = doc.model_copy() # or copy if you are using Pydantic
doc.distance = float(-score) # so a higher combined => smaller distance
final_docs.append(doc)
return final_docs


from ...autogen.openapi_model import DocReference
import asyncpg
from fastapi import HTTPException

from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Raw query for hybrid search
search_docs_hybrid_query = """
SELECT * FROM search_hybrid(
$1, -- developer_id
$2, -- text_query
$3::vector(1024), -- embedding
$4::text[], -- owner_types
$UUID_LIST::uuid[], -- owner_ids
$5, -- k
$6, -- alpha
$7, -- confidence
$8, -- metadata_filter
$9 -- search_language
)
"""


@rewrap_exceptions(
{
asyncpg.UniqueViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer does not exist.",
)
}
)
@wrap_in_class(
DocReference,
transform=lambda d: {
"owner": {
"id": d["owner_id"],
"role": d["owner_type"],
},
"metadata": d.get("metadata", {}),
**d,
},
)

@pg_query
@beartype
async def search_docs_hybrid(
developer_id: UUID,
owners: list[tuple[Literal["user", "agent"], UUID]],
text_query: str = "",
embedding: List[float] = None,
k: int = 10,
alpha: float = 0.5,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_id: UUID | None = None,
) -> List[Doc]:
metadata_filter: dict[str, Any] = {},
search_language: str = "english",
confidence: float = 0.5,
) -> tuple[str, list]:
"""
Hybrid text-and-embedding doc search. We get top-K from each approach,
then fuse them client-side. Adjust concurrency or approach as you like.
"""
# We'll dispatch two queries in parallel
# (One full-text, one embedding-based) each limited to K
tasks = []
if text_query.strip():
tasks.append(
search_docs_by_text(
developer_id=developer_id,
query=text_query,
k=k,
owner_type=owner_type,
owner_id=owner_id,
)
)
else:
tasks.append([]) # no text results if query is empty

if embedding and any(embedding):
tasks.append(
search_docs_by_embedding(
developer_id=developer_id,
query_embedding=embedding,
k=k,
owner_type=owner_type,
owner_id=owner_id,
)
)
else:
tasks.append([])
# Run concurrently (or sequentially, if you prefer)
# If you have a 'run_concurrently' from your old code, you can do:
# text_results, embed_results = await run_concurrently([task1, task2])
# Otherwise just do them in parallel with e.g. asyncio.gather:
from asyncio import gather

text_results, embed_results = await gather(*tasks)
Parameters:
developer_id (UUID): The unique identifier for the developer.
text_query (str): The text query to search for.
embedding (List[float]): The embedding to search for.
k (int): The number of results to return.
alpha (float): The weight for the embedding results.
owner_type (Literal["user", "agent", "org"] | None): The type of the owner.
owner_id (UUID | None): The ID of the owner.
Returns:
tuple[str, list]: The SQL query and parameters for the search.
"""

# fuse them
fused = fuse_results(text_results, embed_results, alpha)
# Then pick top K overall
return fused[:k]
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")

if not text_query and not embedding:
raise HTTPException(status_code=400, detail="Empty query provided")

if not embedding:
raise HTTPException(status_code=400, detail="Empty embedding provided")

# Convert query_embedding to a string
embedding_str = f"[{', '.join(map(str, embedding))}]"

# Extract owner types and IDs
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

# NOTE: Manually replace uuids list coz asyncpg isnt sending it correctly
owner_ids_pg_str = f"ARRAY['{'\', \''.join(owner_ids)}']"
query = search_docs_hybrid_query.replace("$UUID_LIST", owner_ids_pg_str)

return (
query,
[
developer_id,
text_query,
embedding_str,
owner_types,
k,
alpha,
confidence,
metadata_filter,
search_language,
],
)
10 changes: 10 additions & 0 deletions agents-api/agents_api/queries/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@
from .list_tools import list_tools
from .patch_tool import patch_tool
from .update_tool import update_tool

__all__ = [
"create_tools",
"delete_tool",
"get_tool",
"get_tool_args_from_metadata",
"list_tools",
"patch_tool",
"update_tool",
]
Loading

0 comments on commit 830206b

Please sign in to comment.