Skip to content

Commit

Permalink
feat: Simplify the api and minor improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Nov 26, 2024
1 parent 04d9eb5 commit 4e32914
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 60 deletions.
36 changes: 15 additions & 21 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,33 +355,26 @@ class ContextSegment:
chunks: list[Chunk]
chunk_scores: list[float]

def __post_init__(self) -> None:
"""Validate the segment data after initialization."""
if not isinstance(self.document_id, str) or not self.document_id.strip():
msg = "document_id must be a non-empty string"
raise ValueError(msg)
if not self.chunks:
msg = "chunks cannot be empty"
raise ValueError(msg)
if not all(isinstance(chunk, Chunk) for chunk in self.chunks):
msg = "all elements in chunks must be Chunk instances"
raise ValueError(msg)

def to_xml(self, indent: int = 4) -> str:
"""Convert the segment to an XML string representation.
def __str__(self) -> str:
"""Return a string representation of the segment."""
return self.as_xml

Args:
indent (int): Number of spaces to use for indentation.
@property
def as_xml(self) -> str:
"""Returns the segment as an XML string representation.
Returns
-------
str: XML representation of the segment.
"""
chunks_content = "\n".join(str(chunk) for chunk in self.chunks)

# Create the final XML
chunk_ids = ",".join(self.chunk_ids)
xml = f"""<document id="{escape(self.document_id)}" chunk_ids="{escape(chunk_ids)}">\n{escape(str(chunks_content))}\n</document>"""
xml = "\n".join(
[
f'<document id="{escape(self.document_id)}" chunk_ids="{escape(chunk_ids)}">',
escape(self.as_str),
"</document>",
]
)

return xml

Expand All @@ -394,7 +387,8 @@ def chunk_ids(self) -> list[str]:
"""Return a list of chunk IDs."""
return [chunk.id for chunk in self.chunks]

def __str__(self) -> str:
@property
def as_str(self) -> str:
"""Return a string representation reconstructing the document with headings.
Treats headings as a stack, showing headers only when they differ from
Expand Down
68 changes: 29 additions & 39 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Retrieval-augmented generation."""

from collections.abc import AsyncIterator, Iterator
from typing import Literal
from typing import cast

from litellm import acompletion, completion

Expand Down Expand Up @@ -88,21 +88,25 @@ def rag( # noqa: PLR0913
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search,
messages: list[dict[str, str]] | None = None,
system_prompt: str = RAG_SYSTEM_PROMPT,
config: RAGLiteConfig | None = None,
) -> Iterator[str]:
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search,
config=config,
)
segments: list[ContextSegment]
if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search):
segments = cast(list[ContextSegment], search)
else:
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search, # type: ignore[arg-type]
config=config,
)
# Stream the LLM response.
stream = completion(
model=config.llm,
Expand All @@ -121,21 +125,25 @@ async def async_rag( # noqa: PLR0913
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search,
messages: list[dict[str, str]] | None = None,
system_prompt: str = RAG_SYSTEM_PROMPT,
config: RAGLiteConfig | None = None,
) -> AsyncIterator[str]:
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search,
config=config,
)
segments: list[ContextSegment]
if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search):
segments = cast(list[ContextSegment], search)
else:
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search, # type: ignore[arg-type]
config=config,
)
messages = _compose_messages(
prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments
)
Expand All @@ -151,11 +159,8 @@ def _compose_messages(
system_prompt: str,
messages: list[dict[str, str]] | None,
segments: list[ContextSegment] | None,
context_placement: Literal[
"system_prompt", "user_prompt", "separate_system_prompt"
] = "user_prompt",
) -> list[dict[str, str]]:
"""Compose the messages for the LLM, placing the context in the desired position."""
"""Compose the messages for the LLM, placing the context in the user position."""
# Using the format recommended by Anthropic for documents in RAG
# (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts
if not segments:
Expand All @@ -164,26 +169,11 @@ def _compose_messages(
*(messages or []),
{"role": "user", "content": prompt},
]
context_content = (
"\n\n<documents>\n" + "\n\n".join(seg.to_xml() for seg in segments) + "\n</documents>"
)
if context_placement == "system_prompt":
return [
{"role": "system", "content": system_prompt + "\n\n" + context_content},
*(messages or []),
{"role": "user", "content": prompt},
]
if context_placement == "user_prompt":
return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "user", "content": prompt + "\n\n" + context_content},
]

# Separate system prompt from context
context_content = "<documents>\n" + "\n".join(str(seg) for seg in segments) + "\n</documents>"

return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "system", "content": context_content},
{"role": "user", "content": prompt},
{"role": "user", "content": prompt + "\n\n" + context_content},
]

0 comments on commit 4e32914

Please sign in to comment.