From 2a149288451532c56e4a323d5ff8cd5ff3b29c20 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 12 Dec 2024 23:01:23 +0100 Subject: [PATCH] fix: Fix text and refactor hybrid search --- src/raglite/_chainlit.py | 8 -------- src/raglite/_config.py | 6 ------ src/raglite/_search.py | 21 ++++++++++++++++++++- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 6846473..78bcabd 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -10,19 +10,11 @@ RAGLiteConfig, async_rag, create_rag_instruction, - hybrid_search, insert_document, - rerank_chunks, - retrieve_chunk_spans, - retrieve_chunks, ) from raglite._markdown import document_to_markdown async_insert_document = cl.make_async(insert_document) -async_hybrid_search = cl.make_async(hybrid_search) -async_retrieve_chunks = cl.make_async(retrieve_chunks) -async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans) -async_rerank_chunks = cl.make_async(rerank_chunks) @cl.on_chat_start diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 411f172..8783601 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -14,9 +14,7 @@ from raglite._rag import retrieve_rag_context from raglite._search import ( hybrid_search, - keyword_search, rerank_chunks, - vector_search, ) if TYPE_CHECKING: @@ -34,10 +32,6 @@ max_chunk_spans=5, search=partial( hybrid_search, - subsearches=[ - partial(keyword_search, max_chunks=20), - partial(vector_search, max_chunks=20), - ], max_chunks=20, ), rerank=rerank_chunks, diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 89edb6c..2208d76 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -4,6 +4,7 @@ import string from collections import defaultdict from collections.abc import Sequence +from functools import partial from itertools import groupby from typing import TYPE_CHECKING, cast @@ -165,7 +166,7 @@ def reciprocal_rank_fusion( return list(rrf_chunk_ids), list(rrf_score) -def hybrid_search( +def multi_search( query: str, *, subsearches: list[ChunkSearchMethod], @@ -179,6 +180,24 @@ def hybrid_search( return chunk_ids[:max_chunks], hybrid_score[:max_chunks] +def hybrid_search( + query: str, + *, + max_chunks: int = 10, + oversample: int = 4, + config: "RAGLiteConfig", +) -> tuple[list[ChunkId], list[float]]: + """Search chunks by combining vector and keyword search.""" + subsearches: list[ChunkSearchMethod] = [ + partial(keyword_search, max_chunks=oversample * max_chunks), + partial(vector_search, max_chunks=oversample * max_chunks), + ] + chunk_ids, hybrid_score = multi_search( + query, subsearches=subsearches, max_chunks=max_chunks, config=config + ) + return chunk_ids[:max_chunks], hybrid_score[:max_chunks] + + def retrieve_chunks(chunk_ids: list[ChunkId], *, config: "RAGLiteConfig") -> list[Chunk]: """Retrieve chunks by their ids.""" engine = create_database_engine(config)