Skip to content

Commit

Permalink
fix: Fix text and refactor hybrid search
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Dec 12, 2024
1 parent 501063b commit 2a14928
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
8 changes: 0 additions & 8 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,11 @@
RAGLiteConfig,
async_rag,
create_rag_instruction,
hybrid_search,
insert_document,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
)
from raglite._markdown import document_to_markdown

async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
async_retrieve_chunks = cl.make_async(retrieve_chunks)
async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)


@cl.on_chat_start
Expand Down
6 changes: 0 additions & 6 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from raglite._rag import retrieve_rag_context
from raglite._search import (
hybrid_search,
keyword_search,
rerank_chunks,
vector_search,
)

if TYPE_CHECKING:
Expand All @@ -34,10 +32,6 @@
max_chunk_spans=5,
search=partial(
hybrid_search,
subsearches=[
partial(keyword_search, max_chunks=20),
partial(vector_search, max_chunks=20),
],
max_chunks=20,
),
rerank=rerank_chunks,
Expand Down
21 changes: 20 additions & 1 deletion src/raglite/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import string
from collections import defaultdict
from collections.abc import Sequence
from functools import partial
from itertools import groupby
from typing import TYPE_CHECKING, cast

Expand Down Expand Up @@ -165,7 +166,7 @@ def reciprocal_rank_fusion(
return list(rrf_chunk_ids), list(rrf_score)


def hybrid_search(
def multi_search(
query: str,
*,
subsearches: list[ChunkSearchMethod],
Expand All @@ -179,6 +180,24 @@ def hybrid_search(
return chunk_ids[:max_chunks], hybrid_score[:max_chunks]


def hybrid_search(
query: str,
*,
max_chunks: int = 10,
oversample: int = 4,
config: "RAGLiteConfig",
) -> tuple[list[ChunkId], list[float]]:
"""Search chunks by combining vector and keyword search."""
subsearches: list[ChunkSearchMethod] = [
partial(keyword_search, max_chunks=oversample * max_chunks),
partial(vector_search, max_chunks=oversample * max_chunks),
]
chunk_ids, hybrid_score = multi_search(
query, subsearches=subsearches, max_chunks=max_chunks, config=config
)
return chunk_ids[:max_chunks], hybrid_score[:max_chunks]


def retrieve_chunks(chunk_ids: list[ChunkId], *, config: "RAGLiteConfig") -> list[Chunk]:
"""Retrieve chunks by their ids."""
engine = create_database_engine(config)
Expand Down

0 comments on commit 2a14928

Please sign in to comment.