From 4e3291409fc9b683263475fb0425efffcc567ebc Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Tue, 26 Nov 2024 19:47:03 +0100 Subject: [PATCH] feat: Simplify the api and minor improvements. --- src/raglite/_database.py | 36 +++++++++------------ src/raglite/_rag.py | 68 +++++++++++++++++----------------------- 2 files changed, 44 insertions(+), 60 deletions(-) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 9ed394d..e1a5f6c 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -355,33 +355,26 @@ class ContextSegment: chunks: list[Chunk] chunk_scores: list[float] - def __post_init__(self) -> None: - """Validate the segment data after initialization.""" - if not isinstance(self.document_id, str) or not self.document_id.strip(): - msg = "document_id must be a non-empty string" - raise ValueError(msg) - if not self.chunks: - msg = "chunks cannot be empty" - raise ValueError(msg) - if not all(isinstance(chunk, Chunk) for chunk in self.chunks): - msg = "all elements in chunks must be Chunk instances" - raise ValueError(msg) - - def to_xml(self, indent: int = 4) -> str: - """Convert the segment to an XML string representation. + def __str__(self) -> str: + """Return a string representation of the segment.""" + return self.as_xml - Args: - indent (int): Number of spaces to use for indentation. + @property + def as_xml(self) -> str: + """Returns the segment as an XML string representation. Returns ------- str: XML representation of the segment. """ - chunks_content = "\n".join(str(chunk) for chunk in self.chunks) - - # Create the final XML chunk_ids = ",".join(self.chunk_ids) - xml = f"""\n{escape(str(chunks_content))}\n""" + xml = "\n".join( + [ + f'', + escape(self.as_str), + "", + ] + ) return xml @@ -394,7 +387,8 @@ def chunk_ids(self) -> list[str]: """Return a list of chunk IDs.""" return [chunk.id for chunk in self.chunks] - def __str__(self) -> str: + @property + def as_str(self) -> str: """Return a string representation reconstructing the document with headings. Treats headings as a stack, showing headers only when they differ from diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 89a7cc0..90e7262 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,7 +1,7 @@ """Retrieval-augmented generation.""" from collections.abc import AsyncIterator, Iterator -from typing import Literal +from typing import cast from litellm import acompletion, completion @@ -88,7 +88,7 @@ def rag( # noqa: PLR0913 *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] = hybrid_search, + search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search, messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, config: RAGLiteConfig | None = None, @@ -96,13 +96,17 @@ def rag( # noqa: PLR0913 """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() - segments = context_segments( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - search=search, - config=config, - ) + segments: list[ContextSegment] + if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search): + segments = cast(list[ContextSegment], search) + else: + segments = context_segments( + prompt, + max_contexts=max_contexts, + context_neighbors=context_neighbors, + search=search, # type: ignore[arg-type] + config=config, + ) # Stream the LLM response. stream = completion( model=config.llm, @@ -121,7 +125,7 @@ async def async_rag( # noqa: PLR0913 *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] = hybrid_search, + search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search, messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, config: RAGLiteConfig | None = None, @@ -129,13 +133,17 @@ async def async_rag( # noqa: PLR0913 """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() - segments = context_segments( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - search=search, - config=config, - ) + segments: list[ContextSegment] + if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search): + segments = cast(list[ContextSegment], search) + else: + segments = context_segments( + prompt, + max_contexts=max_contexts, + context_neighbors=context_neighbors, + search=search, # type: ignore[arg-type] + config=config, + ) messages = _compose_messages( prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments ) @@ -151,11 +159,8 @@ def _compose_messages( system_prompt: str, messages: list[dict[str, str]] | None, segments: list[ContextSegment] | None, - context_placement: Literal[ - "system_prompt", "user_prompt", "separate_system_prompt" - ] = "user_prompt", ) -> list[dict[str, str]]: - """Compose the messages for the LLM, placing the context in the desired position.""" + """Compose the messages for the LLM, placing the context in the user position.""" # Using the format recommended by Anthropic for documents in RAG # (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts if not segments: @@ -164,26 +169,11 @@ def _compose_messages( *(messages or []), {"role": "user", "content": prompt}, ] - context_content = ( - "\n\n\n" + "\n\n".join(seg.to_xml() for seg in segments) + "\n" - ) - if context_placement == "system_prompt": - return [ - {"role": "system", "content": system_prompt + "\n\n" + context_content}, - *(messages or []), - {"role": "user", "content": prompt}, - ] - if context_placement == "user_prompt": - return [ - {"role": "system", "content": system_prompt}, - *(messages or []), - {"role": "user", "content": prompt + "\n\n" + context_content}, - ] - # Separate system prompt from context + context_content = "\n" + "\n".join(str(seg) for seg in segments) + "\n" + return [ {"role": "system", "content": system_prompt}, *(messages or []), - {"role": "system", "content": context_content}, - {"role": "user", "content": prompt}, + {"role": "user", "content": prompt + "\n\n" + context_content}, ]