Skip to content

Commit

Permalink
fix: Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Dec 12, 2024
1 parent 93e0495 commit 6002da1
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 180 deletions.
22 changes: 6 additions & 16 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ async def update_config(settings: cl.ChatSettings) -> None:
if str(config.db_url).startswith("sqlite") or config.embedder.startswith("llama-cpp-python"):
# async with cl.Step(name="initialize", type="retrieval"):
query = "Hello world"
chunk_ids, _ = await async_hybrid_search(query=query, config=config)
_ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config)
config.retrieval(query=query, config=config)


@cl.on_message
Expand Down Expand Up @@ -94,21 +93,12 @@ async def handle_message(user_message: cl.Message) -> None:
)
+ f"\n\n{user_message.content}"
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:

# Retrieve the context for RAG.
async with cl.Step(name="retrieval", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(query=user_prompt, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
]
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Rerank the chunks and group them into chunk spans.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
retrieval = cl.make_async(config.retrieval)
chunk_spans = await retrieval(query=user_prompt, config=config)
step.output = chunk_spans
step.elements = [ # Show the top chunk spans inline.
cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
Expand Down
33 changes: 23 additions & 10 deletions src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
from sqlalchemy.engine import URL

from raglite._prompts import RAG_INSTRUCTION_TEMPLATE
from raglite._rag import retrieve_rag_context
from raglite._search import (
hybrid_search,
keyword_search,
rerank_chunks,
vector_search,
)

if TYPE_CHECKING:
from raglite._typing import SearchMethod
from raglite._typing import ChunkSpanSearchMethod

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
Expand All @@ -22,11 +29,20 @@
from rerankers.models.ranker import BaseRanker


def _default_search_method() -> "SearchMethod":
"""Get the default search method."""
from raglite._search import hybrid_search

return partial(hybrid_search, oversample=4)
default_retrieval: "ChunkSpanSearchMethod" = partial(
retrieve_rag_context,
max_chunk_spans=5,
search=partial(
hybrid_search,
subsearches=[
partial(keyword_search, max_chunks=20),
partial(vector_search, max_chunks=20),
],
max_chunks=20,
),
rerank=rerank_chunks,
chunk_neighbors=(-1, 1),
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -67,12 +83,9 @@ class RAGLiteConfig:
),
compare=False, # Exclude the reranker from comparison to avoid lru_cache misses.
)
search_method: "SearchMethod" = field(default_factory=_default_search_method, compare=False)
retrieval: "ChunkSpanSearchMethod" = default_retrieval
system_prompt: str | None = None
rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE
num_chunks: int = 5
chunk_neighbors: tuple[int, ...] | None = (-1, 1) # Neighbors to include in the context.
reranker_oversample: int = 4 # How many extra chunks to retrieve for reranking (multiplied).

def __post_init__(self) -> None:
# Late chunking with llama-cpp-python does not apply sentence windowing.
Expand Down
25 changes: 14 additions & 11 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
from xml.sax.saxutils import escape

import numpy as np
Expand All @@ -16,7 +16,6 @@
from sqlalchemy.engine import Engine, make_url
from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text

from raglite._config import RAGLiteConfig
from raglite._litellm import get_embedding_dim
from raglite._typing import (
ChunkId,
Expand All @@ -29,6 +28,9 @@
PickledObject,
)

if TYPE_CHECKING:
from raglite._config import RAGLiteConfig


def hash_bytes(data: bytes, max_len: int = 16) -> str:
"""Hash bytes to a hexadecimal string."""
Expand Down Expand Up @@ -236,7 +238,7 @@ class IndexMetadata(SQLModel, table=True):

@staticmethod
@lru_cache(maxsize=4)
def _get(id_: str, *, config: RAGLiteConfig | None = None) -> dict[str, Any] | None:
def _get(id_: str, *, config: "RAGLiteConfig") -> dict[str, Any] | None:
engine = create_database_engine(config)
with Session(engine) as session:
index_metadata_record = session.get(IndexMetadata, id_)
Expand All @@ -245,7 +247,7 @@ def _get(id_: str, *, config: RAGLiteConfig | None = None) -> dict[str, Any] | N
return index_metadata_record.metadata_

@staticmethod
def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]:
def get(id_: str = "default", *, config: "RAGLiteConfig") -> dict[str, Any]:
metadata = IndexMetadata._get(id_, config=config) or {}
return metadata

