diff --git a/src/raglite/_config.py b/src/raglite/_config.py index bda5be7..e6c74c5 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -72,3 +72,5 @@ class RAGLiteConfig: # Vector search config. vector_search_index_id: str = "default" vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine". + # Query adapter config. + enable_query_adapter: bool = True diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index 51e08f8..bd3fed9 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -10,7 +10,7 @@ from raglite._search import vector_search -def update_query_adapter( # noqa: C901 +def update_query_adapter( # noqa: C901, PLR0915 *, max_triplets: int = 4096, max_triplets_per_eval: int = 64, @@ -26,7 +26,7 @@ def update_query_adapter( # noqa: C901 Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ. - If the nearest neighbour search uses the dot product as its ranking function, we can find the + If the nearest neighbour search uses the dot product as its relevance score, we can find the optimal query adapter by solving the following relaxed Procrustes optimisation problem with a bound on the Frobenius norm of A: @@ -40,7 +40,7 @@ def update_query_adapter( # noqa: C901 s.t. ||A||_F == 1 = M' / ||M||_F - If the nearest neighbour search uses the cosine similarity as its ranking function, we can find + If the nearest neighbour search uses the cosine similarity as its relevance score, we can find the optimal query adapter by solving the following orthogonal Procrustes optimisation problem with an orthogonality constraint on A: @@ -53,8 +53,17 @@ def update_query_adapter( # noqa: C901 trace[ Σ V' A U ] s.t. A'A == 𝕀 = V U' + + Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert + incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones. + To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀, + and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap + between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap + would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say + C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A). """ config = config or RAGLiteConfig() + config_no_query_adapter = RAGLiteConfig(**{**config.__dict__, "enable_query_adapter": False}) engine = create_database_engine(config.db_url) with Session(engine) as session: # Get random evals from the database. @@ -75,7 +84,7 @@ def update_query_adapter( # noqa: C901 question_embedding = embed_strings([eval_.question], config=config) # Retrieve chunks that would be used to answer the question. chunk_rowids, _ = vector_search( - question_embedding, num_results=optimize_top_k, query_adapter=False, config=config + question_embedding, num_results=optimize_top_k, config=config_no_query_adapter ) retrieved_chunks = [ session.exec(select(Chunk).offset(chunk_rowid - 1)).first() @@ -92,20 +101,27 @@ def update_query_adapter( # noqa: C901 raise ValueError(error_message) # Select irrelevant chunks. if retrieved_chunk.id not in eval_.chunk_ids: - # Look up all positive chunks that are ranked lower than this negative one. - p_mve = [ + # Look up all positive chunks (each represented by the mean of its multi-vector + # embedding) that are ranked lower than this negative one (represented by the + # embedding in the multi-vector embedding that best matches the query). + p_mean = [ np.mean(chunk.multi_vector_embedding, axis=0, keepdims=True) for chunk in retrieved_chunks[i + 1 :] if chunk is not None and chunk.id in eval_.chunk_ids ] - if not p_mve: + n_top = retrieved_chunk.multi_vector_embedding[ + np.argmax(retrieved_chunk.multi_vector_embedding @ question_embedding.T), + np.newaxis, + :, + ] + # Filter out any (p, n, q) triplets for which the mean positive embedding ranks + # higher than the top negative one. + p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0] + if not p_mean: continue - p = np.vstack(p_mve) - n = np.repeat( - np.mean(retrieved_chunk.multi_vector_embedding, axis=0, keepdims=True), - p.shape[0], - axis=0, - ) + # Stack the (p, n, q) triplets. + p = np.vstack(p_mean) + n = np.repeat(n_top, p.shape[0], axis=0) q = np.repeat(question_embedding, p.shape[0], axis=0) num_triplets += p.shape[0] # Append the (query, positive, negative) tuples to the Q, P, N matrices. @@ -123,9 +139,15 @@ def update_query_adapter( # noqa: C901 Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806 P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806 N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806 - # Compute the optimal query adapter A*. + # Compute the optimal weighted query adapter A*. # TODO: Matmul in float16 is extremely slow compared to single or double precision, why? - MT = (P - N).T @ Q # noqa: N806 + gap_before = np.sum((P - N) * Q, axis=1) + gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1) + gap_target = 0.05 * gap_after + α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401 + MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806 + s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0]) + MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806 if config.vector_search_index_metric == "dot": # Use the relaxed Procrustes solution. A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806 diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 4248929..73c2dd8 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -43,7 +43,6 @@ def vector_search( prompt: str | FloatMatrix, *, num_results: int = 3, - query_adapter: bool = True, config: RAGLiteConfig | None = None, ) -> tuple[list[int], list[float]]: """Search chunks using ANN vector search.""" @@ -57,7 +56,7 @@ def vector_search( else np.reshape(prompt, (1, -1)) ) # Apply the query adapter. - if query_adapter and Q is not None: + if config.enable_query_adapter and Q is not None: prompt_embedding = (Q @ prompt_embedding[0, :])[np.newaxis, :].astype(config.embedder_dtype) # Find the neighbouring multi-vector indices. multi_vector_indices, cosine_distance = index.query(prompt_embedding, k=8 * num_results)