From 0b4b919e108f6c32fcfc106f68260b7ac6ba6d00 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 7 Oct 2024 07:45:26 +0200 Subject: [PATCH] feat: add reranking (#20) --- .cruft.json | 4 +- .devcontainer/devcontainer.json | 6 +- .gitignore | 3 + Dockerfile | 1 + README.md | 89 ++++++++++--- poetry.lock | 221 +++++++++++++++++++++++++++++++- pyproject.toml | 8 +- src/raglite/__init__.py | 6 +- src/raglite/_config.py | 43 +++++-- src/raglite/_embed.py | 9 +- src/raglite/_eval.py | 6 +- src/raglite/_litellm.py | 5 + src/raglite/_rag.py | 25 +++- src/raglite/_search.py | 177 ++++++++++++------------- src/raglite/_typing.py | 10 +- tests/conftest.py | 4 +- tests/test_embed.py | 2 +- 17 files changed, 471 insertions(+), 148 deletions(-) diff --git a/.cruft.json b/.cruft.json index a44ca99..28789e3 100644 --- a/.cruft.json +++ b/.cruft.json @@ -1,6 +1,6 @@ { "template": "https://github.com/superlinear-ai/poetry-cookiecutter", - "commit": "a969f1d182ec39d7d27ccb1116cf60ba736adcfa", + "commit": "b7f2fb0f123aae0a01d2ab015db31f52d2d8cc21", "checkout": null, "context": { "cookiecutter": { @@ -26,4 +26,4 @@ } }, "directory": null -} \ No newline at end of file +} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 1c3d4b4..7d2fbb1 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -38,7 +38,9 @@ 100 ], "files.autoSave": "onFocusChange", - "jupyter.kernels.excludePythonEnvironments": ["/usr/local/bin/python"], + "jupyter.kernels.excludePythonEnvironments": [ + "/usr/local/bin/python" + ], "mypy-type-checker.importStrategy": "fromEnvironment", "mypy-type-checker.preferDaemon": true, "notebook.codeActionsOnSave": { @@ -50,7 +52,7 @@ "python.terminal.activateEnvironment": false, "python.testing.pytestEnabled": true, "ruff.importStrategy": "fromEnvironment", - "ruff.logLevel": "warn", + "ruff.logLevel": "warning", "terminal.integrated.defaultProfile.linux": "zsh", "terminal.integrated.profiles.linux": { "zsh": { diff --git a/.gitignore b/.gitignore index 3b903be..e2d2da3 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,9 @@ data/ # dotenv .env +# Rerankers +.*_cache/ + # Hypothesis .hypothesis/ diff --git a/Dockerfile b/Dockerfile index a7260de..04c115e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -70,6 +70,7 @@ RUN --mount=type=cache,target=/var/cache/apt/ \ sh -c "$(curl -fsSL https://starship.rs/install.sh)" -- "--yes" && \ usermod --shell /usr/bin/zsh user && \ echo 'user ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/user && chmod 0440 /etc/sudoers.d/user +RUN git config --system --add safe.directory '*' USER user # Install the development Python dependencies in the virtual environment. diff --git a/README.md b/README.md index 21a9a09..5a4febc 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/superlinear-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=812973394&skip_quickstart=true) +[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/superlinear-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new/superlinear-ai/raglite) # 🥤 RAGLite @@ -6,17 +6,31 @@ RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with Postgr ## Features -1. ❤️ Only lightweight and permissive open source dependencies (e.g., no [PyTorch](https://github.com/pytorch/pytorch), [LangChain](https://github.com/langchain-ai/langchain), or [PyMuPDF](https://github.com/pymupdf/PyMuPDF)) -2. 🧠 Choose any LLM provider with [LiteLLM](https://github.com/BerriAI/litellm), including local [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) models -3. 💾 Either [PostgreSQL](https://github.com/postgres/postgres) or [SQLite](https://github.com/sqlite/sqlite) as a keyword & vector search database -4. 🚀 Acceleration with Metal on macOS, and CUDA on Linux and Windows -5. 📖 PDF to Markdown conversion on top of [pdftext](https://github.com/VikParuchuri/pdftext) and [pypdfium2](https://github.com/pypdfium2-team/pypdfium2) -6. 🧬 Multi-vector chunk embedding with [late chunking](https://weaviate.io/blog/late-chunking) and [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag) -7. ✂️ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming) -8. 🌀 Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem) -9. 🔍 [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) that combines the database's built-in keyword search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html) in PostgreSQL, [FTS5](https://www.sqlite.org/fts5.html) in SQLite) with their native vector search extensions ([pgvector](https://github.com/pgvector/pgvector) in PostgreSQL, [sqlite-vec](https://github.com/asg017/sqlite-vec) in SQLite) -10. ✍️ Optional: conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc) -11. ✅ Optional: evaluation of retrieval and generation performance with [Ragas](https://github.com/explodinggradients/ragas) +##### Configurable + +- 🧠 Choose any LLM provider with [LiteLLM](https://github.com/BerriAI/litellm), including local [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) models +- 💾 Choose either [PostgreSQL](https://github.com/postgres/postgres) or [SQLite](https://github.com/sqlite/sqlite) as a keyword & vector search database +- 🥇 Choose any reranker with [rerankers](https://github.com/AnswerDotAI/rerankers), including multi-lingual [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) as the default + +##### Fast and permissive + +- ❤️ Only lightweight and permissive open source dependencies (e.g., no [PyTorch](https://github.com/pytorch/pytorch) or [LangChain](https://github.com/langchain-ai/langchain)) +- 🚀 Acceleration with Metal on macOS, and CUDA on Linux and Windows + +##### Unhobbled + +- 📖 PDF to Markdown conversion on top of [pdftext](https://github.com/VikParuchuri/pdftext) and [pypdfium2](https://github.com/pypdfium2-team/pypdfium2) +- 🧬 Multi-vector chunk embedding with [late chunking](https://weaviate.io/blog/late-chunking) and [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag) +- ✂️ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming) +- 🔍 [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) with the database's native keyword & vector search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html)+[pgvector](https://github.com/pgvector/pgvector), [FTS5](https://www.sqlite.org/fts5.html)+[sqlite-vec](https://github.com/asg017/sqlite-vec)[^1]) +- 🌀 Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem) + +##### Extensible + +- ✍️ Optional conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc) +- ✅ Optional evaluation of retrieval and generation performance with [Ragas](https://github.com/explodinggradients/ragas) + +[^1]: We use [PyNNDescent](https://github.com/lmcinnes/pynndescent) until [sqlite-vec](https://github.com/asg017/sqlite-vec) is more mature. ## Installing @@ -57,10 +71,10 @@ pip install raglite[ragas] ### 1. Configuring RAGLite > [!TIP] -> 🧠 RAGLite extends [LiteLLM](https://github.com/BerriAI/litellm) with support for [llama.cpp](https://github.com/ggerganov/llama.cpp) models using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python). To select a llama.cpp model (e.g., from [bartowski's collection](https://huggingface.co/collections/bartowski/recent-highlights-65cf8e08f8ab7fc669d7b5bd)), use a model identifier of the form `"llama-cpp-python//@"`, where `n_ctx` is an optional parameter that specifies the context size of the model. +> 🧠 RAGLite extends [LiteLLM](https://github.com/BerriAI/litellm) with support for [llama.cpp](https://github.com/ggerganov/llama.cpp) models using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python). To select a llama.cpp model (e.g., from [bartowski's collection](https://huggingface.co/bartowski)), use a model identifier of the form `"llama-cpp-python//@"`, where `n_ctx` is an optional parameter that specifies the context size of the model. > [!TIP] -> 💾 You can create a PostgreSQL database for free in a few clicks at [neon.tech](https://neon.tech) (not sponsored). +> 💾 You can create a PostgreSQL database in a few clicks at [neon.tech](https://neon.tech). First, configure RAGLite with your preferred PostgreSQL or SQLite database and [any LLM supported by LiteLLM](https://docs.litellm.ai/docs/providers/openai): @@ -82,6 +96,27 @@ my_config = RAGLiteConfig( ) ``` +You can also configure [any reranker supported by rerankers](https://github.com/AnswerDotAI/rerankers): + +```python +from rerankers import Reranker + +# Example remote API-based reranker: +my_config = RAGLiteConfig( + db_url="postgresql://my_username:my_password@my_host:5432/my_database" + reranker=Reranker("cohere", lang="en", api_key=COHERE_API_KEY) +) + +# Example local cross-encoder reranker per language (this is the default): +my_config = RAGLiteConfig( + db_url="sqlite:///raglite.sqlite", + reranker=( + ("en", Reranker("ms-marco-MiniLM-L-12-v2", model_type="flashrank")), # English + ("other", Reranker("ms-marco-MultiBERT-L-12", model_type="flashrank")), # Other languages + ) +) +``` + ### 2. Inserting documents > [!TIP] @@ -100,24 +135,38 @@ insert_document(Path("Special Relativity.pdf"), config=my_config) ### 3. Searching and Retrieval-Augmented Generation (RAG) -Now, you can search for chunks with keyword search, vector search, or a hybrid of the two. You can also answer questions with RAG and the search method of your choice (`hybrid` is the default): +Now, you can search for chunks with vector search, keyword search, or a hybrid of the two. You can also rerank the search results with the configured reranker. And you can use any search method of your choice (`hybrid_search` is the default) together with reranking to answer questions with RAG: ```python # Search for chunks: from raglite import hybrid_search, keyword_search, vector_search prompt = "How is intelligence measured?" -results_vector = vector_search(prompt, num_results=5, config=my_config) -results_keyword = keyword_search(prompt, num_results=5, config=my_config) -results_hybrid = hybrid_search(prompt, num_results=5, config=my_config) +chunk_ids_vector, _ = vector_search(prompt, num_results=20, config=my_config) +chunk_ids_keyword, _ = keyword_search(prompt, num_results=20, config=my_config) +chunk_ids_hybrid, _ = hybrid_search(prompt, num_results=20, config=my_config) + +# Retrieve chunks: +from raglite import retrieve_chunks + +chunks_hybrid = retrieve_chunks(chunk_ids_hybrid, config=my_config) + +# Rerank chunks: +from raglite import rerank + +chunks_reranked = rerank(prompt, chunks_hybrid, config=my_config) # Answer questions with RAG: from raglite import rag prompt = "What does it mean for two events to be simultaneous?" -stream = rag(prompt, search=hybrid_search, config=my_config) +stream = rag(prompt, config=my_config) for update in stream: print(update, end="") + +# You can also pass a search method or search results directly: +stream = rag(prompt, search=hybrid_search, config=my_config) +stream = rag(prompt, search=chunks_reranked, config=my_config) ``` ### 4. Computing and using an optimal query adapter @@ -129,7 +178,7 @@ RAGLite can compute and apply an [optimal closed-form query adapter](src/raglite from raglite import insert_evals, update_query_adapter insert_evals(num_evals=100, config=my_config) -update_query_adapter(config=my_config) +update_query_adapter(config=my_config) # From here, simply call vector_search to use the query adapter. ``` ### 5. Evaluation of retrieval and generation diff --git a/poetry.lock b/poetry.lock index dc2c0f5..dd0bf9f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -610,6 +610,23 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coloredlogs" +version = "15.0.1" +description = "Colored terminal output for Python's logging module" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, + {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"}, +] + +[package.dependencies] +humanfriendly = ">=9.1" + +[package.extras] +cron = ["capturer (>=2.4)"] + [[package]] name = "comm" version = "0.2.2" @@ -629,13 +646,13 @@ test = ["pytest"] [[package]] name = "commitizen" -version = "3.29.0" +version = "3.29.1" description = "Python commitizen client tool" optional = false python-versions = ">=3.8" files = [ - {file = "commitizen-3.29.0-py3-none-any.whl", hash = "sha256:0c6c479dbee6d19292315c6fca3782cf5c1f7f1638bc4bb5ab4cfb67f4e11894"}, - {file = "commitizen-3.29.0.tar.gz", hash = "sha256:586b30c1976850d244b836cd4730771097ba362c9c1684d1f8c379176c2ea532"}, + {file = "commitizen-3.29.1-py3-none-any.whl", hash = "sha256:83f6563fae6a6262238e4424c55db5743eaa9827d2044dc23719466e4e78a0ca"}, + {file = "commitizen-3.29.1.tar.gz", hash = "sha256:b9a56190f4f3b20c73600e5ba448c7b81e0e6f87be3092aec1db4de75bf0fa91"}, ] [package.dependencies] @@ -1187,6 +1204,38 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "flashrank" +version = "0.2.9" +description = "Ultra lite & Super fast SoTA cross-encoder based re-ranking for your search & retrieval pipelines." +optional = false +python-versions = ">=3.6" +files = [ + {file = "FlashRank-0.2.9-py3-none-any.whl", hash = "sha256:4e43e0ccb95f143bb6eaf9bde74b9bd7159fd2161116eba4c0fa295def86156d"}, + {file = "FlashRank-0.2.9.tar.gz", hash = "sha256:475f1192e0722da1a4409812165ebc7e3eccec56e7b7853ed9dd5dd5c9c985f5"}, +] + +[package.dependencies] +numpy = "*" +onnxruntime = "*" +requests = "*" +tokenizers = "*" +tqdm = "*" + +[package.extras] +listwise = ["llama-cpp-python (==0.2.76)"] + +[[package]] +name = "flatbuffers" +version = "24.3.25" +description = "The FlatBuffers serialization format for Python" +optional = false +python-versions = "*" +files = [ + {file = "flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812"}, + {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, +] + [[package]] name = "fonttools" version = "4.53.1" @@ -1573,6 +1622,20 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr torch = ["safetensors[torch]", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "humanfriendly" +version = "10.0" +description = "Human friendly output for text interfaces using Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"}, + {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"}, +] + +[package.dependencies] +pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} + [[package]] name = "identify" version = "2.6.0" @@ -2189,6 +2252,20 @@ language-data = ">=1.2" build = ["build", "twine"] test = ["pytest", "pytest-cov"] +[[package]] +name = "langdetect" +version = "1.0.9" +description = "Language detection library ported from Google's language-detection." +optional = false +python-versions = "*" +files = [ + {file = "langdetect-1.0.9-py2-none-any.whl", hash = "sha256:7cbc0746252f19e76f77c0b1690aadf01963be835ef0cd4b56dddf2a8f1dfc2a"}, + {file = "langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "langsmith" version = "0.1.99" @@ -2946,6 +3023,23 @@ files = [ [package.dependencies] psutil = "*" +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + [[package]] name = "multidict" version = "6.0.5" @@ -3270,6 +3364,48 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "onnxruntime" +version = "1.19.2" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +optional = false +python-versions = "*" +files = [ + {file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"}, + {file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"}, + {file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"}, + {file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"}, + {file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"}, +] + +[package.dependencies] +coloredlogs = "*" +flatbuffers = "*" +numpy = ">=1.21.6" +packaging = "*" +protobuf = "*" +sympy = "*" + [[package]] name = "openai" version = "1.45.0" @@ -3759,6 +3895,26 @@ files = [ [package.dependencies] wcwidth = "*" +[[package]] +name = "protobuf" +version = "5.28.2" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-5.28.2-cp310-abi3-win32.whl", hash = "sha256:eeea10f3dc0ac7e6b4933d32db20662902b4ab81bf28df12218aa389e9c2102d"}, + {file = "protobuf-5.28.2-cp310-abi3-win_amd64.whl", hash = "sha256:2c69461a7fcc8e24be697624c09a839976d82ae75062b11a0972e41fd2cd9132"}, + {file = "protobuf-5.28.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8b9403fc70764b08d2f593ce44f1d2920c5077bf7d311fefec999f8c40f78b7"}, + {file = "protobuf-5.28.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:35cfcb15f213449af7ff6198d6eb5f739c37d7e4f1c09b5d0641babf2cc0c68f"}, + {file = "protobuf-5.28.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:5e8a95246d581eef20471b5d5ba010d55f66740942b95ba9b872d918c459452f"}, + {file = "protobuf-5.28.2-cp38-cp38-win32.whl", hash = "sha256:87317e9bcda04a32f2ee82089a204d3a2f0d3c8aeed16568c7daf4756e4f1fe0"}, + {file = "protobuf-5.28.2-cp38-cp38-win_amd64.whl", hash = "sha256:c0ea0123dac3399a2eeb1a1443d82b7afc9ff40241433296769f7da42d142ec3"}, + {file = "protobuf-5.28.2-cp39-cp39-win32.whl", hash = "sha256:ca53faf29896c526863366a52a8f4d88e69cd04ec9571ed6082fa117fac3ab36"}, + {file = "protobuf-5.28.2-cp39-cp39-win_amd64.whl", hash = "sha256:8ddc60bf374785fb7cb12510b267f59067fa10087325b8e1855b898a0d81d276"}, + {file = "protobuf-5.28.2-py3-none-any.whl", hash = "sha256:52235802093bd8a2811abbe8bf0ab9c5f54cca0a751fdd3f6ac2a21438bffece"}, + {file = "protobuf-5.28.2.tar.gz", hash = "sha256:59379674ff119717404f7454647913787034f03fe7049cbef1d74a97bb4593f0"}, +] + [[package]] name = "psutil" version = "6.0.0" @@ -4102,6 +4258,20 @@ files = [ {file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"}, ] +[[package]] +name = "pyreadline3" +version = "3.5.4" +description = "A python implementation of GNU readline." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6"}, + {file = "pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7"}, +] + +[package.extras] +dev = ["build", "flake8", "mypy", "pytest", "twine"] + [[package]] name = "pysbd" version = "0.3.4" @@ -4578,6 +4748,32 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rerankers" +version = "0.5.3" +description = "A unified API for various document re-ranking models." +optional = false +python-versions = ">=3.8" +files = [ + {file = "rerankers-0.5.3-py3-none-any.whl", hash = "sha256:773aef2f7db371ec9c7b38b6652a78b7fdb96bb7999dc12817e882104c3e24e0"}, + {file = "rerankers-0.5.3.tar.gz", hash = "sha256:415bae5e9aaa71e1c2e9932839cc84ff8d939afcf31833bdb398c4bb3cbc67a3"}, +] + +[package.dependencies] +flashrank = {version = "*", optional = true, markers = "extra == \"flashrank\""} +pydantic = "*" +tqdm = "*" + +[package.extras] +all = ["flash-attn", "flashrank", "litellm", "nmslib-metabrainz", "protobuf", "rank-llm", "requests", "sentencepiece", "torch", "transformers"] +api = ["requests"] +dev = ["ipyprogress", "ipython", "ir-datasets", "isort", "pytest", "ranx", "ruff", "srsly"] +flashrank = ["flashrank"] +gpt = ["litellm"] +llmlayerwise = ["flash-attn", "protobuf", "sentencepiece", "torch", "transformers"] +rankllm = ["nmslib-metabrainz", "rank-llm"] +transformers = ["protobuf", "sentencepiece", "torch", "transformers"] + [[package]] name = "rich" version = "13.7.1" @@ -5339,6 +5535,23 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "sympy" +version = "1.13.3" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73"}, + {file = "sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9"}, +] + +[package.dependencies] +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] + [[package]] name = "tenacity" version = "8.5.0" @@ -6255,4 +6468,4 @@ ragas = ["ragas"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "f04b163e8e553b80afa7944958d063b72748ca55cc9f996ead98523ecce9716f" +content-hash = "b67e5de29567433c962694b0a54ae8c7843e36988e70ced50d71f7a70eacacf1" diff --git a/pyproject.toml b/pyproject.toml index 3955805..bf4780a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ llama-cpp-python = [ pydantic = ">=2.7.0" # Approximate Nearest Neighbors: pynndescent = ">=0.5.12" +# Reranking +langdetect = ">=1.0.9" +rerankers = { extras = ["flashrank"], version = ">=0.5.3" } # Storage: pg8000 = ">=1.31.2" sqlmodel-slim = ">=0.0.18" @@ -66,14 +69,14 @@ pandoc = ["pypandoc-binary"] ragas = ["ragas"] [tool.poetry.group.test.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/ -commitizen = ">=3.21.3" +commitizen = ">=3.29.1" coverage = { extras = ["toml"], version = ">=7.4.4" } mypy = ">=1.9.0" poethepoet = ">=0.25.0" pre-commit = ">=3.7.0" pytest = ">=8.1.1" pytest-mock = ">=3.14.0" -ruff = ">=0.3.5" +ruff = ">=0.5.7" safety = ">=3.1.0" shellcheck-py = ">=0.10.0.1" typeguard = ">=4.2.1" @@ -81,6 +84,7 @@ typeguard = ">=4.2.1" [tool.poetry.group.dev.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/ cruft = ">=2.15.0" ipykernel = ">=6.29.4" +ipython = ">=8.8.0" ipywidgets = ">=8.1.2" matplotlib = ">=3.9.0" memory-profiler = ">=0.61.0" diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py index 1719c59..b6d6231 100644 --- a/src/raglite/__init__.py +++ b/src/raglite/__init__.py @@ -6,9 +6,10 @@ from raglite._query_adapter import update_query_adapter from raglite._rag import rag from raglite._search import ( - fusion_search, hybrid_search, keyword_search, + rerank, + retrieve_chunks, retrieve_segments, vector_search, ) @@ -19,11 +20,12 @@ # Insert "insert_document", # Search - "fusion_search", "hybrid_search", "keyword_search", "vector_search", + "retrieve_chunks", "retrieve_segments", + "rerank", # RAG "rag", # Query adapter diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 5d8c4c6..38a9248 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -1,22 +1,18 @@ """RAGLite config.""" +import contextlib import os -from dataclasses import dataclass +from dataclasses import dataclass, field +from io import StringIO from llama_cpp import llama_supports_gpu_offload from sqlalchemy.engine import URL -DEFAULT_LLM = ( - "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@8192" - if llama_supports_gpu_offload() - else "llama-cpp-python/bartowski/Phi-3.5-mini-instruct-GGUF/*Q4_K_M.gguf@4096" -) - -DEFAULT_EMBEDDER = ( - "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf" - if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004 - else "llama-cpp-python/yishan-wang/snowflake-arctic-embed-m-v1.5-Q8_0-GGUF/*q8_0.gguf" -) +# Suppress rerankers output on import until [1] is fixed. +# [1] https://github.com/AnswerDotAI/rerankers/issues/36 +with contextlib.redirect_stdout(StringIO()): + from rerankers.models.flashrank_ranker import FlashRankRanker + from rerankers.models.ranker import BaseRanker @dataclass(frozen=True) @@ -26,10 +22,22 @@ class RAGLiteConfig: # Database config. db_url: str | URL = "sqlite:///raglite.sqlite" # LLM config used for generation. - llm: str = DEFAULT_LLM + llm: str = field( + default_factory=lambda: ( + "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@8192" + if llama_supports_gpu_offload() + else "llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096" + ) + ) llm_max_tries: int = 4 # Embedder config used for indexing. - embedder: str = DEFAULT_EMBEDDER + embedder: str = field( + default_factory=lambda: ( # Nomic-embed may be better if only English is used. + "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf" + if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004 + else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf" + ) + ) embedder_normalize: bool = True embedder_sentence_window_size: int = 3 # Chunk config used to partition documents into chunks. @@ -37,6 +45,13 @@ class RAGLiteConfig: # Vector search config. vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine". vector_search_query_adapter: bool = True + # Reranking config. + reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = field( + default_factory=lambda: ( + ("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)), + ("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)), + ) + ) def __post_init__(self) -> None: # Late chunking with llama-cpp-python does not apply sentence windowing. diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py index 0855144..bf96874 100644 --- a/src/raglite/_embed.py +++ b/src/raglite/_embed.py @@ -77,7 +77,9 @@ def _create_segment( assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder" # Compute the number of tokens per sentence. We use a method based on a sentinel token to # minimise the number of calls to embedder.tokenize, which incurs a significant overhead - # (presumably to load the tokenizer). + # (presumably to load the tokenizer) [1]. + # TODO: Make token counting faster and more robust once [1] is fixed. + # [1] https://github.com/abetlen/llama-cpp-python/issues/1763 num_tokens_list: list[int] = [] sentence_batch, sentence_batch_len = [], 0 for i, sentence in enumerate(sentences): @@ -90,8 +92,9 @@ def _create_segment( sentence_batch, sentence_batch_len = [], 0 num_tokens = np.asarray(num_tokens_list, dtype=np.intp) # Compute the maximum number of tokens for each segment's preamble and content. - # TODO: Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you - # try to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch. + # Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try + # to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch. + # TODO: Improve the context window size once [1] is fixed. # [1] https://github.com/abetlen/llama-cpp-python/issues/1762 max_tokens = min(n_ctx, n_batch) - 16 max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio. diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 4270340..bfce7c4 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -1,6 +1,5 @@ """Generation and evaluation of evals.""" -from collections.abc import Callable from random import randint from typing import ClassVar @@ -15,6 +14,7 @@ from raglite._extract import extract_with_llm from raglite._rag import rag from raglite._search import hybrid_search, retrieve_segments, vector_search +from raglite._typing import SearchMethod def insert_evals( # noqa: C901 @@ -167,7 +167,7 @@ class AnswerResponse(BaseModel): def answer_evals( num_evals: int = 100, - search: Callable[[str], tuple[list[str], list[float]]] = hybrid_search, + search: SearchMethod = hybrid_search, *, config: RAGLiteConfig | None = None, ) -> pd.DataFrame: @@ -184,7 +184,7 @@ def answer_evals( response = rag(eval_.question, search=search, config=config) answer = "".join(response) answers.append(answer) - chunk_ids, _ = search(eval_.question, config=config) # type: ignore[call-arg] + chunk_ids, _ = search(eval_.question, config=config) contexts.append(retrieve_segments(chunk_ids)) # Collect the answered evals. answered_evals: dict[str, list[str] | list[list[str]]] = { diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 0a15252..8a98a4d 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -1,5 +1,6 @@ """Add support for llama-cpp-python models to LiteLLM.""" +import logging import warnings from collections.abc import Callable, Iterator from functools import cache @@ -22,6 +23,10 @@ LlamaRAMCache, ) +# Reduce the logging level for LiteLLM and flashrank. +logging.getLogger("litellm").setLevel(logging.WARNING) +logging.getLogger("flashrank").setLevel(logging.WARNING) + class LlamaCppPythonLLM(CustomLLM): """A llama-cpp-python provider for LiteLLM. diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 48aefb9..0c42524 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,12 +1,14 @@ """Retrieval-augmented generation.""" -from collections.abc import Callable, Iterator +from collections.abc import Iterator from litellm import completion, get_model_info # type: ignore[attr-defined] from raglite._config import RAGLiteConfig +from raglite._database import Chunk from raglite._litellm import LlamaCppPythonLLM -from raglite._search import hybrid_search, retrieve_segments +from raglite._search import hybrid_search, rerank, retrieve_segments +from raglite._typing import SearchMethod def rag( @@ -14,7 +16,7 @@ def rag( *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), - search: Callable[[str], tuple[list[str], list[float]]] = hybrid_search, + search: SearchMethod | list[str] | list[Chunk] = hybrid_search, config: RAGLiteConfig | None = None, ) -> Iterator[str]: """Retrieval-augmented generation.""" @@ -30,9 +32,20 @@ def rag( max_tokens_per_context = round(1.2 * (config.chunk_max_size // 4)) max_tokens_per_context *= 1 + len(context_neighbors or []) max_contexts = min(max_contexts, max_tokens // max_tokens_per_context) - # Retrieve relevant contexts. - chunk_ids, _ = search(prompt, num_results=max_contexts, config=config) # type: ignore[call-arg] - segments = retrieve_segments(chunk_ids, neighbors=context_neighbors) + # Retrieve the top chunks. + chunks: list[str] | list[Chunk] + if callable(search): + # If the user has configured a reranker, we retrieve extra contexts to rerank. + extra_contexts = 4 * max_contexts if config.reranker else 0 + # Retrieve relevant contexts. + chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config) + # Rerank the relevant contexts. + chunks = rerank(query=prompt, chunk_ids=chunk_ids, config=config) + else: + # The user has passed a list of chunk_ids or chunks directly. + chunks = search + # Extend the top contexts with their neighbors and group chunks into contiguous segments. + segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config) # Respond with an LLM. contexts = "\n\n".join( f'\n{segment.strip()}\n' diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 48c2661..424cfa4 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -3,23 +3,23 @@ import re import string from collections import defaultdict +from collections.abc import Sequence from itertools import groupby -from typing import Annotated, ClassVar, cast +from typing import cast import numpy as np -from pydantic import BaseModel, Field +from langdetect import detect from sqlalchemy.engine import make_url -from sqlmodel import Session, select, text +from sqlmodel import Session, and_, col, or_, select, text from raglite._config import RAGLiteConfig from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine from raglite._embed import embed_sentences -from raglite._extract import extract_with_llm from raglite._typing import FloatMatrix def vector_search( - prompt: str | FloatMatrix, + query: str | FloatMatrix, *, num_results: int = 3, config: RAGLiteConfig | None = None, @@ -30,17 +30,15 @@ def vector_search( db_backend = make_url(config.db_url).get_backend_name() # Get the index metadata (including the query adapter, and in the case of SQLite, the index). index_metadata = IndexMetadata.get("default", config=config) - # Embed the prompt. - prompt_embedding = ( - embed_sentences([prompt], config=config)[0, :] - if isinstance(prompt, str) - else np.ravel(prompt) + # Embed the query. + query_embedding = ( + embed_sentences([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query) ) - # Apply the query adapter to the prompt embedding. + # Apply the query adapter to the query embedding. Q = index_metadata.get("query_adapter") # noqa: N806 if config.vector_search_query_adapter and Q is not None: - prompt_embedding = (Q @ prompt_embedding).astype(prompt_embedding.dtype) - # Search for the multi-vector chunk embeddings that are most similar to the prompt embedding. + query_embedding = (Q @ query_embedding).astype(query_embedding.dtype) + # Search for the multi-vector chunk embeddings that are most similar to the query embedding. if db_backend == "postgresql": # Check that the selected metric is supported by pgvector. metrics = {"cosine": "<=>", "dot": "<#>", "euclidean": "<->", "l1": "<+>", "l2": "<->"} @@ -53,7 +51,7 @@ def vector_search( distance_func = getattr( ChunkEmbedding.embedding, f"{config.vector_search_index_metric}_distance" ) - distance = distance_func(prompt_embedding).label("distance") + distance = distance_func(query_embedding).label("distance") results = session.exec( select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results) ) @@ -68,7 +66,7 @@ def vector_search( from pynndescent import NNDescent multi_vector_indices, distance = cast(NNDescent, index).query( - prompt_embedding[np.newaxis, :], k=8 * num_results + query_embedding[np.newaxis, :], k=8 * num_results ) similarity = 1 - distance[0, :] # Transform the multi-vector indices into chunk indices, and then to chunk ids. @@ -91,7 +89,7 @@ def vector_search( def keyword_search( - prompt: str, *, num_results: int = 3, config: RAGLiteConfig | None = None + query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None ) -> tuple[list[str], list[float]]: """Search chunks using BM25 keyword search.""" # Read the config. @@ -101,10 +99,10 @@ def keyword_search( engine = create_database_engine(config) with Session(engine) as session: if db_backend == "postgresql": - # Convert the prompt to a tsquery [1]. + # Convert the query to a tsquery [1]. # [1] https://www.postgresql.org/docs/current/textsearch-controls.html - prompt_escaped = re.sub(r"[&|!():<>\"]", " ", prompt) - tsv_query = " | ".join(prompt_escaped.split()) + query_escaped = re.sub(r"[&|!():<>\"]", " ", query) + tsv_query = " | ".join(query_escaped.split()) # Perform keyword search with tsvector. statement = text(""" SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score @@ -115,10 +113,10 @@ def keyword_search( """) results = session.execute(statement, params={"query": tsv_query, "limit": num_results}) elif db_backend == "sqlite": - # Convert the prompt to an FTS5 query [1]. + # Convert the query to an FTS5 query [1]. # [1] https://www.sqlite.org/fts5.html#full_text_query_syntax - prompt_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", prompt) - fts5_query = " OR ".join(prompt_escaped.split()) + query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", query) + fts5_query = " OR ".join(query_escaped.split()) # Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we # negate them to make them positive. # [1] https://www.sqlite.org/fts5.html#the_bm25_function @@ -155,90 +153,59 @@ def reciprocal_rank_fusion( def hybrid_search( - prompt: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None + query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None ) -> tuple[list[str], list[float]]: """Search chunks by combining ANN vector search with BM25 keyword search.""" # Run both searches. - chunkeyword_search_vector, _ = vector_search(prompt, num_results=num_rerank, config=config) - chunkeyword_search_keyword, _ = keyword_search(prompt, num_results=num_rerank, config=config) + vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config) + ks_chunk_ids, _ = keyword_search(query, num_results=num_rerank, config=config) # Combine the results with Reciprocal Rank Fusion (RRF). - chunk_ids, hybrid_score = reciprocal_rank_fusion( - [chunkeyword_search_vector, chunkeyword_search_keyword] - ) + chunk_ids, hybrid_score = reciprocal_rank_fusion([vs_chunk_ids, ks_chunk_ids]) chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results] return chunk_ids, hybrid_score -def fusion_search( - prompt: str, +def retrieve_chunks( + chunk_ids: list[str], *, - num_results: int = 5, - num_rerank: int = 100, config: RAGLiteConfig | None = None, -) -> tuple[list[str], list[float]]: - """Search for chunks with the RAG-Fusion method.""" - - class QueriesResponse(BaseModel): - """An array of queries that help answer the user prompt.""" - - queries: list[Annotated[str, Field(min_length=1)]] = Field( - ..., description="A single query that helps answer the user prompt." - ) - system_prompt: ClassVar[str] = """ -The user will give you a prompt in search of an answer. -Your task is to generate a minimal set of search queries for a search engine that together provide a complete answer to the user prompt. - """.strip() - - try: - queries_response = extract_with_llm(QueriesResponse, prompt, config=config) - except ValueError: - queries = [prompt] - else: - queries = [*queries_response.queries, prompt] - # Collect the search results for all the queries. - rankings = [] - for query in queries: - # Run both searches. - chunkeyword_search_vector, _ = vector_search(query, num_results=num_rerank, config=config) - chunkeyword_search_keyword, _ = keyword_search(query, num_results=num_rerank, config=config) - # Add results to the rankings. - rankings.append(chunkeyword_search_vector) - rankings.append(chunkeyword_search_keyword) - # Combine all the search results with Reciprocal Rank Fusion (RRF). - chunk_ids, fusion_score = reciprocal_rank_fusion(rankings) - chunk_ids, fusion_score = chunk_ids[:num_results], fusion_score[:num_results] - return chunk_ids, fusion_score +) -> list[Chunk]: + """Retrieve chunks by their ids.""" + config = config or RAGLiteConfig() + engine = create_database_engine(config) + with Session(engine) as session: + chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()) + return chunks def retrieve_segments( - chunk_ids: list[str], + chunk_ids: list[str] | list[Chunk], *, neighbors: tuple[int, ...] | None = (-1, 1), config: RAGLiteConfig | None = None, ) -> list[str]: """Group chunks into contiguous segments and retrieve them.""" - # Get the chunks and extend them with their neighbours. + # Retrieve the chunks. config = config or RAGLiteConfig() - chunks = set() - engine = create_database_engine(config) - with Session(engine) as session: - for chunk_id in chunk_ids: - # Get the chunk by id. - chunk = session.get(Chunk, chunk_id) - if chunk is not None: - chunks.add(chunk) - # Extend the chunk with its neighbouring chunks. - if chunk is not None and neighbors is not None and len(neighbors) > 0: - for offset in sorted(neighbors, key=abs): - where = ( - Chunk.document_id == chunk.document_id, - Chunk.index == chunk.index + offset, - ) - neighbor = session.exec(select(Chunk).where(*where)).first() - if neighbor is not None: - chunks.add(neighbor) + chunks: list[Chunk] = ( + retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] + if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) + else chunk_ids + ) + # Extend the chunks with their neighbouring chunks. + if neighbors: + engine = create_database_engine(config) + with Session(engine) as session: + neighbor_conditions = [ + and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset) + for chunk in chunks + for offset in neighbors + ] + chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all()) + # Keep only the unique chunks. + chunks = list(set(chunks)) # Sort the chunks by document_id and index (needed for groupby). - chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index)) # type: ignore[assignment] + chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index)) # Group the chunks into contiguous segments. segments: list[list[Chunk]] = [] for _, group in groupby(chunks, key=lambda chunk: chunk.document_id): @@ -256,3 +223,41 @@ def retrieve_segments( for segment in segments ] return segments # type: ignore[return-value] + + +def rerank( + query: str, + chunk_ids: list[str] | list[Chunk], + *, + config: RAGLiteConfig | None = None, +) -> list[Chunk]: + """Rerank chunks according to their relevance to a given query.""" + # Retrieve the chunks. + config = config or RAGLiteConfig() + chunks: list[Chunk] = ( + retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] + if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) + else chunk_ids + ) + # Early exit if no reranker is configured. + if not config.reranker: + return chunks + # Select the reranker. + if isinstance(config.reranker, Sequence): + # Detect the languages of the chunks and queries. + langs = {detect(str(chunk)) for chunk in chunks} + langs.add(detect(query)) + # If all chunks and the query are in the same language, use a language-specific reranker. + rerankers = dict(config.reranker) + if len(langs) == 1 and (lang := next(iter(langs))) in rerankers: + reranker = rerankers[lang] + else: + reranker = rerankers.get("other") + else: + # A specific reranker was configured. + reranker = config.reranker + # Rerank the chunks. + if reranker: + results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks]) + chunks = [chunks[result.doc_id] for result in results.results] + return chunks diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py index adda9d0..07a6904 100644 --- a/src/raglite/_typing.py +++ b/src/raglite/_typing.py @@ -3,18 +3,26 @@ import io import pickle from collections.abc import Callable -from typing import Any +from typing import Any, Protocol import numpy as np from sqlalchemy.engine import Dialect from sqlalchemy.sql.operators import Operators from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType +from raglite._config import RAGLiteConfig + FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] IntVector = np.ndarray[tuple[int], np.dtype[np.intp]] +class SearchMethod(Protocol): + def __call__( + self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None + ) -> tuple[list[str], list[float]]: ... + + class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): """A NumPy array column type for SQLAlchemy.""" diff --git a/tests/conftest.py b/tests/conftest.py index f49582b..96121b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,8 +54,8 @@ def database(request: pytest.FixtureRequest) -> str: @pytest.fixture( params=[ pytest.param( - "llama-cpp-python/ChristianAzinn/snowflake-arctic-embed-xs-gguf/*f16.GGUF", - id="snowflake_arctic_embed_xs", + "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf", + id="bge_m3", ), pytest.param( "text-embedding-3-small", diff --git a/tests/test_embed.py b/tests/test_embed.py index 77d5d5f..8de56d2 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -23,4 +23,4 @@ def test_embed(embedder: str) -> None: assert sentence_embeddings.shape[1] >= 128 # noqa: PLR2004 assert sentence_embeddings.dtype == np.float16 assert np.all(np.isfinite(sentence_embeddings)) - assert np.allclose(np.linalg.norm(sentence_embeddings, axis=1), 1.0) + assert np.allclose(np.linalg.norm(sentence_embeddings, axis=1), 1.0, rtol=1e-3)