Expand All @@ -271,18 +273,20 @@ class Eval(SQLModel, table=True):
document: Document = Relationship(back_populates="evals")

@staticmethod
def from_chunks(
question: str, contexts: list[Chunk], ground_truth: str, **kwargs: Any
def from_contexts(
question: str, contexts: list[ChunkSpan], ground_truth: str, **kwargs: Any
) -> "Eval":
"""Create a chunk from Markdown."""
document_id = contexts[0].document_id
chunk_ids = [context.id for context in contexts]
document_id = contexts[0].document.id
chunk_ids = [
chunk.id for span in contexts for chunk in span.chunks
] # Should we take out the neighbors?
return Eval(
id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()),
document_id=document_id,
chunk_ids=chunk_ids,
question=question,
contexts=[str(context) for context in contexts],
contexts=contexts,
ground_truth=ground_truth,
metadata_=kwargs,
)
Expand All @@ -301,10 +305,9 @@ def _get_pgvector_version(session: Session) -> str | None:


@lru_cache(maxsize=1)
def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
def create_database_engine(config: "RAGLiteConfig") -> Engine:
"""Create a database engine and initialize it."""
# Parse the database URL and validate that the database backend is supported.
config = config or RAGLiteConfig()
db_url = make_url(config.db_url)
db_backend = db_url.get_backend_name()
# Update database configuration.
Expand Down
20 changes: 9 additions & 11 deletions src/raglite/_embed.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""String embedder."""

from functools import partial
from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np
from litellm import embedding
from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama
from tqdm.auto import tqdm, trange

from raglite._config import RAGLiteConfig
from raglite._litellm import LlamaCppPythonLLM
from raglite._typing import FloatMatrix, IntVector

if TYPE_CHECKING:
from raglite._config import RAGLiteConfig


def _embed_sentences_with_late_chunking( # noqa: PLR0915
sentences: list[str], *, config: RAGLiteConfig | None = None
sentences: list[str], *, config: "RAGLiteConfig"
) -> FloatMatrix:
"""Embed a document's sentences with late chunking."""

Expand Down Expand Up @@ -59,7 +61,6 @@ def _create_segment(

# Assert that we're using a llama-cpp-python model, since API-based embedding models don't
# support outputting token-level embeddings.
config = config or RAGLiteConfig()
assert config.embedder.startswith("llama-cpp-python")
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
Expand Down Expand Up @@ -138,11 +139,11 @@ def _create_segment(


def _embed_sentences_with_windowing(
sentences: list[str], *, config: RAGLiteConfig | None = None
sentences: list[str], *, config: "RAGLiteConfig"
) -> FloatMatrix:
"""Embed a document's sentences with windowing."""

def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
def _embed_string_batch(string_batch: list[str], *, config: "RAGLiteConfig") -> FloatMatrix:
# Embed the batch of strings.
if config.embedder.startswith("llama-cpp-python"):
# LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
Expand All @@ -164,7 +165,6 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl
return embeddings

# Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences.
config = config or RAGLiteConfig()
sentence_windows = [
"".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1])
for i in range(len(sentences))
Expand All @@ -186,16 +186,14 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl

def sentence_embedding_type(
*,
config: RAGLiteConfig | None = None,
config: "RAGLiteConfig",
) -> Literal["late_chunking", "windowing"]:
"""Return the type of sentence embeddings."""
config = config or RAGLiteConfig()
return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing"


def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
def embed_sentences(sentences: list[str], *, config: "RAGLiteConfig") -> FloatMatrix:
"""Embed the sentences of a document as a NumPy matrix with one row per sentence."""
config = config or RAGLiteConfig()
if sentence_embedding_type(config=config) == "late_chunking":
sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config)
else:
Expand Down
Loading

0 comments on commit 6002da1

Please sign in to comment.