Skip to content

Commit

Permalink
feat: make query adapter minimally invasive (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Aug 26, 2024
1 parent 072a968 commit eba82ce
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,5 @@ class RAGLiteConfig:
# Vector search config.
vector_search_index_id: str = "default"
vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine".
# Query adapter config.
enable_query_adapter: bool = True
52 changes: 37 additions & 15 deletions src/raglite/_query_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from raglite._search import vector_search


def update_query_adapter( # noqa: C901
def update_query_adapter( # noqa: C901, PLR0915
*,
max_triplets: int = 4096,
max_triplets_per_eval: int = 64,
Expand All @@ -26,7 +26,7 @@ def update_query_adapter( # noqa: C901
Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.
If the nearest neighbour search uses the dot product as its ranking function, we can find the
If the nearest neighbour search uses the dot product as its relevance score, we can find the
optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
bound on the Frobenius norm of A:
Expand All @@ -40,7 +40,7 @@ def update_query_adapter( # noqa: C901
s.t. ||A||_F == 1
= M' / ||M||_F
If the nearest neighbour search uses the cosine similarity as its ranking function, we can find
If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
with an orthogonality constraint on A:
Expand All @@ -53,8 +53,17 @@ def update_query_adapter( # noqa: C901
trace[ Σ V' A U ]
s.t. A'A == 𝕀
= V U'
Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
"""
config = config or RAGLiteConfig()
config_no_query_adapter = RAGLiteConfig(**{**config.__dict__, "enable_query_adapter": False})
engine = create_database_engine(config.db_url)
with Session(engine) as session:
# Get random evals from the database.
Expand All @@ -75,7 +84,7 @@ def update_query_adapter( # noqa: C901
question_embedding = embed_strings([eval_.question], config=config)
# Retrieve chunks that would be used to answer the question.
chunk_rowids, _ = vector_search(
question_embedding, num_results=optimize_top_k, query_adapter=False, config=config
question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
)
retrieved_chunks = [
session.exec(select(Chunk).offset(chunk_rowid - 1)).first()
Expand All @@ -92,20 +101,27 @@ def update_query_adapter( # noqa: C901
raise ValueError(error_message)
# Select irrelevant chunks.
if retrieved_chunk.id not in eval_.chunk_ids:
# Look up all positive chunks that are ranked lower than this negative one.
p_mve = [
# Look up all positive chunks (each represented by the mean of its multi-vector
# embedding) that are ranked lower than this negative one (represented by the
# embedding in the multi-vector embedding that best matches the query).
p_mean = [
np.mean(chunk.multi_vector_embedding, axis=0, keepdims=True)
for chunk in retrieved_chunks[i + 1 :]
if chunk is not None and chunk.id in eval_.chunk_ids
]
if not p_mve:
n_top = retrieved_chunk.multi_vector_embedding[
np.argmax(retrieved_chunk.multi_vector_embedding @ question_embedding.T),
np.newaxis,
:,
]
# Filter out any (p, n, q) triplets for which the mean positive embedding ranks
# higher than the top negative one.
p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0]
if not p_mean:
continue
p = np.vstack(p_mve)
n = np.repeat(
np.mean(retrieved_chunk.multi_vector_embedding, axis=0, keepdims=True),
p.shape[0],
axis=0,
)
# Stack the (p, n, q) triplets.
p = np.vstack(p_mean)
n = np.repeat(n_top, p.shape[0], axis=0)
q = np.repeat(question_embedding, p.shape[0], axis=0)
num_triplets += p.shape[0]
# Append the (query, positive, negative) tuples to the Q, P, N matrices.
Expand All @@ -123,9 +139,15 @@ def update_query_adapter( # noqa: C901
Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806
P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806
N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806
# Compute the optimal query adapter A*.
# Compute the optimal weighted query adapter A*.
# TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
MT = (P - N).T @ Q # noqa: N806
gap_before = np.sum((P - N) * Q, axis=1)
gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
gap_target = 0.05 * gap_after
α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401
MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806
s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806
if config.vector_search_index_metric == "dot":
# Use the relaxed Procrustes solution.
A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806
Expand Down
3 changes: 1 addition & 2 deletions src/raglite/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def vector_search(
prompt: str | FloatMatrix,
*,
num_results: int = 3,
query_adapter: bool = True,
config: RAGLiteConfig | None = None,
) -> tuple[list[int], list[float]]:
"""Search chunks using ANN vector search."""
Expand All @@ -57,7 +56,7 @@ def vector_search(
else np.reshape(prompt, (1, -1))
)
# Apply the query adapter.
if query_adapter and Q is not None:
if config.enable_query_adapter and Q is not None:
prompt_embedding = (Q @ prompt_embedding[0, :])[np.newaxis, :].astype(config.embedder_dtype)
# Find the neighbouring multi-vector indices.
multi_vector_indices, cosine_distance = index.query(prompt_embedding, k=8 * num_results)
Expand Down

0 comments on commit eba82ce

Please sign in to comment.