-
Notifications
You must be signed in to change notification settings - Fork 901
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(agents-api): added docs hybrid search
- Loading branch information
1 parent
358b60b
commit 830206b
Showing
15 changed files
with
303 additions
and
287 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
239 changes: 97 additions & 142 deletions
239
agents-api/agents_api/queries/docs/search_docs_hybrid.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,158 +1,113 @@ | ||
from typing import List, Literal | ||
from typing import List, Any, Literal | ||
from uuid import UUID | ||
|
||
from beartype import beartype | ||
|
||
from ...autogen.openapi_model import Doc | ||
from .search_docs_by_embedding import search_docs_by_embedding | ||
from .search_docs_by_text import search_docs_by_text | ||
|
||
|
||
def dbsf_normalize(scores: List[float]) -> List[float]: | ||
""" | ||
Example distribution-based normalization: clamp each score | ||
from (mean - 3*stddev) to (mean + 3*stddev) and scale to 0..1 | ||
""" | ||
import statistics | ||
|
||
if len(scores) < 2: | ||
return scores | ||
m = statistics.mean(scores) | ||
sd = statistics.pstdev(scores) # population std | ||
if sd == 0: | ||
return scores | ||
upper = m + 3 * sd | ||
lower = m - 3 * sd | ||
|
||
def clamp_scale(v): | ||
c = min(upper, max(lower, v)) | ||
return (c - lower) / (upper - lower) | ||
|
||
return [clamp_scale(s) for s in scores] | ||
|
||
|
||
@beartype | ||
def fuse_results( | ||
text_docs: List[Doc], embedding_docs: List[Doc], alpha: float | ||
) -> List[Doc]: | ||
""" | ||
Merges text search results (descending by text rank) with | ||
embedding results (descending by closeness or inverse distance). | ||
alpha ~ how much to weigh the embedding score | ||
""" | ||
# Suppose we stored each doc's "distance" from the embedding query, and | ||
# for text search we store a rank or negative distance. We'll unify them: | ||
# Make up a dictionary of doc_id -> text_score, doc_id -> embed_score | ||
# For example, text_score = -distance if you want bigger = better | ||
text_scores = {} | ||
embed_scores = {} | ||
for doc in text_docs: | ||
# If you had "rank", you might store doc.distance = rank | ||
# For demo, let's assume doc.distance is negative... up to you | ||
text_scores[doc.id] = float(-doc.distance if doc.distance else 0) | ||
|
||
for doc in embedding_docs: | ||
# Lower distance => better, so we do embed_score = -distance | ||
embed_scores[doc.id] = float(-doc.distance if doc.distance else 0) | ||
|
||
# Normalize them | ||
text_vals = list(text_scores.values()) | ||
embed_vals = list(embed_scores.values()) | ||
text_vals_norm = dbsf_normalize(text_vals) | ||
embed_vals_norm = dbsf_normalize(embed_vals) | ||
|
||
# Map them back | ||
t_keys = list(text_scores.keys()) | ||
for i, key in enumerate(t_keys): | ||
text_scores[key] = text_vals_norm[i] | ||
e_keys = list(embed_scores.keys()) | ||
for i, key in enumerate(e_keys): | ||
embed_scores[key] = embed_vals_norm[i] | ||
|
||
# Gather all doc IDs | ||
all_ids = set(text_scores.keys()) | set(embed_scores.keys()) | ||
|
||
# Weighted sum => combined | ||
out = [] | ||
for doc_id in all_ids: | ||
# text and embed might be missing doc_id => 0 | ||
t_score = text_scores.get(doc_id, 0) | ||
e_score = embed_scores.get(doc_id, 0) | ||
combined = alpha * e_score + (1 - alpha) * t_score | ||
# We'll store final "distance" as -(combined) so bigger combined => smaller distance | ||
out.append((doc_id, combined)) | ||
|
||
# Sort descending by combined | ||
out.sort(key=lambda x: x[1], reverse=True) | ||
|
||
# Convert to doc objects. We can pick from text_docs or embedding_docs or whichever is found. | ||
# If present in both, we can merge fields. For simplicity, just pick from text_docs then fallback embedding_docs. | ||
|
||
# Create a quick ID->doc map | ||
text_map = {d.id: d for d in text_docs} | ||
embed_map = {d.id: d for d in embedding_docs} | ||
|
||
final_docs = [] | ||
for doc_id, score in out: | ||
doc = text_map.get(doc_id) or embed_map.get(doc_id) | ||
doc = doc.model_copy() # or copy if you are using Pydantic | ||
doc.distance = float(-score) # so a higher combined => smaller distance | ||
final_docs.append(doc) | ||
return final_docs | ||
|
||
|
||
from ...autogen.openapi_model import DocReference | ||
import asyncpg | ||
from fastapi import HTTPException | ||
|
||
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class | ||
|
||
# Raw query for hybrid search | ||
search_docs_hybrid_query = """ | ||
SELECT * FROM search_hybrid( | ||
$1, -- developer_id | ||
$2, -- text_query | ||
$3::vector(1024), -- embedding | ||
$4::text[], -- owner_types | ||
$UUID_LIST::uuid[], -- owner_ids | ||
$5, -- k | ||
$6, -- alpha | ||
$7, -- confidence | ||
$8, -- metadata_filter | ||
$9 -- search_language | ||
) | ||
""" | ||
|
||
|
||
@rewrap_exceptions( | ||
{ | ||
asyncpg.UniqueViolationError: partialclass( | ||
HTTPException, | ||
status_code=404, | ||
detail="The specified developer does not exist.", | ||
) | ||
} | ||
) | ||
@wrap_in_class( | ||
DocReference, | ||
transform=lambda d: { | ||
"owner": { | ||
"id": d["owner_id"], | ||
"role": d["owner_type"], | ||
}, | ||
"metadata": d.get("metadata", {}), | ||
**d, | ||
}, | ||
) | ||
|
||
@pg_query | ||
@beartype | ||
async def search_docs_hybrid( | ||
developer_id: UUID, | ||
owners: list[tuple[Literal["user", "agent"], UUID]], | ||
text_query: str = "", | ||
embedding: List[float] = None, | ||
k: int = 10, | ||
alpha: float = 0.5, | ||
owner_type: Literal["user", "agent", "org"] | None = None, | ||
owner_id: UUID | None = None, | ||
) -> List[Doc]: | ||
metadata_filter: dict[str, Any] = {}, | ||
search_language: str = "english", | ||
confidence: float = 0.5, | ||
) -> tuple[str, list]: | ||
""" | ||
Hybrid text-and-embedding doc search. We get top-K from each approach, | ||
then fuse them client-side. Adjust concurrency or approach as you like. | ||
""" | ||
# We'll dispatch two queries in parallel | ||
# (One full-text, one embedding-based) each limited to K | ||
tasks = [] | ||
if text_query.strip(): | ||
tasks.append( | ||
search_docs_by_text( | ||
developer_id=developer_id, | ||
query=text_query, | ||
k=k, | ||
owner_type=owner_type, | ||
owner_id=owner_id, | ||
) | ||
) | ||
else: | ||
tasks.append([]) # no text results if query is empty | ||
|
||
if embedding and any(embedding): | ||
tasks.append( | ||
search_docs_by_embedding( | ||
developer_id=developer_id, | ||
query_embedding=embedding, | ||
k=k, | ||
owner_type=owner_type, | ||
owner_id=owner_id, | ||
) | ||
) | ||
else: | ||
tasks.append([]) | ||
# Run concurrently (or sequentially, if you prefer) | ||
# If you have a 'run_concurrently' from your old code, you can do: | ||
# text_results, embed_results = await run_concurrently([task1, task2]) | ||
# Otherwise just do them in parallel with e.g. asyncio.gather: | ||
from asyncio import gather | ||
|
||
text_results, embed_results = await gather(*tasks) | ||
Parameters: | ||
developer_id (UUID): The unique identifier for the developer. | ||
text_query (str): The text query to search for. | ||
embedding (List[float]): The embedding to search for. | ||
k (int): The number of results to return. | ||
alpha (float): The weight for the embedding results. | ||
owner_type (Literal["user", "agent", "org"] | None): The type of the owner. | ||
owner_id (UUID | None): The ID of the owner. | ||
Returns: | ||
tuple[str, list]: The SQL query and parameters for the search. | ||
""" | ||
|
||
# fuse them | ||
fused = fuse_results(text_results, embed_results, alpha) | ||
# Then pick top K overall | ||
return fused[:k] | ||
if k < 1: | ||
raise HTTPException(status_code=400, detail="k must be >= 1") | ||
|
||
if not text_query and not embedding: | ||
raise HTTPException(status_code=400, detail="Empty query provided") | ||
|
||
if not embedding: | ||
raise HTTPException(status_code=400, detail="Empty embedding provided") | ||
|
||
# Convert query_embedding to a string | ||
embedding_str = f"[{', '.join(map(str, embedding))}]" | ||
|
||
# Extract owner types and IDs | ||
owner_types: list[str] = [owner[0] for owner in owners] | ||
owner_ids: list[str] = [str(owner[1]) for owner in owners] | ||
|
||
# NOTE: Manually replace uuids list coz asyncpg isnt sending it correctly | ||
owner_ids_pg_str = f"ARRAY['{'\', \''.join(owner_ids)}']" | ||
query = search_docs_hybrid_query.replace("$UUID_LIST", owner_ids_pg_str) | ||
|
||
return ( | ||
query, | ||
[ | ||
developer_id, | ||
text_query, | ||
embedding_str, | ||
owner_types, | ||
k, | ||
alpha, | ||
confidence, | ||
metadata_filter, | ||
search_language, | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.