diff --git a/src/raglite/_database.py b/src/raglite/_database.py
index 9ed394d..e1a5f6c 100644
--- a/src/raglite/_database.py
+++ b/src/raglite/_database.py
@@ -355,33 +355,26 @@ class ContextSegment:
chunks: list[Chunk]
chunk_scores: list[float]
- def __post_init__(self) -> None:
- """Validate the segment data after initialization."""
- if not isinstance(self.document_id, str) or not self.document_id.strip():
- msg = "document_id must be a non-empty string"
- raise ValueError(msg)
- if not self.chunks:
- msg = "chunks cannot be empty"
- raise ValueError(msg)
- if not all(isinstance(chunk, Chunk) for chunk in self.chunks):
- msg = "all elements in chunks must be Chunk instances"
- raise ValueError(msg)
-
- def to_xml(self, indent: int = 4) -> str:
- """Convert the segment to an XML string representation.
+ def __str__(self) -> str:
+ """Return a string representation of the segment."""
+ return self.as_xml
- Args:
- indent (int): Number of spaces to use for indentation.
+ @property
+ def as_xml(self) -> str:
+ """Returns the segment as an XML string representation.
Returns
-------
str: XML representation of the segment.
"""
- chunks_content = "\n".join(str(chunk) for chunk in self.chunks)
-
- # Create the final XML
chunk_ids = ",".join(self.chunk_ids)
- xml = f"""\n{escape(str(chunks_content))}\n"""
+ xml = "\n".join(
+ [
+ f'',
+ escape(self.as_str),
+ "",
+ ]
+ )
return xml
@@ -394,7 +387,8 @@ def chunk_ids(self) -> list[str]:
"""Return a list of chunk IDs."""
return [chunk.id for chunk in self.chunks]
- def __str__(self) -> str:
+ @property
+ def as_str(self) -> str:
"""Return a string representation reconstructing the document with headings.
Treats headings as a stack, showing headers only when they differ from
diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py
index 89a7cc0..90e7262 100644
--- a/src/raglite/_rag.py
+++ b/src/raglite/_rag.py
@@ -1,7 +1,7 @@
"""Retrieval-augmented generation."""
from collections.abc import AsyncIterator, Iterator
-from typing import Literal
+from typing import cast
from litellm import acompletion, completion
@@ -88,7 +88,7 @@ def rag( # noqa: PLR0913
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
- search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
+ search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search,
messages: list[dict[str, str]] | None = None,
system_prompt: str = RAG_SYSTEM_PROMPT,
config: RAGLiteConfig | None = None,
@@ -96,13 +96,17 @@ def rag( # noqa: PLR0913
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
- segments = context_segments(
- prompt,
- max_contexts=max_contexts,
- context_neighbors=context_neighbors,
- search=search,
- config=config,
- )
+ segments: list[ContextSegment]
+ if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search):
+ segments = cast(list[ContextSegment], search)
+ else:
+ segments = context_segments(
+ prompt,
+ max_contexts=max_contexts,
+ context_neighbors=context_neighbors,
+ search=search, # type: ignore[arg-type]
+ config=config,
+ )
# Stream the LLM response.
stream = completion(
model=config.llm,
@@ -121,7 +125,7 @@ async def async_rag( # noqa: PLR0913
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
- search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
+ search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search,
messages: list[dict[str, str]] | None = None,
system_prompt: str = RAG_SYSTEM_PROMPT,
config: RAGLiteConfig | None = None,
@@ -129,13 +133,17 @@ async def async_rag( # noqa: PLR0913
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
- segments = context_segments(
- prompt,
- max_contexts=max_contexts,
- context_neighbors=context_neighbors,
- search=search,
- config=config,
- )
+ segments: list[ContextSegment]
+ if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search):
+ segments = cast(list[ContextSegment], search)
+ else:
+ segments = context_segments(
+ prompt,
+ max_contexts=max_contexts,
+ context_neighbors=context_neighbors,
+ search=search, # type: ignore[arg-type]
+ config=config,
+ )
messages = _compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments
)
@@ -151,11 +159,8 @@ def _compose_messages(
system_prompt: str,
messages: list[dict[str, str]] | None,
segments: list[ContextSegment] | None,
- context_placement: Literal[
- "system_prompt", "user_prompt", "separate_system_prompt"
- ] = "user_prompt",
) -> list[dict[str, str]]:
- """Compose the messages for the LLM, placing the context in the desired position."""
+ """Compose the messages for the LLM, placing the context in the user position."""
# Using the format recommended by Anthropic for documents in RAG
# (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts
if not segments:
@@ -164,26 +169,11 @@ def _compose_messages(
*(messages or []),
{"role": "user", "content": prompt},
]
- context_content = (
- "\n\n\n" + "\n\n".join(seg.to_xml() for seg in segments) + "\n"
- )
- if context_placement == "system_prompt":
- return [
- {"role": "system", "content": system_prompt + "\n\n" + context_content},
- *(messages or []),
- {"role": "user", "content": prompt},
- ]
- if context_placement == "user_prompt":
- return [
- {"role": "system", "content": system_prompt},
- *(messages or []),
- {"role": "user", "content": prompt + "\n\n" + context_content},
- ]
- # Separate system prompt from context
+ context_content = "\n" + "\n".join(str(seg) for seg in segments) + "\n"
+
return [
{"role": "system", "content": system_prompt},
*(messages or []),
- {"role": "system", "content": context_content},
- {"role": "user", "content": prompt},
+ {"role": "user", "content": prompt + "\n\n" + context_content},
]