From fb9126098aed2bf471737e3a4b966193c77c2635 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Fri, 30 Aug 2024 10:45:36 +0200 Subject: [PATCH] feat: add PostgreSQL driver --- .cruft.json | 12 +- README.md | 27 +-- docker-compose.yml | 29 +++- poetry.lock | 42 ++++- pyproject.toml | 9 +- src/raglite/__init__.py | 3 +- src/raglite/_config.py | 13 +- src/raglite/_database.py | 317 ++++++++++++++++++++-------------- src/raglite/_embed.py | 4 +- src/raglite/_eval.py | 23 ++- src/raglite/_index.py | 130 +++++++------- src/raglite/_query_adapter.py | 39 ++--- src/raglite/_rag.py | 6 +- src/raglite/_search.py | 230 +++++++++++++----------- src/raglite/_typing.py | 129 ++++++++++++++ tests/conftest.py | 48 ++++- tests/test_basic.py | 10 +- 17 files changed, 679 insertions(+), 392 deletions(-) diff --git a/.cruft.json b/.cruft.json index c5664de..a44ca99 100644 --- a/.cruft.json +++ b/.cruft.json @@ -1,15 +1,15 @@ { - "template": "https://github.com/radix-ai/poetry-cookiecutter", + "template": "https://github.com/superlinear-ai/poetry-cookiecutter", "commit": "a969f1d182ec39d7d27ccb1116cf60ba736adcfa", "checkout": null, "context": { "cookiecutter": { "project_type": "package", "project_name": "RAGLite", - "project_description": "A RAG extension for SQLite.", - "project_url": "https://github.com/radix-ai/raglite", + "project_description": "A Python package for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL.", + "project_url": "https://github.com/superlinear-ai/raglite", "author_name": "Laurent Sorber", - "author_email": "laurent@radix.ai", + "author_email": "laurent@superlinear.eu", "python_version": "3.10", "development_environment": "strict", "with_conventional_commits": "1", @@ -22,8 +22,8 @@ "__docstring_style": "NumPy", "__project_name_kebab_case": "raglite", "__project_name_snake_case": "raglite", - "_template": "https://github.com/radix-ai/poetry-cookiecutter" + "_template": "https://github.com/superlinear-ai/poetry-cookiecutter" } }, "directory": null -} +} \ No newline at end of file diff --git a/README.md b/README.md index 47f0a4f..29af5f0 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,23 @@ -[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/radix-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=812973394&skip_quickstart=true) +[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/superlinear-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=812973394&skip_quickstart=true) # ๐Ÿงต RAGLite -RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with SQLite. +RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with PostgreSQL or SQLite. ## Features 1. โค๏ธ Only lightweight and permissive open source dependencies (e.g., no [PyTorch](https://github.com/pytorch/pytorch), [LangChain](https://github.com/langchain-ai/langchain), or [PyMuPDF](https://github.com/pymupdf/PyMuPDF)) -2. ๐Ÿ”’ Fully local RAG with [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as an LLM provider and [SQLite](https://github.com/sqlite/sqlite) as a local database -3. ๐Ÿš€ Acceleration with Metal on macOS and with CUDA on Linux and Windows -4. ๐Ÿ“– PDF to Markdown conversion on top of [pdftext](https://github.com/VikParuchuri/pdftext) and [pypdfium2](https://github.com/pypdfium2-team/pypdfium2) -5. โœ‚๏ธ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming) -6. ๐Ÿ“Œ Markdown-based [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag) -7. ๐ŸŒˆ Combined sentence-level and chunk-level matching with [multi-vector chunk retrieval](https://python.langchain.com/v0.2/docs/how_to/multi_vector/) -8. ๐ŸŒ€ Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem) -9. ๐Ÿ” [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) that combines [SQLite's BM25 full-text search](https://sqlite.org/fts5.html) with [PyNNDescent's ANN vector search](https://github.com/lmcinnes/pynndescent) -10. โœ๏ธ Optional support for conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc) -11. โœ… Optional support for evaluation of retrieval and generation with [Ragas](https://github.com/explodinggradients/ragas) +2. ๐Ÿง  Your choice of local LLM with [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) +3. ๐Ÿ’พ Your choice of [PostgreSQL](https://github.com/postgres/postgres) or [SQLite](https://github.com/sqlite/sqlite) as a full-text & vector search database +4. ๐Ÿš€ Acceleration with Metal on macOS and with CUDA on Linux and Windows +5. ๐Ÿ“– PDF to Markdown conversion on top of [pdftext](https://github.com/VikParuchuri/pdftext) and [pypdfium2](https://github.com/pypdfium2-team/pypdfium2) +6. โœ‚๏ธ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming) +7. ๐Ÿ“Œ Markdown-based [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag) +8. ๐ŸŒˆ Combined sentence-level and chunk-level matching with [multi-vector chunk retrieval](https://python.langchain.com/v0.2/docs/how_to/multi_vector/) +9. ๐ŸŒ€ Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem) +10. ๐Ÿ” [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) that combines the database's built-in full-text search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html) in PostgreSQL, [FTS5](https://www.sqlite.org/fts5.html) in SQLite) with their native vector search extensions ([pgvector](https://github.com/pgvector/pgvector) in PostgreSQL, [sqlite-vec](https://github.com/asg017/sqlite-vec) in SQLite) +11. โœ๏ธ Optional support for conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc) +12. โœ… Optional support for evaluation of retrieval and generation performance with [Ragas](https://github.com/explodinggradients/ragas) ## Installing @@ -145,7 +146,7 @@ evaluation_df = evaluate(answered_evals_df, config=my_config) The following development environments are supported: 1. โญ๏ธ _GitHub Codespaces_: click on _Code_ and select _Create codespace_ to start a Dev Container with [GitHub Codespaces](https://github.com/features/codespaces). -1. โญ๏ธ _Dev Container (with container volume)_: click on [Open in Dev Containers](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/radix-ai/raglite) to clone this repository in a container volume and create a Dev Container with VS Code. +1. โญ๏ธ _Dev Container (with container volume)_: click on [Open in Dev Containers](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/superlinear-ai/raglite) to clone this repository in a container volume and create a Dev Container with VS Code. 1. _Dev Container_: clone this repository, open it with VS Code, and run Ctrl/โŒ˜ + โ‡ง + P โ†’ _Dev Containers: Reopen in Container_. 1. _PyCharm_: clone this repository, open it with PyCharm, and [configure Docker Compose as a remote interpreter](https://www.jetbrains.com/help/pycharm/using-docker-compose-as-a-remote-interpreter.html#docker-compose-remote) with the `dev` service. 1. _Terminal_: clone this repository, open it with your terminal, and run `docker compose up --detach dev` to start a Dev Container in the background, and then run `docker compose exec dev zsh` to open a shell prompt in the Dev Container. diff --git a/docker-compose.yml b/docker-compose.yml index 0c889b3..e2eb595 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,10 @@ services: GID: ${GID:-1000} environment: - POETRY_PYPI_TOKEN_PYPI + depends_on: + - postgres + networks: + - raglite-network volumes: - ..:/workspaces - command-history-volume:/home/user/.history/ @@ -21,15 +25,14 @@ services: stdin_open: true tty: true entrypoint: [] - command: - [ - "sh", - "-c", - "sudo chown user $$SSH_AUTH_SOCK && cp --update /opt/build/poetry/poetry.lock /workspaces/raglite/ && mkdir -p /workspaces/raglite/.git/hooks/ && cp --update /opt/build/git/* /workspaces/raglite/.git/hooks/ && zsh" - ] + command: [ "sh", "-c", "sudo chown user $$SSH_AUTH_SOCK && cp --update /opt/build/poetry/poetry.lock /workspaces/raglite/ && mkdir -p /workspaces/raglite/.git/hooks/ && cp --update /opt/build/git/* /workspaces/raglite/.git/hooks/ && zsh" ] environment: - POETRY_PYPI_TOKEN_PYPI - SSH_AUTH_SOCK=/run/host-services/ssh-auth.sock + depends_on: + - postgres + networks: + - raglite-network volumes: - ~/.gitconfig:/etc/gitconfig - ~/.ssh/known_hosts:/home/user/.ssh/known_hosts @@ -37,5 +40,19 @@ services: profiles: - dev + postgres: + image: pgvector/pgvector:pg16 + environment: + POSTGRES_USER: raglite_user + POSTGRES_PASSWORD: raglite_password + networks: + - raglite-network + tmpfs: + - /var/lib/postgresql/data + +networks: + raglite-network: + driver: bridge + volumes: command-history-volume: diff --git a/poetry.lock b/poetry.lock index f84bfeb..e4d26af 100644 --- a/poetry.lock +++ b/poetry.lock @@ -210,6 +210,17 @@ types-python-dateutil = ">=2.8.10" doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"] test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"] +[[package]] +name = "asn1crypto" +version = "1.5.1" +description = "Fast ASN.1 parser and serializer with definitions for private keys, public keys, certificates, CRL, OCSP, CMS, PKCS#3, PKCS#7, PKCS#8, PKCS#12, PKCS#5, X.509 and TSP" +optional = false +python-versions = "*" +files = [ + {file = "asn1crypto-1.5.1-py2.py3-none-any.whl", hash = "sha256:db4e40728b728508912cbb3d44f19ce188f218e9eba635821bb4b68564f8fd67"}, + {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, +] + [[package]] name = "asttokens" version = "2.4.1" @@ -3437,6 +3448,21 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pg8000" +version = "1.31.2" +description = "PostgreSQL interface library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pg8000-1.31.2-py3-none-any.whl", hash = "sha256:436c771ede71af4d4c22ba867a30add0bc5c942d7ab27fadbb6934a487ecc8f6"}, + {file = "pg8000-1.31.2.tar.gz", hash = "sha256:1ea46cf09d8eca07fe7eaadefd7951e37bee7fabe675df164f1a572ffb300876"}, +] + +[package.dependencies] +python-dateutil = ">=2.8.2" +scramp = ">=1.4.5" + [[package]] name = "pillow" version = "10.4.0" @@ -4745,6 +4771,20 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "scramp" +version = "1.4.5" +description = "An implementation of the SCRAM protocol." +optional = false +python-versions = ">=3.8" +files = [ + {file = "scramp-1.4.5-py3-none-any.whl", hash = "sha256:50e37c464fc67f37994e35bee4151e3d8f9320e9c204fca83a5d313c121bbbe7"}, + {file = "scramp-1.4.5.tar.gz", hash = "sha256:be3fbe774ca577a7a658117dca014e5d254d158cecae3dd60332dfe33ce6d78e"}, +] + +[package.dependencies] +asn1crypto = ">=1.5.1" + [[package]] name = "setuptools" version = "72.2.0" @@ -5900,4 +5940,4 @@ ragas = ["ragas"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "a65d8c1115521aa69301e35158c5d6f6dc0618f745c6373c86dd677a853e320c" +content-hash = "8178f5e494205788ce704f31e778aea16ce12d2e055d704c1c706fa2c1e8e5e9" diff --git a/pyproject.toml b/pyproject.toml index ae6fc6c..2cb9287 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,10 +5,10 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] # https://python-poetry.org/docs/pyproject/ name = "raglite" version = "0.0.0" -description = "A Python package for Retrieval-Augmented Generation (RAG) with SQLite." -authors = ["Laurent Sorber "] +description = "A Python package for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL." +authors = ["Laurent Sorber "] readme = "README.md" -repository = "https://github.com/radix-ai/raglite" +repository = "https://github.com/superlinear-ai/raglite" [tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/ bump_message = "bump(release): v$current_version โ†’ v$new_version" @@ -51,6 +51,7 @@ pydantic = ">=2.7.0" # Approximate Nearest Neighbors: pynndescent = ">=0.5.12" # Storage: +pg8000 = ">=1.31.2" sqlmodel-slim = ">=0.0.18" # Progress: tqdm = ">=4.66.0" @@ -114,7 +115,7 @@ warn_unreachable = true [tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml" -filterwarnings = ["error", "ignore::DeprecationWarning"] +filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::pytest.PytestUnraisableExceptionWarning"] testpaths = ["src", "tests"] xfail_strict = true diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py index c77f710..a5ef385 100644 --- a/src/raglite/__init__.py +++ b/src/raglite/__init__.py @@ -2,7 +2,7 @@ from raglite._config import RAGLiteConfig from raglite._eval import answer_evals, evaluate, insert_evals -from raglite._index import insert_document, update_vector_index +from raglite._index import insert_document from raglite._query_adapter import update_query_adapter from raglite._rag import rag from raglite._search import ( @@ -18,7 +18,6 @@ "RAGLiteConfig", # Index "insert_document", - "update_vector_index", # Search "fusion_search", "hybrid_search", diff --git a/src/raglite/_config.py b/src/raglite/_config.py index e6c74c5..a76daf8 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -3,8 +3,6 @@ from dataclasses import dataclass, field from functools import lru_cache -import numpy as np -import numpy.typing as npt from llama_cpp import Llama, LlamaRAMCache, llama_supports_gpu_offload # type: ignore[attr-defined] from sqlalchemy.engine import URL @@ -12,14 +10,14 @@ @lru_cache(maxsize=1) def default_llm() -> Llama: """Get default LLM.""" - # Select the best available LLM for the given accelerator. + # Select the best available LLM for the given accelerator: + # - Llama-3.1-8B-instruct on GPU. + # - Phi-3.5-mini-instruct on CPU. if llama_supports_gpu_offload(): - # Llama-3.1-8B-instruct on GPU. repo_id = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" # https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct filename = "*Q4_K_M.gguf" n_ctx = 8192 else: - # Phi-3.1-mini-128k-instruct on CPU. repo_id = "bartowski/Phi-3.5-mini-instruct-GGUF" # https://huggingface.co/microsoft/Phi-3.5-mini-instruct filename = "*Q4_K_M.gguf" n_ctx = 4096 @@ -61,7 +59,6 @@ class RAGLiteConfig: # Embedder config used for indexing. embedder: Llama = field(default_factory=default_embedder) embedder_batch_size: int = 128 - embedder_dtype: npt.DTypeLike = np.float16 embedder_normalize: bool = True sentence_embedding_weight: float = 0.5 # Between 0 (chunk level) and 1 (sentence level). # Chunker config used to partition documents into chunks. @@ -70,7 +67,5 @@ class RAGLiteConfig: # Database config. db_url: str | URL = "sqlite:///raglite.sqlite" # Vector search config. - vector_search_index_id: str = "default" vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine". - # Query adapter config. - enable_query_adapter: bool = True + vector_search_query_adapter: bool = True diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 2bedcd8..b4bafa4 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -1,7 +1,6 @@ -"""SQLite tables for RAGLite.""" +"""PostgreSQL or SQLite database tables for RAGLite.""" -import io -import pickle +import datetime from functools import lru_cache from hashlib import sha256 from pathlib import Path @@ -9,12 +8,22 @@ import numpy as np from markdown_it import MarkdownIt -from pynndescent import NNDescent -from sqlalchemy.engine import URL, Dialect, Engine, make_url -from sqlalchemy.types import LargeBinary, TypeDecorator -from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text - -from raglite._typing import FloatMatrix +from pydantic import ConfigDict +from sqlalchemy.engine import Engine, make_url +from sqlmodel import ( + JSON, + Column, + Field, + Relationship, + Session, + SQLModel, + create_engine, + select, + text, +) + +from raglite._config import RAGLiteConfig +from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject def hash_bytes(data: bytes, max_len: int = 16) -> str: @@ -22,55 +31,17 @@ def hash_bytes(data: bytes, max_len: int = 16) -> str: return sha256(data, usedforsecurity=False).hexdigest()[:max_len] -class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): - """A NumPy array column type for SQLAlchemy.""" - - impl = LargeBinary - - def process_bind_param( - self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect - ) -> bytes | None: - """Convert a NumPy array to bytes.""" - if value is None: - return None - buffer = io.BytesIO() - np.save(buffer, value, allow_pickle=False, fix_imports=False) - return buffer.getvalue() - - def process_result_value( - self, value: bytes | None, dialect: Dialect - ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: - """Convert bytes to a NumPy array.""" - if value is None: - return None - return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return] - - -class PickledObject(TypeDecorator[object]): - """A pickled object column type for SQLAlchemy.""" - - impl = LargeBinary - - def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None: - """Convert a Python object to bytes.""" - if value is None: - return None - return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False) - - def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None: - """Convert bytes to a Python object.""" - if value is None: - return None - return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301 - - class Document(SQLModel, table=True): """A document.""" + # Enable JSON columns. + model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] + + # Table columns. id: str = Field(..., primary_key=True) filename: str url: str | None = Field(default=None) - metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON)) + metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) # Add relationships so we can access document.chunks and document.evals. chunks: list["Chunk"] = Relationship(back_populates="document") @@ -90,26 +61,24 @@ def from_path(doc_path: Path, **kwargs: Any) -> "Document": }, ) - # Enable support for JSON columns. - class Config: - """Table configuration.""" - - arbitrary_types_allowed = True - class Chunk(SQLModel, table=True): """A document chunk.""" + # Enable JSON columns. + model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] + + # Table columns. id: str = Field(..., primary_key=True) document_id: str = Field(..., foreign_key="document.id", index=True) index: int = Field(..., index=True) headings: str body: str - multi_vector_embedding: FloatMatrix = Field(..., sa_column=Column(NumpyArray)) - metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON)) + metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) - # Add relationship so we can access chunk.document. + # Add relationships so we can access chunk.document and chunk.embeddings. document: Document = Relationship(back_populates="chunks") + embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk") @staticmethod def from_body( @@ -117,7 +86,6 @@ def from_body( index: int, body: str, headings: str = "", - multi_vector_embedding: FloatMatrix | None = None, **kwargs: Any, ) -> "Chunk": """Create a chunk from Markdown.""" @@ -127,9 +95,6 @@ def from_body( index=index, headings=headings, body=body, - multi_vector_embedding=multi_vector_embedding - if multi_vector_embedding is not None - else np.empty(0), metadata_=kwargs, ) @@ -151,6 +116,12 @@ def extract_headings(self) -> str: headings = "\n".join([heading for heading in heading_lines if heading]) return headings + @property + def embedding_matrix(self) -> FloatMatrix: + """Return this chunk's multi-vector embedding matrix.""" + # Uses the relationship chunk.embeddings to access the chunk_embedding table. + return np.vstack([embedding.embedding[np.newaxis, :] for embedding in self.embeddings]) + def __str__(self) -> str: """Context representation of this chunk.""" return f"{self.headings.strip()}\n\n{self.body.strip()}".strip() @@ -158,29 +129,75 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.id) - # Enable support for JSON and NumpyArray columns. - class Config: - """Table configuration.""" - arbitrary_types_allowed = True +class ChunkEmbedding(SQLModel, table=True): + """A (sub-)chunk embedding.""" + __tablename__ = "chunk_embedding" -class VectorSearchChunkIndex(SQLModel, table=True): - """A vector search index for chunks.""" + # Enable Embedding columns. + model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] - __tablename__ = "vs_chunk_index" # Vector search chunk index. + # Table columns. + id: int = Field(..., primary_key=True) + chunk_id: str = Field(..., foreign_key="chunk.id", index=True) + embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1))) + # Add relationship so we can access embedding.chunk. + chunk: Chunk = Relationship(back_populates="embeddings") + + @classmethod + def set_embedding_dim(cls, dim: int) -> None: + """Modify the embedding column's dimension after class definition.""" + cls.__table__.c["embedding"].type.dim = dim # type: ignore[attr-defined] + + +class IndexMetadata(SQLModel, table=True): + """Vector and keyword search index metadata.""" + + __tablename__ = "index_metadata" + + # Enable PickledObject columns. + model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] + + # Table columns. id: str = Field(..., primary_key=True) - chunk_sizes: list[int] = Field(default=[], sa_column=Column(JSON)) - index: NNDescent | None = Field(default=None, sa_column=Column(PickledObject)) - query_adapter: FloatMatrix | None = Field(default=None, sa_column=Column(NumpyArray)) - metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON)) + version: datetime.datetime = Field( + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + metadata_: dict[str, Any] = Field( + default_factory=dict, sa_column=Column("metadata", PickledObject) + ) + + @staticmethod + def _get_version(id_: str, *, config: RAGLiteConfig | None = None) -> datetime.datetime | None: + """Get the version of the index metadata with a given id.""" + engine = create_database_engine(config) + with Session(engine) as session: + version = session.exec( + select(IndexMetadata.version).where(IndexMetadata.id == id_) + ).first() + return version - # Enable support for JSON, PickledObject, and NumpyArray columns. - class Config: - """Table configuration.""" + @staticmethod + @lru_cache(maxsize=4) + def _get( + id_: str, version: datetime.datetime | None, *, config: RAGLiteConfig | None = None + ) -> dict[str, Any] | None: + if version is None: + return None + engine = create_database_engine(config) + with Session(engine) as session: + index_metadata_record = session.get(IndexMetadata, id_) + if index_metadata_record is None: + return None + return index_metadata_record.metadata_ - arbitrary_types_allowed = True + @staticmethod + def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]: + version = IndexMetadata._get_version(id_, config=config) + metadata = IndexMetadata._get(id_, version, config=config) or {} + return metadata class Eval(SQLModel, table=True): @@ -188,13 +205,17 @@ class Eval(SQLModel, table=True): __tablename__ = "eval" + # Enable JSON columns. + model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] + + # Table columns. id: str = Field(..., primary_key=True) document_id: str = Field(..., foreign_key="document.id", index=True) - chunk_ids: list[str] = Field(default=[], sa_column=Column(JSON)) + chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON)) question: str - contexts: list[str] = Field(default=[], sa_column=Column(JSON)) + contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON)) ground_truth: str - metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON)) + metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) # Add relationship so we can access eval.document. document: Document = Relationship(back_populates="evals") @@ -219,56 +240,94 @@ def from_chunks( metadata_=kwargs, ) - # Enable support for JSON columns. - class Config: - """Table configuration.""" - - arbitrary_types_allowed = True - @lru_cache(maxsize=1) -def create_database_engine(db_url: str | URL = "sqlite:///raglite.sqlite") -> Engine: +def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: """Create a database engine and initialize it.""" - # Parse the database URL. - db_url = make_url(db_url) - assert db_url.get_backend_name() == "sqlite", "RAGLite currently only supports SQLite." - # Optimize SQLite performance. - pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"} - db_url = db_url.update_query_dict(pragmas, append=True) + # Parse the database URL and validate that the database backend is supported. + config = config or RAGLiteConfig() + db_url = make_url(config.db_url) + db_backend = db_url.get_backend_name() + # Update database configuration. + connect_args = {} + if db_backend == "postgresql": + # Select the pg8000 driver if not set (psycopg2 is the default), and prefer SSL. + if "+" not in db_url.drivername: + db_url = db_url.set(drivername="postgresql+pg8000") + # Support setting the sslmode for pg8000. + if "pg8000" in db_url.drivername and "sslmode" in db_url.query: + query = dict(db_url.query) + if query.pop("sslmode") != "disable": + connect_args["ssl_context"] = True + db_url = db_url.set(query=query) + elif db_backend == "sqlite": + # Optimize SQLite performance. + pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"} + db_url = db_url.update_query_dict(pragmas, append=True) + else: + error_message = "RAGLite only supports PostgreSQL and SQLite." + raise ValueError(error_message) # Create the engine. - engine = create_engine(db_url) + engine = create_engine(db_url, pool_pre_ping=True, connect_args=connect_args) + # Install database extensions. + if db_backend == "postgresql": + with Session(engine) as session: + session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + session.commit() # Create all SQLModel tables. + ChunkEmbedding.set_embedding_dim(config.embedder.n_embd()) SQLModel.metadata.create_all(engine) - # Create a virtual table for full-text search on the chunk table. - # We use the chunk table as an external content table [1] to avoid duplicating the data. - # [1] https://www.sqlite.org/fts5.html#external_content_tables - with Session(engine) as session: - session.execute( - text(""" - CREATE VIRTUAL TABLE IF NOT EXISTS fts_chunk_index USING fts5(body, content='chunk', content_rowid='rowid'); - """) - ) - session.execute( - text(""" - CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN - INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body); - END; - """) - ) - session.execute( - text(""" - CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN - INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); - END; - """) - ) - session.execute( - text(""" - CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN - INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); - INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body); - END; - """) - ) - session.commit() + # Create backend-specific indexes. + if db_backend == "postgresql": + # Create a full-text search index with `tsvector` and a vector search index with `pgvector`. + with Session(engine) as session: + metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"} + session.execute( + text(""" + CREATE INDEX IF NOT EXISTS fts_chunk_index ON chunk USING GIN (to_tsvector('simple', body)); + """) + ) + session.execute( + text(f""" + CREATE INDEX IF NOT EXISTS vs_chunk_index ON chunk_embedding + USING hnsw ( + (embedding::halfvec({config.embedder.n_embd()})) + halfvec_{metrics[config.vector_search_index_metric]}_ops + ); + """) + ) + session.commit() + elif db_backend == "sqlite": + # Create a virtual table for full-text search on the chunk table. + # We use the chunk table as an external content table [1] to avoid duplicating the data. + # [1] https://www.sqlite.org/fts5.html#external_content_tables + with Session(engine) as session: + session.execute( + text(""" + CREATE VIRTUAL TABLE IF NOT EXISTS fts_chunk_index USING fts5(body, content='chunk', content_rowid='rowid'); + """) + ) + session.execute( + text(""" + CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN + INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body); + END; + """) + ) + session.execute( + text(""" + CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN + INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); + END; + """) + ) + session.execute( + text(""" + CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN + INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); + INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body); + END; + """) + ) + session.commit() return engine diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py index 7aa4a93..02a873d 100644 --- a/src/raglite/_embed.py +++ b/src/raglite/_embed.py @@ -22,8 +22,8 @@ def _embed_string_batch( # Normalise embeddings to unit norm. if config.embedder_normalize: embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) - # Cast to the configured dtype after normalisation. - embeddings = embeddings.astype(config.embedder_dtype) + # Cast to half precision after normalisation. + embeddings = embeddings.astype(np.float16) return embeddings diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index daf2605..425464f 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -55,7 +55,7 @@ def validate_question(cls, value: str) -> str: return value config = config or RAGLiteConfig() - engine = create_database_engine(config.db_url) + engine = create_database_engine(config) with Session(engine) as session: for _ in trange(num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True): # Sample a random document from the database. @@ -73,12 +73,12 @@ def validate_question(cls, value: str) -> str: if seed_chunk is None: continue # Expand the seed chunk into a set of related chunks. - related_chunk_rowids, _ = vector_search( - np.mean(seed_chunk.multi_vector_embedding, axis=0, keepdims=True), + related_chunk_ids, _ = vector_search( + np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True), num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311 config=config, ) - related_chunks = retrieve_segments(related_chunk_rowids, config=config) + related_chunks = retrieve_segments(related_chunk_ids, config=config) # Extract a question from the seed chunk's related chunks. try: question_response = extract_with_llm( @@ -89,13 +89,10 @@ def validate_question(cls, value: str) -> str: else: question = question_response.question # Search for candidate chunks to answer the generated question. - candidate_chunk_rowids, _ = hybrid_search( + candidate_chunk_ids, _ = hybrid_search( question, num_results=max_contexts_per_eval, config=config ) - candidate_chunks = [ - session.exec(select(Chunk).offset(chunk_rowid - 1)).first() - for chunk_rowid in candidate_chunk_rowids - ] + candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids] # Determine which candidate chunks are relevant to answer the generated question. class ContextEvalResponse(BaseModel): @@ -170,14 +167,14 @@ class AnswerResponse(BaseModel): def answer_evals( num_evals: int = 100, - search: Callable[[str], tuple[list[int], list[float]]] = hybrid_search, + search: Callable[[str], tuple[list[str], list[float]]] = hybrid_search, *, config: RAGLiteConfig | None = None, ) -> pd.DataFrame: """Read evals from the database and answer them with RAG.""" # Read evals from the database. config = config or RAGLiteConfig() - engine = create_database_engine(config.db_url) + engine = create_database_engine(config) with Session(engine) as session: evals = session.exec(select(Eval).limit(num_evals)).all() # Answer evals with RAG. @@ -187,8 +184,8 @@ def answer_evals( response = rag(eval_.question, search=search, config=config) answer = "".join(response) answers.append(answer) - chunk_rowids, _ = search(eval_.question, config=config) # type: ignore[call-arg] - contexts.append(retrieve_segments(chunk_rowids)) + chunk_ids, _ = search(eval_.question, config=config) # type: ignore[call-arg] + contexts.append(retrieve_segments(chunk_ids)) # Collect the answered evals. answered_evals: dict[str, list[str] | list[list[str]]] = { "question": [eval_.question for eval_ in evals], diff --git a/src/raglite/_index.py b/src/raglite/_index.py index 81e531d..bfa2e25 100644 --- a/src/raglite/_index.py +++ b/src/raglite/_index.py @@ -1,16 +1,15 @@ """Index documents.""" -from copy import deepcopy from functools import partial from pathlib import Path import numpy as np -from pynndescent import NNDescent +from sqlalchemy.engine import make_url from sqlmodel import Session, select from tqdm.auto import tqdm from raglite._config import RAGLiteConfig -from raglite._database import Chunk, Document, VectorSearchChunkIndex, create_database_engine +from raglite._database import Chunk, ChunkEmbedding, Document, IndexMetadata, create_database_engine from raglite._embed import embed_strings from raglite._markdown import document_to_markdown from raglite._split_chunks import split_chunks @@ -23,8 +22,8 @@ def _create_chunk_records( chunks: list[str], sentence_embeddings: list[FloatMatrix], config: RAGLiteConfig, -) -> list[Chunk]: - """Process chunks into chunk records comprising headings, body, and a multi-vector embedding.""" +) -> tuple[list[Chunk], list[list[ChunkEmbedding]]]: + """Process chunks into chunk and chunk embedding records.""" # Create the chunk records. chunk_records, headings = [], "" for i, chunk in enumerate(chunks): @@ -37,28 +36,28 @@ def _create_chunk_records( contextualized_embeddings = embed_strings([str(chunk) for chunk in chunks], config=config) # Set the chunk's multi-vector embedding as a linear combination of its sentence embeddings # (for local context) and an embedding of the contextualised chunk (for global context). - for record, sentence_embedding, contextualized_embedding in zip( + ฮฑ = config.sentence_embedding_weight # noqa: PLC2401 + chunk_embedding_records = [] + for chunk_record, sentence_embedding, contextualized_embedding in zip( chunk_records, sentence_embeddings, contextualized_embeddings, strict=True ): - chunk_embedding = ( - config.sentence_embedding_weight * sentence_embedding - + (1 - config.sentence_embedding_weight) * contextualized_embedding[np.newaxis, :] - ) + chunk_embedding = ฮฑ * sentence_embedding + (1 - ฮฑ) * contextualized_embedding[np.newaxis, :] chunk_embedding = chunk_embedding / np.linalg.norm(chunk_embedding, axis=1, keepdims=True) - record.multi_vector_embedding = chunk_embedding - return chunk_records + chunk_embedding_records.append( + [ChunkEmbedding(chunk_id=chunk_record.id, embedding=row) for row in chunk_embedding] + ) + return chunk_records, chunk_embedding_records -def insert_document( - doc_path: Path, *, update_index: bool = True, config: RAGLiteConfig | None = None -) -> None: +def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> None: """Insert a document into the database and update the index.""" # Use the default config if not provided. config = config or RAGLiteConfig() + db_backend = make_url(config.db_url).get_backend_name() # Preprocess the document into chunks. with tqdm(total=4, unit="step", dynamic_ncols=True) as pbar: pbar.set_description("Initializing database") - engine = create_database_engine(config.db_url) + engine = create_database_engine(config) pbar.update(1) pbar.set_description("Converting to Markdown") doc = document_to_markdown(doc_path) @@ -81,61 +80,62 @@ def insert_document( if session.get(Document, document_record.id) is None: session.add(document_record) session.commit() - # Create the chunk records. - chunk_records = _create_chunk_records( + # Create the chunk records to insert into the chunk table. + chunk_records, chunk_embedding_records = _create_chunk_records( document_record.id, chunks, sentence_embeddings, config ) - # Store the chunk records. - for chunk_record in tqdm( - chunk_records, desc="Storing chunks", unit="chunk", dynamic_ncols=True + # Store the chunk and chunk embedding records. + for chunk_record, chunk_embedding_record in tqdm( + zip(chunk_records, chunk_embedding_records, strict=True), + desc="Storing chunks" if db_backend == "sqlite" else "Storing and indexing chunks", + total=len(chunk_records), + unit="chunk", + dynamic_ncols=True, ): if session.get(Chunk, chunk_record.id) is not None: continue session.add(chunk_record) + session.add_all(chunk_embedding_record) session.commit() - # Update the vector search chunk index. - if update_index: - update_vector_index(config) - + # Manually update the vector search chunk index for SQLite. + if db_backend == "sqlite": + from pynndescent import NNDescent -def update_vector_index(config: RAGLiteConfig | None = None) -> None: - """Update the vector search chunk index with any unindexed chunks.""" - config = config or RAGLiteConfig() - engine = create_database_engine(config.db_url) - with Session(engine) as session: - # Get the vector search chunk index from the database, or create a new one. - vector_search_chunk_index = session.get( - VectorSearchChunkIndex, config.vector_search_index_id - ) or VectorSearchChunkIndex(id=config.vector_search_index_id) - num_chunks_indexed = len(vector_search_chunk_index.chunk_sizes) - # Get the unindexed chunks. - statement = select(Chunk).offset(num_chunks_indexed) - unindexed_chunks = session.exec(statement).all() - num_chunks_unindexed = len(unindexed_chunks) - # Index the unindexed chunks. - with tqdm( - total=num_chunks_indexed + num_chunks_unindexed, - desc="Indexing chunks", - unit="chunk", - dynamic_ncols=True, - ) as pbar: - # Fit or update the ANN index. - pbar.update(num_chunks_indexed) - if num_chunks_unindexed == 0: + with Session(engine) as session: + # Get the vector search chunk index from the database, or create a new one. + index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") + chunk_ids = index_metadata.metadata_.get("chunk_ids", []) + chunk_sizes = index_metadata.metadata_.get("chunk_sizes", []) + # Get the unindexed chunks. + unindexed_chunks = list(session.exec(select(Chunk).offset(len(chunk_ids))).all()) + if not unindexed_chunks: return - X_unindexed = np.vstack([chunk.multi_vector_embedding for chunk in unindexed_chunks]) # noqa: N806 - if num_chunks_indexed == 0: - nndescent = NNDescent(X_unindexed, metric=config.vector_search_index_metric) - else: - nndescent = deepcopy(vector_search_chunk_index.index) - nndescent.update(X_unindexed) - nndescent.prepare() - # Mark the vector search chunk index as dirty. - vector_search_chunk_index.index = nndescent - vector_search_chunk_index.chunk_sizes = vector_search_chunk_index.chunk_sizes + [ - chunk.multi_vector_embedding.shape[0] for chunk in unindexed_chunks - ] - # Store the updated vector search chunk index. - session.add(vector_search_chunk_index) - session.commit() - pbar.update(num_chunks_unindexed) + # Assemble the unindexed chunk embeddings into a NumPy array. + unindexed_chunk_embeddings = [chunk.embedding_matrix for chunk in unindexed_chunks] + X = np.vstack(unindexed_chunk_embeddings) # noqa: N806 + # Index the unindexed chunks. + with tqdm( + total=len(unindexed_chunks), + desc="Indexing chunks", + unit="chunk", + dynamic_ncols=True, + ) as pbar: + # Fit or update the ANN index. + if len(chunk_ids) == 0: + nndescent = NNDescent(X, metric=config.vector_search_index_metric) + else: + nndescent = index_metadata.metadata_["index"] + nndescent.update(X) + # Prepare the ANN index so it can to handle query vectors not in the training set. + nndescent.prepare() + # Update the index metadata and mark it as dirty by recreating the dictionary. + index_metadata.metadata_ = { + **index_metadata.metadata_, + "index": nndescent, + "chunk_ids": chunk_ids + [c.id for c in unindexed_chunks], + "chunk_sizes": chunk_sizes + [len(em) for em in unindexed_chunk_embeddings], + } + # Store the updated vector search chunk index. + session.add(index_metadata) + session.commit() + pbar.update(len(unindexed_chunks)) diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index bd3fed9..bedb9a7 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -1,16 +1,16 @@ """Compute and update an optimal query adapter.""" import numpy as np -from sqlmodel import Session, select +from sqlmodel import Session, col, select from tqdm.auto import tqdm from raglite._config import RAGLiteConfig -from raglite._database import Chunk, Eval, VectorSearchChunkIndex, create_database_engine +from raglite._database import Chunk, Eval, IndexMetadata, create_database_engine from raglite._embed import embed_strings from raglite._search import vector_search -def update_query_adapter( # noqa: C901, PLR0915 +def update_query_adapter( # noqa: PLR0915 *, max_triplets: int = 4096, max_triplets_per_eval: int = 64, @@ -63,8 +63,10 @@ def update_query_adapter( # noqa: C901, PLR0915 C := 5% * A, the optimal ฮฑ is then given by ฮฑA + (1 - ฮฑ)B = C => ฮฑ = (B - C) / (B - A). """ config = config or RAGLiteConfig() - config_no_query_adapter = RAGLiteConfig(**{**config.__dict__, "enable_query_adapter": False}) - engine = create_database_engine(config.db_url) + config_no_query_adapter = RAGLiteConfig( + **{**config.__dict__, "vector_search_query_adapter": False} + ) + engine = create_database_engine(config) with Session(engine) as session: # Get random evals from the database. evals = session.exec( @@ -83,34 +85,25 @@ def update_query_adapter( # noqa: C901, PLR0915 # Embed the question. question_embedding = embed_strings([eval_.question], config=config) # Retrieve chunks that would be used to answer the question. - chunk_rowids, _ = vector_search( + chunk_ids, _ = vector_search( question_embedding, num_results=optimize_top_k, config=config_no_query_adapter ) - retrieved_chunks = [ - session.exec(select(Chunk).offset(chunk_rowid - 1)).first() - for chunk_rowid in chunk_rowids - ] + retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval. num_triplets = 0 for i, retrieved_chunk in enumerate(retrieved_chunks): - # Raise an error if the retrieved chunk is None. - if retrieved_chunk is None: - error_message = ( - f"The chunk with rowid {chunk_rowids[i]} is missing from the database." - ) - raise ValueError(error_message) # Select irrelevant chunks. if retrieved_chunk.id not in eval_.chunk_ids: # Look up all positive chunks (each represented by the mean of its multi-vector # embedding) that are ranked lower than this negative one (represented by the # embedding in the multi-vector embedding that best matches the query). p_mean = [ - np.mean(chunk.multi_vector_embedding, axis=0, keepdims=True) + np.mean(chunk.embedding_matrix, axis=0, keepdims=True) for chunk in retrieved_chunks[i + 1 :] if chunk is not None and chunk.id in eval_.chunk_ids ] - n_top = retrieved_chunk.multi_vector_embedding[ - np.argmax(retrieved_chunk.multi_vector_embedding @ question_embedding.T), + n_top = retrieved_chunk.embedding_matrix[ + np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T), np.newaxis, :, ] @@ -159,9 +152,7 @@ def update_query_adapter( # noqa: C901, PLR0915 error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}" raise ValueError(error_message) # Store the optimal query adapter in the database. - vector_search_chunk_index = session.get( - VectorSearchChunkIndex, config.vector_search_index_id - ) or VectorSearchChunkIndex(id=config.vector_search_index_id) - vector_search_chunk_index.query_adapter = A_star - session.add(vector_search_chunk_index) + index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") + index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star} + session.add(index_metadata) session.commit() diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 0235f06..08be304 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -11,7 +11,7 @@ def rag( *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), - search: Callable[[str], tuple[list[int], list[float]]] = hybrid_search, + search: Callable[[str], tuple[list[str], list[float]]] = hybrid_search, config: RAGLiteConfig | None = None, ) -> Iterator[str]: """Retrieval-augmented generation.""" @@ -22,8 +22,8 @@ def rag( max_tokens_per_context *= 1 + len(context_neighbors or []) max_contexts = min(max_contexts, max_tokens // max_tokens_per_context) # Retrieve relevant contexts. - chunk_rowids, _ = search(prompt, num_results=max_contexts, config=config) # type: ignore[call-arg] - segments = retrieve_segments(chunk_rowids, neighbors=context_neighbors) + chunk_ids, _ = search(prompt, num_results=max_contexts, config=config) # type: ignore[call-arg] + segments = retrieve_segments(chunk_ids, neighbors=context_neighbors) # Respond with an LLM. contexts = "\n\n".join( f'\n{segment.strip()}\n' diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 73c2dd8..2f3a5ba 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -3,40 +3,19 @@ import re import string from collections import defaultdict -from functools import lru_cache from itertools import groupby -from typing import Annotated, ClassVar +from typing import Annotated, ClassVar, cast import numpy as np from pydantic import BaseModel, Field -from pynndescent import NNDescent +from sqlalchemy.engine import make_url from sqlmodel import Session, select, text from raglite._config import RAGLiteConfig -from raglite._database import Chunk, VectorSearchChunkIndex, create_database_engine +from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine from raglite._embed import embed_strings from raglite._extract import extract_with_llm -from raglite._typing import FloatMatrix, IntVector - - -@lru_cache(maxsize=1) -def _vector_search_chunk_index( - config: RAGLiteConfig, -) -> tuple[NNDescent, IntVector, FloatMatrix | None]: - engine = create_database_engine(config.db_url) - with Session(engine) as session: - vector_search_chunk_index = session.get( - VectorSearchChunkIndex, config.vector_search_index_id - ) - if vector_search_chunk_index is None: - error_message = "First run `update_vector_index()` to create a vector search index." - raise ValueError(error_message) - index = vector_search_chunk_index.index - chunk_size_cumsum = np.cumsum( - np.asarray(vector_search_chunk_index.chunk_sizes, dtype=np.intp) - ) - query_adapter = vector_search_chunk_index.query_adapter - return index, chunk_size_cumsum, query_adapter +from raglite._typing import FloatMatrix def vector_search( @@ -44,101 +23,148 @@ def vector_search( *, num_results: int = 3, config: RAGLiteConfig | None = None, -) -> tuple[list[int], list[float]]: +) -> tuple[list[str], list[float]]: """Search chunks using ANN vector search.""" - # Retrieve the index from the database. + # Read the config. config = config or RAGLiteConfig() - index, chunk_size_cumsum, Q = _vector_search_chunk_index(config) # noqa: N806 + db_backend = make_url(config.db_url).get_backend_name() + # Get the index metadata (including the query adapter, and in the case of SQLite, the index). + index_metadata = IndexMetadata.get("default", config=config) # Embed the prompt. prompt_embedding = ( - embed_strings([prompt], config=config) + embed_strings([prompt], config=config)[0, :] if isinstance(prompt, str) - else np.reshape(prompt, (1, -1)) + else np.ravel(prompt) ) - # Apply the query adapter. - if config.enable_query_adapter and Q is not None: - prompt_embedding = (Q @ prompt_embedding[0, :])[np.newaxis, :].astype(config.embedder_dtype) - # Find the neighbouring multi-vector indices. - multi_vector_indices, cosine_distance = index.query(prompt_embedding, k=8 * num_results) - cosine_similarity = 1 - cosine_distance[0, :] - # Transform the multi-vector indices into chunk rowids. - chunk_rowids = np.searchsorted(chunk_size_cumsum, multi_vector_indices[0, :], side="right") + 1 - # Score each unique chunk rowid as the mean cosine similarity of its multi-vector hits. - # Chunk rowids with fewer hits are padded with the minimum cosine similarity of the result set. - unique_chunk_rowids, counts = np.unique(chunk_rowids, return_counts=True) + # Apply the query adapter to the prompt embedding. + Q = index_metadata.get("query_adapter") # noqa: N806 + if config.vector_search_query_adapter and Q is not None: + prompt_embedding = (Q @ prompt_embedding).astype(prompt_embedding.dtype) + # Search for the multi-vector chunk embeddings that are most similar to the prompt embedding. + if db_backend == "postgresql": + # Check that the selected metric is supported by pgvector. + metrics = {"cosine": "<=>", "dot": "<#>", "euclidean": "<->", "l1": "<+>", "l2": "<->"} + if config.vector_search_index_metric not in metrics: + error_message = f"Unsupported metric {config.vector_search_index_metric}." + raise ValueError(error_message) + # With pgvector, we can obtain the nearest neighbours and similarities with a single query. + engine = create_database_engine(config) + with Session(engine) as session: + distance_func = getattr( + ChunkEmbedding.embedding, f"{config.vector_search_index_metric}_distance" + ) + distance = distance_func(prompt_embedding).label("distance") + results = session.exec( + select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results) + ) + chunk_ids_, distance = zip(*results, strict=True) + chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance) + elif db_backend == "sqlite": + # Load the NNDescent index. + index = index_metadata.get("index") + ids = np.asarray(index_metadata.get("chunk_ids")) + cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes"))) + # Find the neighbouring multi-vector indices. + from pynndescent import NNDescent + + multi_vector_indices, distance = cast(NNDescent, index).query( + prompt_embedding[np.newaxis, :], k=8 * num_results + ) + similarity = 1 - distance[0, :] + # Transform the multi-vector indices into chunk indices, and then to chunk ids. + chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1 + chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices]) + # Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with + # fewer hits are padded with the minimum similarity of the result set. + unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True) score = np.full( - (len(unique_chunk_rowids), np.max(counts)), - np.min(cosine_similarity), - dtype=cosine_similarity.dtype, + (len(unique_chunk_ids), np.max(counts)), np.min(similarity), dtype=similarity.dtype ) - for i, (unique_chunk_rowid, count) in enumerate(zip(unique_chunk_rowids, counts, strict=True)): - score[i, :count] = cosine_similarity[chunk_rowids == unique_chunk_rowid] - pooled_cosine_similarity = np.mean(score, axis=1) - # Sort the chunk rowids by adjusted cosine similarity. - sorted_indices = np.argsort(pooled_cosine_similarity)[::-1] - unique_chunk_rowids = unique_chunk_rowids[sorted_indices][:num_results] - pooled_cosine_similarity = pooled_cosine_similarity[sorted_indices][:num_results] - return unique_chunk_rowids.tolist(), pooled_cosine_similarity.tolist() - - -def _prompt_to_fts_query(prompt: str) -> str: - """Convert a prompt to an FTS5 query.""" - # https://www.sqlite.org/fts5.html#full_text_query_syntax - prompt = re.sub(f"[{re.escape(string.punctuation)}]", "", prompt) - fts_query = " OR ".join(prompt.split()) - return fts_query + for i, (unique_chunk_id, count) in enumerate(zip(unique_chunk_ids, counts, strict=True)): + score[i, :count] = similarity[chunk_ids == unique_chunk_id] + pooled_similarity = np.mean(score, axis=1) + # Sort the chunk ids by their adjusted similarity. + sorted_indices = np.argsort(pooled_similarity)[::-1] + unique_chunk_ids = unique_chunk_ids[sorted_indices][:num_results] + pooled_similarity = pooled_similarity[sorted_indices][:num_results] + return unique_chunk_ids.tolist(), pooled_similarity.tolist() def keyword_search( prompt: str, *, num_results: int = 3, config: RAGLiteConfig | None = None -) -> tuple[list[int], list[float]]: +) -> tuple[list[str], list[float]]: """Search chunks using BM25 keyword search.""" + # Read the config. config = config or RAGLiteConfig() - engine = create_database_engine(config.db_url) + db_backend = make_url(config.db_url).get_backend_name() + # Connect to the database. + engine = create_database_engine(config) with Session(engine) as session: - # Perform the full-text search query using the BM25 ranking. - statement = text( - "SELECT chunk.rowid, bm25(fts_chunk_index) FROM chunk JOIN fts_chunk_index ON chunk.rowid = fts_chunk_index.rowid WHERE fts_chunk_index MATCH :match ORDER BY rank LIMIT :limit;" - ) - results = session.execute( - statement, params={"match": _prompt_to_fts_query(prompt), "limit": num_results} - ) - # Unpack the results and make FTS5's negative BM25 scores [1] positive. - # https://www.sqlite.org/fts5.html#the_bm25_function - chunk_rowids, bm25_score = zip(*results, strict=True) - chunk_rowids, bm25_score = list(chunk_rowids), [-s for s in bm25_score] # type: ignore[assignment] - return chunk_rowids, bm25_score # type: ignore[return-value] + if db_backend == "postgresql": + # Convert the prompt to a tsquery [1]. + # [1] https://www.postgresql.org/docs/current/textsearch-controls.html + prompt_escaped = re.sub(r"[&|!():<>\"]", " ", prompt) + tsv_query = " | ".join(prompt_escaped.split()) + # Perform full-text search with tsvector. + statement = text(""" + SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score + FROM chunk + WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query) + ORDER BY score DESC + LIMIT :limit; + """) + results = session.execute(statement, params={"query": tsv_query, "limit": num_results}) + elif db_backend == "sqlite": + # Convert the prompt to an FTS5 query [1]. + # [1] https://www.sqlite.org/fts5.html#full_text_query_syntax + prompt_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", prompt) + fts5_query = " OR ".join(prompt_escaped.split()) + # Perform full-text search with FTS5. In FTS5, BM25 scores are negative [1], so we + # negate them to make them positive. + # [1] https://www.sqlite.org/fts5.html#the_bm25_function + statement = text(""" + SELECT chunk.id as chunk_id, -bm25(fts_chunk_index) as score + FROM chunk JOIN fts_chunk_index ON chunk.rowid = fts_chunk_index.rowid + WHERE fts_chunk_index MATCH :match + ORDER BY score DESC + LIMIT :limit; + """) + results = session.execute(statement, params={"match": fts5_query, "limit": num_results}) + # Unpack the results. + chunk_ids, keyword_score = zip(*results, strict=True) + chunk_ids, keyword_score = list(chunk_ids), list(keyword_score) # type: ignore[assignment] + return chunk_ids, keyword_score # type: ignore[return-value] def reciprocal_rank_fusion( - rankings: list[list[int]], *, k: int = 60 -) -> tuple[list[int], list[float]]: + rankings: list[list[str]], *, k: int = 60 +) -> tuple[list[str], list[float]]: """Reciprocal Rank Fusion.""" # Compute the RRF score. - rowids = {rowid for ranking in rankings for rowid in ranking} - rowid_score: defaultdict[int, float] = defaultdict(float) + chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking} + chunk_id_score: defaultdict[str, float] = defaultdict(float) for ranking in rankings: - rowid_index = {rowid: i for i, rowid in enumerate(ranking)} - for rowid in rowids: - rowid_score[rowid] += 1 / (k + rowid_index.get(rowid, len(rowid_index))) + chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)} + for chunk_id in chunk_ids: + chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index))) # Rank RRF results according to descending RRF score. - rrf_rowids, rrf_score = zip( - *sorted(rowid_score.items(), key=lambda x: x[1], reverse=True), strict=True + rrf_chunk_ids, rrf_score = zip( + *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True ) - return list(rrf_rowids), list(rrf_score) + return list(rrf_chunk_ids), list(rrf_score) def hybrid_search( prompt: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None -) -> tuple[list[int], list[float]]: +) -> tuple[list[str], list[float]]: """Search chunks by combining ANN vector search with BM25 keyword search.""" # Run both searches. chunks_vector, _ = vector_search(prompt, num_results=num_rerank, config=config) - chunks_bm25, _ = keyword_search(prompt, num_results=num_rerank, config=config) + chunks_keyword, _ = keyword_search(prompt, num_results=num_rerank, config=config) # Combine the results with Reciprocal Rank Fusion (RRF). - chunk_rowids, hybrid_score = reciprocal_rank_fusion([chunks_vector, chunks_bm25]) - chunk_rowids, hybrid_score = chunk_rowids[:num_results], hybrid_score[:num_results] - return chunk_rowids, hybrid_score + chunk_ids, hybrid_score = reciprocal_rank_fusion([chunks_vector, chunks_keyword]) + chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results] + return chunk_ids, hybrid_score def fusion_search( @@ -147,7 +173,7 @@ def fusion_search( num_results: int = 5, num_rerank: int = 100, config: RAGLiteConfig | None = None, -) -> tuple[list[int], list[float]]: +) -> tuple[list[str], list[float]]: """Search for chunks with the RAG-Fusion method.""" class QueriesResponse(BaseModel): @@ -172,31 +198,31 @@ class QueriesResponse(BaseModel): for query in queries: # Run both searches. chunks_vector, _ = vector_search(query, num_results=num_rerank, config=config) - chunks_bm25, _ = keyword_search(query, num_results=num_rerank, config=config) + chunks_keyword, _ = keyword_search(query, num_results=num_rerank, config=config) # Add results to the rankings. rankings.append(chunks_vector) - rankings.append(chunks_bm25) + rankings.append(chunks_keyword) # Combine all the search results with Reciprocal Rank Fusion (RRF). - chunk_rowids, fusion_score = reciprocal_rank_fusion(rankings) - chunk_rowids, fusion_score = chunk_rowids[:num_results], fusion_score[:num_results] - return chunk_rowids, fusion_score + chunk_ids, fusion_score = reciprocal_rank_fusion(rankings) + chunk_ids, fusion_score = chunk_ids[:num_results], fusion_score[:num_results] + return chunk_ids, fusion_score def retrieve_segments( - chunk_rowids: list[int], + chunk_ids: list[str], *, neighbors: tuple[int, ...] | None = (-1, 1), config: RAGLiteConfig | None = None, ) -> list[str]: - """Group the chunks into contiguous segments and retrieve them.""" - # Get the chunks by rowid and extend them with their neighbours. + """Group chunks into contiguous segments and retrieve them.""" + # Get the chunks and extend them with their neighbours. config = config or RAGLiteConfig() chunks = set() - engine = create_database_engine(config.db_url) + engine = create_database_engine(config) with Session(engine) as session: - for chunk_rowid in chunk_rowids: - # Get the chunk at the given rowid. - chunk = session.exec(select(Chunk).offset(chunk_rowid - 1)).first() + for chunk_id in chunk_ids: + # Get the chunk by id. + chunk = session.get(Chunk, chunk_id) if chunk is not None: chunks.add(chunk) # Extend the chunk with its neighbouring chunks. diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py index c80f0e4..adda9d0 100644 --- a/src/raglite/_typing.py +++ b/src/raglite/_typing.py @@ -1,8 +1,137 @@ """RAGLite typing.""" +import io +import pickle +from collections.abc import Callable from typing import Any import numpy as np +from sqlalchemy.engine import Dialect +from sqlalchemy.sql.operators import Operators +from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] +FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] IntVector = np.ndarray[tuple[int], np.dtype[np.intp]] + + +class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): + """A NumPy array column type for SQLAlchemy.""" + + impl = LargeBinary + + def process_bind_param( + self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect + ) -> bytes | None: + """Convert a NumPy array to bytes.""" + if value is None: + return None + buffer = io.BytesIO() + np.save(buffer, value, allow_pickle=False, fix_imports=False) + return buffer.getvalue() + + def process_result_value( + self, value: bytes | None, dialect: Dialect + ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: + """Convert bytes to a NumPy array.""" + if value is None: + return None + return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return] + + +class PickledObject(TypeDecorator[object]): + """A pickled object column type for SQLAlchemy.""" + + impl = LargeBinary + + def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None: + """Convert a Python object to bytes.""" + if value is None: + return None + return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False) + + def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None: + """Convert bytes to a Python object.""" + if value is None: + return None + return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301 + + +class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]): + """A mixin that provides comparison operators for halfvecs.""" + + def cosine_distance(self, other: FloatVector) -> Operators: + """Compute the cosine distance.""" + return self.op("<=>", return_type=Float)(other) + + def dot_distance(self, other: FloatVector) -> Operators: + """Compute the dot product distance.""" + return self.op("<#>", return_type=Float)(other) + + def euclidean_distance(self, other: FloatVector) -> Operators: + """Compute the Euclidean distance.""" + return self.op("<->", return_type=Float)(other) + + def l1_distance(self, other: FloatVector) -> Operators: + """Compute the L1 distance.""" + return self.op("<+>", return_type=Float)(other) + + def l2_distance(self, other: FloatVector) -> Operators: + """Compute the L2 distance.""" + return self.op("<->", return_type=Float)(other) + + +class HalfVec(UserDefinedType[FloatVector]): + """A PostgreSQL half-precision vector column type for SQLAlchemy.""" + + cache_ok = True # HalfVec is immutable. + + def __init__(self, dim: int | None = None) -> None: + super().__init__() + self.dim = dim + + def get_col_spec(self, **kwargs: Any) -> str: + return f"halfvec({self.dim})" + + def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]: + """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters.""" + + def process(value: FloatVector | None) -> str | None: + return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None + + return process + + def result_processor( + self, dialect: Dialect, coltype: Any + ) -> Callable[[str | None], FloatVector | None]: + """Process PostgreSQL halfvec format to NumPy ndarray.""" + + def process(value: str | None) -> FloatVector | None: + if value is None: + return None + return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16) + + return process + + class comparator_factory(HalfVecComparatorMixin): # noqa: N801 + ... + + +class Embedding(TypeDecorator[FloatVector]): + """An embedding column type for SQLAlchemy.""" + + cache_ok = True # Embedding is immutable. + + impl = NumpyArray + + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]: + if dialect.name == "postgresql": + return dialect.type_descriptor(HalfVec(self.dim)) + return dialect.type_descriptor(NumpyArray()) + + class comparator_factory(HalfVecComparatorMixin): # noqa: N801 + ... diff --git a/tests/conftest.py b/tests/conftest.py index a28e805..c5b0fa9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,36 @@ """Fixtures for the tests.""" +import socket + import pytest from llama_cpp import Llama +from sqlalchemy import create_engine, text from raglite import RAGLiteConfig -@pytest.fixture -def simple_config() -> RAGLiteConfig: - """Create a lightweight in-memory config for testing.""" +def is_postgres_running() -> bool: + """Check if PostgreSQL is running.""" + try: + with socket.create_connection(("postgres", 5432), timeout=1): + return True + except OSError: + return False + + +@pytest.fixture( + scope="module", + params=[ + pytest.param("sqlite:///:memory:", id="SQLite"), + pytest.param( + "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres", + id="PostgreSQL", + marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"), + ), + ], +) +def simple_config(request: pytest.FixtureRequest) -> RAGLiteConfig: + """Create a lightweight in-memory config for testing SQLite and PostgreSQL.""" # Use a lightweight embedder. embedder = Llama.from_pretrained( repo_id="ChristianAzinn/snowflake-arctic-embed-xs-gguf", # https://github.com/Snowflake-Labs/arctic-embed @@ -18,8 +40,18 @@ def simple_config() -> RAGLiteConfig: verbose=False, embedding=True, ) - # Use an in-memory SQLite database. - db_url = "sqlite:///:memory:" - # Create the config. - config = RAGLiteConfig(embedder=embedder, db_url=db_url) - return config + # Yield a SQLite config. + if "sqlite" in request.param: + sqlite_config = RAGLiteConfig(embedder=embedder, db_url=request.param) + return sqlite_config + # Yield a PostgreSQL config. + if "postgresql" in request.param: + engine = create_engine(request.param, isolation_level="AUTOCOMMIT") + with engine.connect() as conn: + conn.execute(text("DROP DATABASE IF EXISTS raglite_test")) + conn.execute(text("CREATE DATABASE raglite_test")) + postgresql_config = RAGLiteConfig( + embedder=embedder, db_url=request.param.replace("/postgres", "/raglite_test") + ) + return postgresql_config + raise ValueError diff --git a/tests/test_basic.py b/tests/test_basic.py index afa569e..e88695a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -12,11 +12,11 @@ def test_insert_index_search(simple_config: RAGLiteConfig) -> None: insert_document(doc_path, config=simple_config) # Search for a query. query = "What does it mean for two events to be simultaneous?" - chunk_rowids, scores = hybrid_search(query, config=simple_config) - assert len(chunk_rowids) == len(scores) - assert all(isinstance(rowid, int) for rowid in chunk_rowids) + chunk_ids, scores = hybrid_search(query, config=simple_config) + assert len(chunk_ids) == len(scores) + assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids) assert all(isinstance(score, float) for score in scores) # Group the chunks into segments and retrieve them. - segments = retrieve_segments(chunk_rowids, neighbors=None, config=simple_config) + segments = retrieve_segments(chunk_ids, neighbors=None, config=simple_config) assert all(isinstance(segment, str) for segment in segments) - assert "Definition of Simultaneity" in segments[0] + segments[1] + assert "Definition of Simultaneity" in "".join(segments[:2])