-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,142 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
"cookiecutter": { | ||
"project_type": "package", | ||
"project_name": "RAGLite", | ||
"project_description": "A Python package for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL.", | ||
"project_description": "A Python toolkit for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL.", | ||
"project_url": "https://github.com/superlinear-ai/raglite", | ||
"author_name": "Laurent Sorber", | ||
"author_email": "[email protected]", | ||
|
@@ -26,4 +26,4 @@ | |
} | ||
}, | ||
"directory": null | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" | |
[tool.poetry] # https://python-poetry.org/docs/pyproject/ | ||
name = "raglite" | ||
version = "0.1.4" | ||
description = "A Python package for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL." | ||
description = "A Python toolkit for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL." | ||
authors = ["Laurent Sorber <[email protected]>"] | ||
readme = "README.md" | ||
repository = "https://github.com/superlinear-ai/raglite" | ||
|
@@ -37,7 +37,7 @@ llama-cpp-python = ">=0.2.88" | |
pydantic = ">=2.7.0" | ||
# Approximate Nearest Neighbors: | ||
pynndescent = ">=0.5.12" | ||
# Reranking | ||
# Reranking: | ||
langdetect = ">=1.0.9" | ||
rerankers = { extras = ["flashrank"], version = ">=0.5.3" } | ||
# Storage: | ||
|
@@ -48,8 +48,13 @@ tqdm = ">=4.66.0" | |
# Evaluation: | ||
pandas = ">=2.1.0" | ||
ragas = { version = ">=0.1.12", optional = true } | ||
# CLI: | ||
typer = ">=0.12.5" | ||
# Frontend: | ||
chainlit = { version = ">=1.2.0", optional = true } | ||
|
||
[tool.poetry.extras] # https://python-poetry.org/docs/pyproject/#extras | ||
chainlit = ["chainlit"] | ||
pandoc = ["pypandoc-binary"] | ||
ragas = ["ragas"] | ||
|
||
|
@@ -76,6 +81,9 @@ matplotlib = ">=3.9.0" | |
memory-profiler = ">=0.61.0" | ||
pdoc = ">=14.4.0" | ||
|
||
[tool.poetry.scripts] # https://python-poetry.org/docs/pyproject/#scripts | ||
raglite = "raglite:cli" | ||
|
||
[tool.coverage.report] # https://coverage.readthedocs.io/en/latest/config.html#report | ||
fail_under = 50 | ||
precision = 1 | ||
|
@@ -104,7 +112,7 @@ show_error_context = true | |
warn_unreachable = true | ||
|
||
[tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref | ||
addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml" | ||
addopts = "--color=yes --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml" | ||
filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::pytest.PytestUnraisableExceptionWarning"] | ||
testpaths = ["src", "tests"] | ||
xfail_strict = true | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Chainlit frontend for RAGLite.""" | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import chainlit as cl | ||
from chainlit.input_widget import Switch, TextInput | ||
|
||
from raglite import ( | ||
RAGLiteConfig, | ||
async_rag, | ||
hybrid_search, | ||
insert_document, | ||
rerank_chunks, | ||
retrieve_chunks, | ||
) | ||
from raglite._markdown import document_to_markdown | ||
|
||
async_insert_document = cl.make_async(insert_document) | ||
async_hybrid_search = cl.make_async(hybrid_search) | ||
async_retrieve_chunks = cl.make_async(retrieve_chunks) | ||
async_rerank_chunks = cl.make_async(rerank_chunks) | ||
|
||
|
||
@cl.on_chat_start | ||
async def start_chat() -> None: | ||
"""Initialize the chat.""" | ||
# Disable tokenizes parallelism to avoid the deadlock warning. | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
# Add Chainlit settings with which the user can configure the RAGLite config. | ||
default_config = RAGLiteConfig() | ||
config = RAGLiteConfig( | ||
db_url=os.environ.get("RAGLITE_DB_URL", default_config.db_url), | ||
llm=os.environ.get("RAGLITE_LLM", default_config.llm), | ||
embedder=os.environ.get("RAGLITE_EMBEDDER", default_config.embedder), | ||
) | ||
settings = await cl.ChatSettings( # type: ignore[no-untyped-call] | ||
[ | ||
TextInput(id="db_url", label="Database URL", initial=str(config.db_url)), | ||
TextInput(id="llm", label="LLM", initial=config.llm), | ||
TextInput(id="embedder", label="Embedder", initial=config.embedder), | ||
Switch(id="vector_search_query_adapter", label="Query adapter", initial=True), | ||
] | ||
).send() | ||
await update_config(settings) | ||
|
||
|
||
@cl.on_settings_update # type: ignore[arg-type] | ||
async def update_config(settings: cl.ChatSettings) -> None: | ||
"""Update the RAGLite config.""" | ||
# Update the RAGLite config given the Chainlit settings. | ||
config = RAGLiteConfig( | ||
db_url=settings["db_url"], # type: ignore[index] | ||
llm=settings["llm"], # type: ignore[index] | ||
embedder=settings["embedder"], # type: ignore[index] | ||
vector_search_query_adapter=settings["vector_search_query_adapter"], # type: ignore[index] | ||
) | ||
cl.user_session.set("config", config) # type: ignore[no-untyped-call] | ||
# Run a search to prime the pipeline if it's a local pipeline. | ||
# TODO: Don't do this for SQLite once we switch from PyNNDescent to sqlite-vec. | ||
if str(config.db_url).startswith("sqlite") or config.embedder.startswith("llama-cpp-python"): | ||
# async with cl.Step(name="initialize", type="retrieval"): | ||
query = "Hello world" | ||
chunk_ids, _ = await async_hybrid_search(query=query, config=config) | ||
_ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config) | ||
|
||
|
||
@cl.on_message | ||
async def handle_message(user_message: cl.Message) -> None: | ||
"""Respond to a user message.""" | ||
# Get the config and message history from the user session. | ||
config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call] | ||
# Determine what to do with the attachments. | ||
inline_attachments = [] | ||
for file in user_message.elements: | ||
if file.path: | ||
doc_md = document_to_markdown(Path(file.path)) | ||
if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3): | ||
# Document is small enough to attach to the context. | ||
inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}") | ||
else: | ||
# Document is too large and must be inserted into the database. | ||
async with cl.Step(name="insert", type="run") as step: | ||
step.input = Path(file.path).name | ||
await async_insert_document(Path(file.path), config=config) | ||
# Append any inline attachments to the user prompt. | ||
user_prompt = f"{user_message.content}\n\n" + "\n\n".join( | ||
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>' | ||
for i, attachment in enumerate(inline_attachments) | ||
) | ||
# Search for relevant contexts for RAG. | ||
async with cl.Step(name="search", type="retrieval") as step: | ||
step.input = user_message.content | ||
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config) | ||
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config) | ||
step.output = chunks | ||
step.elements = [ # Show the top 3 chunks inline. | ||
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3] | ||
] | ||
# Rerank the chunks. | ||
async with cl.Step(name="rerank", type="rerank") as step: | ||
step.input = chunks | ||
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config) | ||
step.output = chunks | ||
step.elements = [ # Show the top 3 chunks inline. | ||
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3] | ||
] | ||
# Stream the LLM response. | ||
assistant_message = cl.Message(content="") | ||
async for token in async_rag( | ||
prompt=user_prompt, | ||
search=chunks, | ||
messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call] | ||
config=config, | ||
): | ||
await assistant_message.stream_token(token) | ||
await assistant_message.update() # type: ignore[no-untyped-call] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""RAGLite CLI.""" | ||
|
||
import os | ||
|
||
import typer | ||
|
||
from raglite._config import RAGLiteConfig | ||
|
||
cli = typer.Typer() | ||
|
||
|
||
@cli.callback() | ||
def main() -> None: | ||
"""RAGLite CLI.""" | ||
|
||
|
||
@cli.command() | ||
def chainlit( | ||
db_url: str = typer.Option(RAGLiteConfig().db_url, help="Database URL"), | ||
llm: str = typer.Option(RAGLiteConfig().llm, help="LiteLLM LLM"), | ||
embedder: str = typer.Option(RAGLiteConfig().embedder, help="LiteLLM embedder"), | ||
) -> None: | ||
"""Serve a Chainlit frontend.""" | ||
# Set the environment variables for the Chainlit frontend. | ||
os.environ["RAGLITE_DB_URL"] = os.environ.get("RAGLITE_DB_URL", db_url) | ||
os.environ["RAGLITE_LLM"] = os.environ.get("RAGLITE_LLM", llm) | ||
os.environ["RAGLITE_EMBEDDER"] = os.environ.get("RAGLITE_EMBEDDER", embedder) | ||
# Import Chainlit here as it's an optional dependency. | ||
try: | ||
from chainlit.cli import run_chainlit | ||
except ImportError as error: | ||
error_message = "To serve a Chainlit frontend, please install the `chainlit` extra." | ||
raise ImportError(error_message) from error | ||
# Serve the frontend. | ||
run_chainlit(__file__.replace("_cli.py", "_chainlit.py")) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.