diff --git a/.cruft.json b/.cruft.json
index c5664de..a44ca99 100644
--- a/.cruft.json
+++ b/.cruft.json
@@ -1,15 +1,15 @@
{
- "template": "https://github.com/radix-ai/poetry-cookiecutter",
+ "template": "https://github.com/superlinear-ai/poetry-cookiecutter",
"commit": "a969f1d182ec39d7d27ccb1116cf60ba736adcfa",
"checkout": null,
"context": {
"cookiecutter": {
"project_type": "package",
"project_name": "RAGLite",
- "project_description": "A RAG extension for SQLite.",
- "project_url": "https://github.com/radix-ai/raglite",
+ "project_description": "A Python package for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL.",
+ "project_url": "https://github.com/superlinear-ai/raglite",
"author_name": "Laurent Sorber",
- "author_email": "laurent@radix.ai",
+ "author_email": "laurent@superlinear.eu",
"python_version": "3.10",
"development_environment": "strict",
"with_conventional_commits": "1",
@@ -22,8 +22,8 @@
"__docstring_style": "NumPy",
"__project_name_kebab_case": "raglite",
"__project_name_snake_case": "raglite",
- "_template": "https://github.com/radix-ai/poetry-cookiecutter"
+ "_template": "https://github.com/superlinear-ai/poetry-cookiecutter"
}
},
"directory": null
-}
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index 47f0a4f..29af5f0 100644
--- a/README.md
+++ b/README.md
@@ -1,22 +1,23 @@
-[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/radix-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=812973394&skip_quickstart=true)
+[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/superlinear-ai/raglite) [![Open in GitHub Codespaces](https://img.shields.io/static/v1?label=GitHub%20Codespaces&message=Open&color=blue&logo=github)](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=812973394&skip_quickstart=true)
# ๐งต RAGLite
-RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with SQLite.
+RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with PostgreSQL or SQLite.
## Features
1. โค๏ธ Only lightweight and permissive open source dependencies (e.g., no [PyTorch](https://github.com/pytorch/pytorch), [LangChain](https://github.com/langchain-ai/langchain), or [PyMuPDF](https://github.com/pymupdf/PyMuPDF))
-2. ๐ Fully local RAG with [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as an LLM provider and [SQLite](https://github.com/sqlite/sqlite) as a local database
-3. ๐ Acceleration with Metal on macOS and with CUDA on Linux and Windows
-4. ๐ PDF to Markdown conversion on top of [pdftext](https://github.com/VikParuchuri/pdftext) and [pypdfium2](https://github.com/pypdfium2-team/pypdfium2)
-5. โ๏ธ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming)
-6. ๐ Markdown-based [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag)
-7. ๐ Combined sentence-level and chunk-level matching with [multi-vector chunk retrieval](https://python.langchain.com/v0.2/docs/how_to/multi_vector/)
-8. ๐ Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem)
-9. ๐ [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) that combines [SQLite's BM25 full-text search](https://sqlite.org/fts5.html) with [PyNNDescent's ANN vector search](https://github.com/lmcinnes/pynndescent)
-10. โ๏ธ Optional support for conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc)
-11. โ
Optional support for evaluation of retrieval and generation with [Ragas](https://github.com/explodinggradients/ragas)
+2. ๐ง Your choice of local LLM with [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)
+3. ๐พ Your choice of [PostgreSQL](https://github.com/postgres/postgres) or [SQLite](https://github.com/sqlite/sqlite) as a full-text & vector search database
+4. ๐ Acceleration with Metal on macOS and with CUDA on Linux and Windows
+5. ๐ PDF to Markdown conversion on top of [pdftext](https://github.com/VikParuchuri/pdftext) and [pypdfium2](https://github.com/pypdfium2-team/pypdfium2)
+6. โ๏ธ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming)
+7. ๐ Markdown-based [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag)
+8. ๐ Combined sentence-level and chunk-level matching with [multi-vector chunk retrieval](https://python.langchain.com/v0.2/docs/how_to/multi_vector/)
+9. ๐ Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem)
+10. ๐ [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) that combines the database's built-in full-text search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html) in PostgreSQL, [FTS5](https://www.sqlite.org/fts5.html) in SQLite) with their native vector search extensions ([pgvector](https://github.com/pgvector/pgvector) in PostgreSQL, [sqlite-vec](https://github.com/asg017/sqlite-vec) in SQLite)
+11. โ๏ธ Optional support for conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc)
+12. โ
Optional support for evaluation of retrieval and generation performance with [Ragas](https://github.com/explodinggradients/ragas)
## Installing
@@ -145,7 +146,7 @@ evaluation_df = evaluate(answered_evals_df, config=my_config)
The following development environments are supported:
1. โญ๏ธ _GitHub Codespaces_: click on _Code_ and select _Create codespace_ to start a Dev Container with [GitHub Codespaces](https://github.com/features/codespaces).
-1. โญ๏ธ _Dev Container (with container volume)_: click on [Open in Dev Containers](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/radix-ai/raglite) to clone this repository in a container volume and create a Dev Container with VS Code.
+1. โญ๏ธ _Dev Container (with container volume)_: click on [Open in Dev Containers](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/superlinear-ai/raglite) to clone this repository in a container volume and create a Dev Container with VS Code.
1. _Dev Container_: clone this repository, open it with VS Code, and run Ctrl/โ + โง + P โ _Dev Containers: Reopen in Container_.
1. _PyCharm_: clone this repository, open it with PyCharm, and [configure Docker Compose as a remote interpreter](https://www.jetbrains.com/help/pycharm/using-docker-compose-as-a-remote-interpreter.html#docker-compose-remote) with the `dev` service.
1. _Terminal_: clone this repository, open it with your terminal, and run `docker compose up --detach dev` to start a Dev Container in the background, and then run `docker compose exec dev zsh` to open a shell prompt in the Dev Container.
diff --git a/docker-compose.yml b/docker-compose.yml
index 0c889b3..e2eb595 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -12,6 +12,10 @@ services:
GID: ${GID:-1000}
environment:
- POETRY_PYPI_TOKEN_PYPI
+ depends_on:
+ - postgres
+ networks:
+ - raglite-network
volumes:
- ..:/workspaces
- command-history-volume:/home/user/.history/
@@ -21,15 +25,14 @@ services:
stdin_open: true
tty: true
entrypoint: []
- command:
- [
- "sh",
- "-c",
- "sudo chown user $$SSH_AUTH_SOCK && cp --update /opt/build/poetry/poetry.lock /workspaces/raglite/ && mkdir -p /workspaces/raglite/.git/hooks/ && cp --update /opt/build/git/* /workspaces/raglite/.git/hooks/ && zsh"
- ]
+ command: [ "sh", "-c", "sudo chown user $$SSH_AUTH_SOCK && cp --update /opt/build/poetry/poetry.lock /workspaces/raglite/ && mkdir -p /workspaces/raglite/.git/hooks/ && cp --update /opt/build/git/* /workspaces/raglite/.git/hooks/ && zsh" ]
environment:
- POETRY_PYPI_TOKEN_PYPI
- SSH_AUTH_SOCK=/run/host-services/ssh-auth.sock
+ depends_on:
+ - postgres
+ networks:
+ - raglite-network
volumes:
- ~/.gitconfig:/etc/gitconfig
- ~/.ssh/known_hosts:/home/user/.ssh/known_hosts
@@ -37,5 +40,19 @@ services:
profiles:
- dev
+ postgres:
+ image: pgvector/pgvector:pg16
+ environment:
+ POSTGRES_USER: raglite_user
+ POSTGRES_PASSWORD: raglite_password
+ networks:
+ - raglite-network
+ tmpfs:
+ - /var/lib/postgresql/data
+
+networks:
+ raglite-network:
+ driver: bridge
+
volumes:
command-history-volume:
diff --git a/poetry.lock b/poetry.lock
index f84bfeb..e4d26af 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -210,6 +210,17 @@ types-python-dateutil = ">=2.8.10"
doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"]
test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"]
+[[package]]
+name = "asn1crypto"
+version = "1.5.1"
+description = "Fast ASN.1 parser and serializer with definitions for private keys, public keys, certificates, CRL, OCSP, CMS, PKCS#3, PKCS#7, PKCS#8, PKCS#12, PKCS#5, X.509 and TSP"
+optional = false
+python-versions = "*"
+files = [
+ {file = "asn1crypto-1.5.1-py2.py3-none-any.whl", hash = "sha256:db4e40728b728508912cbb3d44f19ce188f218e9eba635821bb4b68564f8fd67"},
+ {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"},
+]
+
[[package]]
name = "asttokens"
version = "2.4.1"
@@ -3437,6 +3448,21 @@ files = [
[package.dependencies]
ptyprocess = ">=0.5"
+[[package]]
+name = "pg8000"
+version = "1.31.2"
+description = "PostgreSQL interface library"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pg8000-1.31.2-py3-none-any.whl", hash = "sha256:436c771ede71af4d4c22ba867a30add0bc5c942d7ab27fadbb6934a487ecc8f6"},
+ {file = "pg8000-1.31.2.tar.gz", hash = "sha256:1ea46cf09d8eca07fe7eaadefd7951e37bee7fabe675df164f1a572ffb300876"},
+]
+
+[package.dependencies]
+python-dateutil = ">=2.8.2"
+scramp = ">=1.4.5"
+
[[package]]
name = "pillow"
version = "10.4.0"
@@ -4745,6 +4771,20 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+[[package]]
+name = "scramp"
+version = "1.4.5"
+description = "An implementation of the SCRAM protocol."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "scramp-1.4.5-py3-none-any.whl", hash = "sha256:50e37c464fc67f37994e35bee4151e3d8f9320e9c204fca83a5d313c121bbbe7"},
+ {file = "scramp-1.4.5.tar.gz", hash = "sha256:be3fbe774ca577a7a658117dca014e5d254d158cecae3dd60332dfe33ce6d78e"},
+]
+
+[package.dependencies]
+asn1crypto = ">=1.5.1"
+
[[package]]
name = "setuptools"
version = "72.2.0"
@@ -5900,4 +5940,4 @@ ragas = ["ragas"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
-content-hash = "a65d8c1115521aa69301e35158c5d6f6dc0618f745c6373c86dd677a853e320c"
+content-hash = "8178f5e494205788ce704f31e778aea16ce12d2e055d704c1c706fa2c1e8e5e9"
diff --git a/pyproject.toml b/pyproject.toml
index ae6fc6c..b8bc88b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,10 +5,10 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] # https://python-poetry.org/docs/pyproject/
name = "raglite"
version = "0.0.0"
-description = "A Python package for Retrieval-Augmented Generation (RAG) with SQLite."
-authors = ["Laurent Sorber "]
+description = "A Python package for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL."
+authors = ["Laurent Sorber "]
readme = "README.md"
-repository = "https://github.com/radix-ai/raglite"
+repository = "https://github.com/superlinear-ai/raglite"
[tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/
bump_message = "bump(release): v$current_version โ v$new_version"
@@ -51,6 +51,7 @@ pydantic = ">=2.7.0"
# Approximate Nearest Neighbors:
pynndescent = ">=0.5.12"
# Storage:
+pg8000 = ">=1.31.2"
sqlmodel-slim = ">=0.0.18"
# Progress:
tqdm = ">=4.66.0"
diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py
index c77f710..a5ef385 100644
--- a/src/raglite/__init__.py
+++ b/src/raglite/__init__.py
@@ -2,7 +2,7 @@
from raglite._config import RAGLiteConfig
from raglite._eval import answer_evals, evaluate, insert_evals
-from raglite._index import insert_document, update_vector_index
+from raglite._index import insert_document
from raglite._query_adapter import update_query_adapter
from raglite._rag import rag
from raglite._search import (
@@ -18,7 +18,6 @@
"RAGLiteConfig",
# Index
"insert_document",
- "update_vector_index",
# Search
"fusion_search",
"hybrid_search",
diff --git a/src/raglite/_config.py b/src/raglite/_config.py
index e6c74c5..a76daf8 100644
--- a/src/raglite/_config.py
+++ b/src/raglite/_config.py
@@ -3,8 +3,6 @@
from dataclasses import dataclass, field
from functools import lru_cache
-import numpy as np
-import numpy.typing as npt
from llama_cpp import Llama, LlamaRAMCache, llama_supports_gpu_offload # type: ignore[attr-defined]
from sqlalchemy.engine import URL
@@ -12,14 +10,14 @@
@lru_cache(maxsize=1)
def default_llm() -> Llama:
"""Get default LLM."""
- # Select the best available LLM for the given accelerator.
+ # Select the best available LLM for the given accelerator:
+ # - Llama-3.1-8B-instruct on GPU.
+ # - Phi-3.5-mini-instruct on CPU.
if llama_supports_gpu_offload():
- # Llama-3.1-8B-instruct on GPU.
repo_id = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" # https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct
filename = "*Q4_K_M.gguf"
n_ctx = 8192
else:
- # Phi-3.1-mini-128k-instruct on CPU.
repo_id = "bartowski/Phi-3.5-mini-instruct-GGUF" # https://huggingface.co/microsoft/Phi-3.5-mini-instruct
filename = "*Q4_K_M.gguf"
n_ctx = 4096
@@ -61,7 +59,6 @@ class RAGLiteConfig:
# Embedder config used for indexing.
embedder: Llama = field(default_factory=default_embedder)
embedder_batch_size: int = 128
- embedder_dtype: npt.DTypeLike = np.float16
embedder_normalize: bool = True
sentence_embedding_weight: float = 0.5 # Between 0 (chunk level) and 1 (sentence level).
# Chunker config used to partition documents into chunks.
@@ -70,7 +67,5 @@ class RAGLiteConfig:
# Database config.
db_url: str | URL = "sqlite:///raglite.sqlite"
# Vector search config.
- vector_search_index_id: str = "default"
vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine".
- # Query adapter config.
- enable_query_adapter: bool = True
+ vector_search_query_adapter: bool = True
diff --git a/src/raglite/_database.py b/src/raglite/_database.py
index 2bedcd8..b4bafa4 100644
--- a/src/raglite/_database.py
+++ b/src/raglite/_database.py
@@ -1,7 +1,6 @@
-"""SQLite tables for RAGLite."""
+"""PostgreSQL or SQLite database tables for RAGLite."""
-import io
-import pickle
+import datetime
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
@@ -9,12 +8,22 @@
import numpy as np
from markdown_it import MarkdownIt
-from pynndescent import NNDescent
-from sqlalchemy.engine import URL, Dialect, Engine, make_url
-from sqlalchemy.types import LargeBinary, TypeDecorator
-from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text
-
-from raglite._typing import FloatMatrix
+from pydantic import ConfigDict
+from sqlalchemy.engine import Engine, make_url
+from sqlmodel import (
+ JSON,
+ Column,
+ Field,
+ Relationship,
+ Session,
+ SQLModel,
+ create_engine,
+ select,
+ text,
+)
+
+from raglite._config import RAGLiteConfig
+from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
def hash_bytes(data: bytes, max_len: int = 16) -> str:
@@ -22,55 +31,17 @@ def hash_bytes(data: bytes, max_len: int = 16) -> str:
return sha256(data, usedforsecurity=False).hexdigest()[:max_len]
-class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
- """A NumPy array column type for SQLAlchemy."""
-
- impl = LargeBinary
-
- def process_bind_param(
- self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
- ) -> bytes | None:
- """Convert a NumPy array to bytes."""
- if value is None:
- return None
- buffer = io.BytesIO()
- np.save(buffer, value, allow_pickle=False, fix_imports=False)
- return buffer.getvalue()
-
- def process_result_value(
- self, value: bytes | None, dialect: Dialect
- ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
- """Convert bytes to a NumPy array."""
- if value is None:
- return None
- return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]
-
-
-class PickledObject(TypeDecorator[object]):
- """A pickled object column type for SQLAlchemy."""
-
- impl = LargeBinary
-
- def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
- """Convert a Python object to bytes."""
- if value is None:
- return None
- return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)
-
- def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
- """Convert bytes to a Python object."""
- if value is None:
- return None
- return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301
-
-
class Document(SQLModel, table=True):
"""A document."""
+ # Enable JSON columns.
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
+
+ # Table columns.
id: str = Field(..., primary_key=True)
filename: str
url: str | None = Field(default=None)
- metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON))
+ metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
# Add relationships so we can access document.chunks and document.evals.
chunks: list["Chunk"] = Relationship(back_populates="document")
@@ -90,26 +61,24 @@ def from_path(doc_path: Path, **kwargs: Any) -> "Document":
},
)
- # Enable support for JSON columns.
- class Config:
- """Table configuration."""
-
- arbitrary_types_allowed = True
-
class Chunk(SQLModel, table=True):
"""A document chunk."""
+ # Enable JSON columns.
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
+
+ # Table columns.
id: str = Field(..., primary_key=True)
document_id: str = Field(..., foreign_key="document.id", index=True)
index: int = Field(..., index=True)
headings: str
body: str
- multi_vector_embedding: FloatMatrix = Field(..., sa_column=Column(NumpyArray))
- metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON))
+ metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
- # Add relationship so we can access chunk.document.
+ # Add relationships so we can access chunk.document and chunk.embeddings.
document: Document = Relationship(back_populates="chunks")
+ embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk")
@staticmethod
def from_body(
@@ -117,7 +86,6 @@ def from_body(
index: int,
body: str,
headings: str = "",
- multi_vector_embedding: FloatMatrix | None = None,
**kwargs: Any,
) -> "Chunk":
"""Create a chunk from Markdown."""
@@ -127,9 +95,6 @@ def from_body(
index=index,
headings=headings,
body=body,
- multi_vector_embedding=multi_vector_embedding
- if multi_vector_embedding is not None
- else np.empty(0),
metadata_=kwargs,
)
@@ -151,6 +116,12 @@ def extract_headings(self) -> str:
headings = "\n".join([heading for heading in heading_lines if heading])
return headings
+ @property
+ def embedding_matrix(self) -> FloatMatrix:
+ """Return this chunk's multi-vector embedding matrix."""
+ # Uses the relationship chunk.embeddings to access the chunk_embedding table.
+ return np.vstack([embedding.embedding[np.newaxis, :] for embedding in self.embeddings])
+
def __str__(self) -> str:
"""Context representation of this chunk."""
return f"{self.headings.strip()}\n\n{self.body.strip()}".strip()
@@ -158,29 +129,75 @@ def __str__(self) -> str:
def __hash__(self) -> int:
return hash(self.id)
- # Enable support for JSON and NumpyArray columns.
- class Config:
- """Table configuration."""
- arbitrary_types_allowed = True
+class ChunkEmbedding(SQLModel, table=True):
+ """A (sub-)chunk embedding."""
+ __tablename__ = "chunk_embedding"
-class VectorSearchChunkIndex(SQLModel, table=True):
- """A vector search index for chunks."""
+ # Enable Embedding columns.
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
- __tablename__ = "vs_chunk_index" # Vector search chunk index.
+ # Table columns.
+ id: int = Field(..., primary_key=True)
+ chunk_id: str = Field(..., foreign_key="chunk.id", index=True)
+ embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))
+ # Add relationship so we can access embedding.chunk.
+ chunk: Chunk = Relationship(back_populates="embeddings")
+
+ @classmethod
+ def set_embedding_dim(cls, dim: int) -> None:
+ """Modify the embedding column's dimension after class definition."""
+ cls.__table__.c["embedding"].type.dim = dim # type: ignore[attr-defined]
+
+
+class IndexMetadata(SQLModel, table=True):
+ """Vector and keyword search index metadata."""
+
+ __tablename__ = "index_metadata"
+
+ # Enable PickledObject columns.
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
+
+ # Table columns.
id: str = Field(..., primary_key=True)
- chunk_sizes: list[int] = Field(default=[], sa_column=Column(JSON))
- index: NNDescent | None = Field(default=None, sa_column=Column(PickledObject))
- query_adapter: FloatMatrix | None = Field(default=None, sa_column=Column(NumpyArray))
- metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON))
+ version: datetime.datetime = Field(
+ default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
+ )
+ metadata_: dict[str, Any] = Field(
+ default_factory=dict, sa_column=Column("metadata", PickledObject)
+ )
+
+ @staticmethod
+ def _get_version(id_: str, *, config: RAGLiteConfig | None = None) -> datetime.datetime | None:
+ """Get the version of the index metadata with a given id."""
+ engine = create_database_engine(config)
+ with Session(engine) as session:
+ version = session.exec(
+ select(IndexMetadata.version).where(IndexMetadata.id == id_)
+ ).first()
+ return version
- # Enable support for JSON, PickledObject, and NumpyArray columns.
- class Config:
- """Table configuration."""
+ @staticmethod
+ @lru_cache(maxsize=4)
+ def _get(
+ id_: str, version: datetime.datetime | None, *, config: RAGLiteConfig | None = None
+ ) -> dict[str, Any] | None:
+ if version is None:
+ return None
+ engine = create_database_engine(config)
+ with Session(engine) as session:
+ index_metadata_record = session.get(IndexMetadata, id_)
+ if index_metadata_record is None:
+ return None
+ return index_metadata_record.metadata_
- arbitrary_types_allowed = True
+ @staticmethod
+ def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]:
+ version = IndexMetadata._get_version(id_, config=config)
+ metadata = IndexMetadata._get(id_, version, config=config) or {}
+ return metadata
class Eval(SQLModel, table=True):
@@ -188,13 +205,17 @@ class Eval(SQLModel, table=True):
__tablename__ = "eval"
+ # Enable JSON columns.
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
+
+ # Table columns.
id: str = Field(..., primary_key=True)
document_id: str = Field(..., foreign_key="document.id", index=True)
- chunk_ids: list[str] = Field(default=[], sa_column=Column(JSON))
+ chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
question: str
- contexts: list[str] = Field(default=[], sa_column=Column(JSON))
+ contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
ground_truth: str
- metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON))
+ metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
# Add relationship so we can access eval.document.
document: Document = Relationship(back_populates="evals")
@@ -219,56 +240,94 @@ def from_chunks(
metadata_=kwargs,
)
- # Enable support for JSON columns.
- class Config:
- """Table configuration."""
-
- arbitrary_types_allowed = True
-
@lru_cache(maxsize=1)
-def create_database_engine(db_url: str | URL = "sqlite:///raglite.sqlite") -> Engine:
+def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
"""Create a database engine and initialize it."""
- # Parse the database URL.
- db_url = make_url(db_url)
- assert db_url.get_backend_name() == "sqlite", "RAGLite currently only supports SQLite."
- # Optimize SQLite performance.
- pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"}
- db_url = db_url.update_query_dict(pragmas, append=True)
+ # Parse the database URL and validate that the database backend is supported.
+ config = config or RAGLiteConfig()
+ db_url = make_url(config.db_url)
+ db_backend = db_url.get_backend_name()
+ # Update database configuration.
+ connect_args = {}
+ if db_backend == "postgresql":
+ # Select the pg8000 driver if not set (psycopg2 is the default), and prefer SSL.
+ if "+" not in db_url.drivername:
+ db_url = db_url.set(drivername="postgresql+pg8000")
+ # Support setting the sslmode for pg8000.
+ if "pg8000" in db_url.drivername and "sslmode" in db_url.query:
+ query = dict(db_url.query)
+ if query.pop("sslmode") != "disable":
+ connect_args["ssl_context"] = True
+ db_url = db_url.set(query=query)
+ elif db_backend == "sqlite":
+ # Optimize SQLite performance.
+ pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"}
+ db_url = db_url.update_query_dict(pragmas, append=True)
+ else:
+ error_message = "RAGLite only supports PostgreSQL and SQLite."
+ raise ValueError(error_message)
# Create the engine.
- engine = create_engine(db_url)
+ engine = create_engine(db_url, pool_pre_ping=True, connect_args=connect_args)
+ # Install database extensions.
+ if db_backend == "postgresql":
+ with Session(engine) as session:
+ session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
+ session.commit()
# Create all SQLModel tables.
+ ChunkEmbedding.set_embedding_dim(config.embedder.n_embd())
SQLModel.metadata.create_all(engine)
- # Create a virtual table for full-text search on the chunk table.
- # We use the chunk table as an external content table [1] to avoid duplicating the data.
- # [1] https://www.sqlite.org/fts5.html#external_content_tables
- with Session(engine) as session:
- session.execute(
- text("""
- CREATE VIRTUAL TABLE IF NOT EXISTS fts_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
- """)
- )
- session.execute(
- text("""
- CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
- INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body);
- END;
- """)
- )
- session.execute(
- text("""
- CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
- INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
- END;
- """)
- )
- session.execute(
- text("""
- CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
- INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
- INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body);
- END;
- """)
- )
- session.commit()
+ # Create backend-specific indexes.
+ if db_backend == "postgresql":
+ # Create a full-text search index with `tsvector` and a vector search index with `pgvector`.
+ with Session(engine) as session:
+ metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
+ session.execute(
+ text("""
+ CREATE INDEX IF NOT EXISTS fts_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
+ """)
+ )
+ session.execute(
+ text(f"""
+ CREATE INDEX IF NOT EXISTS vs_chunk_index ON chunk_embedding
+ USING hnsw (
+ (embedding::halfvec({config.embedder.n_embd()}))
+ halfvec_{metrics[config.vector_search_index_metric]}_ops
+ );
+ """)
+ )
+ session.commit()
+ elif db_backend == "sqlite":
+ # Create a virtual table for full-text search on the chunk table.
+ # We use the chunk table as an external content table [1] to avoid duplicating the data.
+ # [1] https://www.sqlite.org/fts5.html#external_content_tables
+ with Session(engine) as session:
+ session.execute(
+ text("""
+ CREATE VIRTUAL TABLE IF NOT EXISTS fts_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
+ """)
+ )
+ session.execute(
+ text("""
+ CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
+ INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body);
+ END;
+ """)
+ )
+ session.execute(
+ text("""
+ CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
+ INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
+ END;
+ """)
+ )
+ session.execute(
+ text("""
+ CREATE TRIGGER IF NOT EXISTS fts_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
+ INSERT INTO fts_chunk_index(fts_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
+ INSERT INTO fts_chunk_index(rowid, body) VALUES (new.rowid, new.body);
+ END;
+ """)
+ )
+ session.commit()
return engine
diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py
index 7aa4a93..02a873d 100644
--- a/src/raglite/_embed.py
+++ b/src/raglite/_embed.py
@@ -22,8 +22,8 @@ def _embed_string_batch(
# Normalise embeddings to unit norm.
if config.embedder_normalize:
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
- # Cast to the configured dtype after normalisation.
- embeddings = embeddings.astype(config.embedder_dtype)
+ # Cast to half precision after normalisation.
+ embeddings = embeddings.astype(np.float16)
return embeddings
diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py
index daf2605..425464f 100644
--- a/src/raglite/_eval.py
+++ b/src/raglite/_eval.py
@@ -55,7 +55,7 @@ def validate_question(cls, value: str) -> str:
return value
config = config or RAGLiteConfig()
- engine = create_database_engine(config.db_url)
+ engine = create_database_engine(config)
with Session(engine) as session:
for _ in trange(num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True):
# Sample a random document from the database.
@@ -73,12 +73,12 @@ def validate_question(cls, value: str) -> str:
if seed_chunk is None:
continue
# Expand the seed chunk into a set of related chunks.
- related_chunk_rowids, _ = vector_search(
- np.mean(seed_chunk.multi_vector_embedding, axis=0, keepdims=True),
+ related_chunk_ids, _ = vector_search(
+ np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
config=config,
)
- related_chunks = retrieve_segments(related_chunk_rowids, config=config)
+ related_chunks = retrieve_segments(related_chunk_ids, config=config)
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
@@ -89,13 +89,10 @@ def validate_question(cls, value: str) -> str:
else:
question = question_response.question
# Search for candidate chunks to answer the generated question.
- candidate_chunk_rowids, _ = hybrid_search(
+ candidate_chunk_ids, _ = hybrid_search(
question, num_results=max_contexts_per_eval, config=config
)
- candidate_chunks = [
- session.exec(select(Chunk).offset(chunk_rowid - 1)).first()
- for chunk_rowid in candidate_chunk_rowids
- ]
+ candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]
# Determine which candidate chunks are relevant to answer the generated question.
class ContextEvalResponse(BaseModel):
@@ -170,14 +167,14 @@ class AnswerResponse(BaseModel):
def answer_evals(
num_evals: int = 100,
- search: Callable[[str], tuple[list[int], list[float]]] = hybrid_search,
+ search: Callable[[str], tuple[list[str], list[float]]] = hybrid_search,
*,
config: RAGLiteConfig | None = None,
) -> pd.DataFrame:
"""Read evals from the database and answer them with RAG."""
# Read evals from the database.
config = config or RAGLiteConfig()
- engine = create_database_engine(config.db_url)
+ engine = create_database_engine(config)
with Session(engine) as session:
evals = session.exec(select(Eval).limit(num_evals)).all()
# Answer evals with RAG.
@@ -187,8 +184,8 @@ def answer_evals(
response = rag(eval_.question, search=search, config=config)
answer = "".join(response)
answers.append(answer)
- chunk_rowids, _ = search(eval_.question, config=config) # type: ignore[call-arg]
- contexts.append(retrieve_segments(chunk_rowids))
+ chunk_ids, _ = search(eval_.question, config=config) # type: ignore[call-arg]
+ contexts.append(retrieve_segments(chunk_ids))
# Collect the answered evals.
answered_evals: dict[str, list[str] | list[list[str]]] = {
"question": [eval_.question for eval_ in evals],
diff --git a/src/raglite/_index.py b/src/raglite/_index.py
index 81e531d..bfa2e25 100644
--- a/src/raglite/_index.py
+++ b/src/raglite/_index.py
@@ -1,16 +1,15 @@
"""Index documents."""
-from copy import deepcopy
from functools import partial
from pathlib import Path
import numpy as np
-from pynndescent import NNDescent
+from sqlalchemy.engine import make_url
from sqlmodel import Session, select
from tqdm.auto import tqdm
from raglite._config import RAGLiteConfig
-from raglite._database import Chunk, Document, VectorSearchChunkIndex, create_database_engine
+from raglite._database import Chunk, ChunkEmbedding, Document, IndexMetadata, create_database_engine
from raglite._embed import embed_strings
from raglite._markdown import document_to_markdown
from raglite._split_chunks import split_chunks
@@ -23,8 +22,8 @@ def _create_chunk_records(
chunks: list[str],
sentence_embeddings: list[FloatMatrix],
config: RAGLiteConfig,
-) -> list[Chunk]:
- """Process chunks into chunk records comprising headings, body, and a multi-vector embedding."""
+) -> tuple[list[Chunk], list[list[ChunkEmbedding]]]:
+ """Process chunks into chunk and chunk embedding records."""
# Create the chunk records.
chunk_records, headings = [], ""
for i, chunk in enumerate(chunks):
@@ -37,28 +36,28 @@ def _create_chunk_records(
contextualized_embeddings = embed_strings([str(chunk) for chunk in chunks], config=config)
# Set the chunk's multi-vector embedding as a linear combination of its sentence embeddings
# (for local context) and an embedding of the contextualised chunk (for global context).
- for record, sentence_embedding, contextualized_embedding in zip(
+ ฮฑ = config.sentence_embedding_weight # noqa: PLC2401
+ chunk_embedding_records = []
+ for chunk_record, sentence_embedding, contextualized_embedding in zip(
chunk_records, sentence_embeddings, contextualized_embeddings, strict=True
):
- chunk_embedding = (
- config.sentence_embedding_weight * sentence_embedding
- + (1 - config.sentence_embedding_weight) * contextualized_embedding[np.newaxis, :]
- )
+ chunk_embedding = ฮฑ * sentence_embedding + (1 - ฮฑ) * contextualized_embedding[np.newaxis, :]
chunk_embedding = chunk_embedding / np.linalg.norm(chunk_embedding, axis=1, keepdims=True)
- record.multi_vector_embedding = chunk_embedding
- return chunk_records
+ chunk_embedding_records.append(
+ [ChunkEmbedding(chunk_id=chunk_record.id, embedding=row) for row in chunk_embedding]
+ )
+ return chunk_records, chunk_embedding_records
-def insert_document(
- doc_path: Path, *, update_index: bool = True, config: RAGLiteConfig | None = None
-) -> None:
+def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> None:
"""Insert a document into the database and update the index."""
# Use the default config if not provided.
config = config or RAGLiteConfig()
+ db_backend = make_url(config.db_url).get_backend_name()
# Preprocess the document into chunks.
with tqdm(total=4, unit="step", dynamic_ncols=True) as pbar:
pbar.set_description("Initializing database")
- engine = create_database_engine(config.db_url)
+ engine = create_database_engine(config)
pbar.update(1)
pbar.set_description("Converting to Markdown")
doc = document_to_markdown(doc_path)
@@ -81,61 +80,62 @@ def insert_document(
if session.get(Document, document_record.id) is None:
session.add(document_record)
session.commit()
- # Create the chunk records.
- chunk_records = _create_chunk_records(
+ # Create the chunk records to insert into the chunk table.
+ chunk_records, chunk_embedding_records = _create_chunk_records(
document_record.id, chunks, sentence_embeddings, config
)
- # Store the chunk records.
- for chunk_record in tqdm(
- chunk_records, desc="Storing chunks", unit="chunk", dynamic_ncols=True
+ # Store the chunk and chunk embedding records.
+ for chunk_record, chunk_embedding_record in tqdm(
+ zip(chunk_records, chunk_embedding_records, strict=True),
+ desc="Storing chunks" if db_backend == "sqlite" else "Storing and indexing chunks",
+ total=len(chunk_records),
+ unit="chunk",
+ dynamic_ncols=True,
):
if session.get(Chunk, chunk_record.id) is not None:
continue
session.add(chunk_record)
+ session.add_all(chunk_embedding_record)
session.commit()
- # Update the vector search chunk index.
- if update_index:
- update_vector_index(config)
-
+ # Manually update the vector search chunk index for SQLite.
+ if db_backend == "sqlite":
+ from pynndescent import NNDescent
-def update_vector_index(config: RAGLiteConfig | None = None) -> None:
- """Update the vector search chunk index with any unindexed chunks."""
- config = config or RAGLiteConfig()
- engine = create_database_engine(config.db_url)
- with Session(engine) as session:
- # Get the vector search chunk index from the database, or create a new one.
- vector_search_chunk_index = session.get(
- VectorSearchChunkIndex, config.vector_search_index_id
- ) or VectorSearchChunkIndex(id=config.vector_search_index_id)
- num_chunks_indexed = len(vector_search_chunk_index.chunk_sizes)
- # Get the unindexed chunks.
- statement = select(Chunk).offset(num_chunks_indexed)
- unindexed_chunks = session.exec(statement).all()
- num_chunks_unindexed = len(unindexed_chunks)
- # Index the unindexed chunks.
- with tqdm(
- total=num_chunks_indexed + num_chunks_unindexed,
- desc="Indexing chunks",
- unit="chunk",
- dynamic_ncols=True,
- ) as pbar:
- # Fit or update the ANN index.
- pbar.update(num_chunks_indexed)
- if num_chunks_unindexed == 0:
+ with Session(engine) as session:
+ # Get the vector search chunk index from the database, or create a new one.
+ index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
+ chunk_ids = index_metadata.metadata_.get("chunk_ids", [])
+ chunk_sizes = index_metadata.metadata_.get("chunk_sizes", [])
+ # Get the unindexed chunks.
+ unindexed_chunks = list(session.exec(select(Chunk).offset(len(chunk_ids))).all())
+ if not unindexed_chunks:
return
- X_unindexed = np.vstack([chunk.multi_vector_embedding for chunk in unindexed_chunks]) # noqa: N806
- if num_chunks_indexed == 0:
- nndescent = NNDescent(X_unindexed, metric=config.vector_search_index_metric)
- else:
- nndescent = deepcopy(vector_search_chunk_index.index)
- nndescent.update(X_unindexed)
- nndescent.prepare()
- # Mark the vector search chunk index as dirty.
- vector_search_chunk_index.index = nndescent
- vector_search_chunk_index.chunk_sizes = vector_search_chunk_index.chunk_sizes + [
- chunk.multi_vector_embedding.shape[0] for chunk in unindexed_chunks
- ]
- # Store the updated vector search chunk index.
- session.add(vector_search_chunk_index)
- session.commit()
- pbar.update(num_chunks_unindexed)
+ # Assemble the unindexed chunk embeddings into a NumPy array.
+ unindexed_chunk_embeddings = [chunk.embedding_matrix for chunk in unindexed_chunks]
+ X = np.vstack(unindexed_chunk_embeddings) # noqa: N806
+ # Index the unindexed chunks.
+ with tqdm(
+ total=len(unindexed_chunks),
+ desc="Indexing chunks",
+ unit="chunk",
+ dynamic_ncols=True,
+ ) as pbar:
+ # Fit or update the ANN index.
+ if len(chunk_ids) == 0:
+ nndescent = NNDescent(X, metric=config.vector_search_index_metric)
+ else:
+ nndescent = index_metadata.metadata_["index"]
+ nndescent.update(X)
+ # Prepare the ANN index so it can to handle query vectors not in the training set.
+ nndescent.prepare()
+ # Update the index metadata and mark it as dirty by recreating the dictionary.
+ index_metadata.metadata_ = {
+ **index_metadata.metadata_,
+ "index": nndescent,
+ "chunk_ids": chunk_ids + [c.id for c in unindexed_chunks],
+ "chunk_sizes": chunk_sizes + [len(em) for em in unindexed_chunk_embeddings],
+ }
+ # Store the updated vector search chunk index.
+ session.add(index_metadata)
+ session.commit()
+ pbar.update(len(unindexed_chunks))
diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py
index bd3fed9..bedb9a7 100644
--- a/src/raglite/_query_adapter.py
+++ b/src/raglite/_query_adapter.py
@@ -1,16 +1,16 @@
"""Compute and update an optimal query adapter."""
import numpy as np
-from sqlmodel import Session, select
+from sqlmodel import Session, col, select
from tqdm.auto import tqdm
from raglite._config import RAGLiteConfig
-from raglite._database import Chunk, Eval, VectorSearchChunkIndex, create_database_engine
+from raglite._database import Chunk, Eval, IndexMetadata, create_database_engine
from raglite._embed import embed_strings
from raglite._search import vector_search
-def update_query_adapter( # noqa: C901, PLR0915
+def update_query_adapter( # noqa: PLR0915
*,
max_triplets: int = 4096,
max_triplets_per_eval: int = 64,
@@ -63,8 +63,10 @@ def update_query_adapter( # noqa: C901, PLR0915
C := 5% * A, the optimal ฮฑ is then given by ฮฑA + (1 - ฮฑ)B = C => ฮฑ = (B - C) / (B - A).
"""
config = config or RAGLiteConfig()
- config_no_query_adapter = RAGLiteConfig(**{**config.__dict__, "enable_query_adapter": False})
- engine = create_database_engine(config.db_url)
+ config_no_query_adapter = RAGLiteConfig(
+ **{**config.__dict__, "vector_search_query_adapter": False}
+ )
+ engine = create_database_engine(config)
with Session(engine) as session:
# Get random evals from the database.
evals = session.exec(
@@ -83,34 +85,25 @@ def update_query_adapter( # noqa: C901, PLR0915
# Embed the question.
question_embedding = embed_strings([eval_.question], config=config)
# Retrieve chunks that would be used to answer the question.
- chunk_rowids, _ = vector_search(
+ chunk_ids, _ = vector_search(
question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
)
- retrieved_chunks = [
- session.exec(select(Chunk).offset(chunk_rowid - 1)).first()
- for chunk_rowid in chunk_rowids
- ]
+ retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
# Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
num_triplets = 0
for i, retrieved_chunk in enumerate(retrieved_chunks):
- # Raise an error if the retrieved chunk is None.
- if retrieved_chunk is None:
- error_message = (
- f"The chunk with rowid {chunk_rowids[i]} is missing from the database."
- )
- raise ValueError(error_message)
# Select irrelevant chunks.
if retrieved_chunk.id not in eval_.chunk_ids:
# Look up all positive chunks (each represented by the mean of its multi-vector
# embedding) that are ranked lower than this negative one (represented by the
# embedding in the multi-vector embedding that best matches the query).
p_mean = [
- np.mean(chunk.multi_vector_embedding, axis=0, keepdims=True)
+ np.mean(chunk.embedding_matrix, axis=0, keepdims=True)
for chunk in retrieved_chunks[i + 1 :]
if chunk is not None and chunk.id in eval_.chunk_ids
]
- n_top = retrieved_chunk.multi_vector_embedding[
- np.argmax(retrieved_chunk.multi_vector_embedding @ question_embedding.T),
+ n_top = retrieved_chunk.embedding_matrix[
+ np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T),
np.newaxis,
:,
]
@@ -159,9 +152,7 @@ def update_query_adapter( # noqa: C901, PLR0915
error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}"
raise ValueError(error_message)
# Store the optimal query adapter in the database.
- vector_search_chunk_index = session.get(
- VectorSearchChunkIndex, config.vector_search_index_id
- ) or VectorSearchChunkIndex(id=config.vector_search_index_id)
- vector_search_chunk_index.query_adapter = A_star
- session.add(vector_search_chunk_index)
+ index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
+ index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
+ session.add(index_metadata)
session.commit()
diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py
index 0235f06..08be304 100644
--- a/src/raglite/_rag.py
+++ b/src/raglite/_rag.py
@@ -11,7 +11,7 @@ def rag(
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
- search: Callable[[str], tuple[list[int], list[float]]] = hybrid_search,
+ search: Callable[[str], tuple[list[str], list[float]]] = hybrid_search,
config: RAGLiteConfig | None = None,
) -> Iterator[str]:
"""Retrieval-augmented generation."""
@@ -22,8 +22,8 @@ def rag(
max_tokens_per_context *= 1 + len(context_neighbors or [])
max_contexts = min(max_contexts, max_tokens // max_tokens_per_context)
# Retrieve relevant contexts.
- chunk_rowids, _ = search(prompt, num_results=max_contexts, config=config) # type: ignore[call-arg]
- segments = retrieve_segments(chunk_rowids, neighbors=context_neighbors)
+ chunk_ids, _ = search(prompt, num_results=max_contexts, config=config) # type: ignore[call-arg]
+ segments = retrieve_segments(chunk_ids, neighbors=context_neighbors)
# Respond with an LLM.
contexts = "\n\n".join(
f'\n{segment.strip()}\n'
diff --git a/src/raglite/_search.py b/src/raglite/_search.py
index 73c2dd8..2f3a5ba 100644
--- a/src/raglite/_search.py
+++ b/src/raglite/_search.py
@@ -3,40 +3,19 @@
import re
import string
from collections import defaultdict
-from functools import lru_cache
from itertools import groupby
-from typing import Annotated, ClassVar
+from typing import Annotated, ClassVar, cast
import numpy as np
from pydantic import BaseModel, Field
-from pynndescent import NNDescent
+from sqlalchemy.engine import make_url
from sqlmodel import Session, select, text
from raglite._config import RAGLiteConfig
-from raglite._database import Chunk, VectorSearchChunkIndex, create_database_engine
+from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine
from raglite._embed import embed_strings
from raglite._extract import extract_with_llm
-from raglite._typing import FloatMatrix, IntVector
-
-
-@lru_cache(maxsize=1)
-def _vector_search_chunk_index(
- config: RAGLiteConfig,
-) -> tuple[NNDescent, IntVector, FloatMatrix | None]:
- engine = create_database_engine(config.db_url)
- with Session(engine) as session:
- vector_search_chunk_index = session.get(
- VectorSearchChunkIndex, config.vector_search_index_id
- )
- if vector_search_chunk_index is None:
- error_message = "First run `update_vector_index()` to create a vector search index."
- raise ValueError(error_message)
- index = vector_search_chunk_index.index
- chunk_size_cumsum = np.cumsum(
- np.asarray(vector_search_chunk_index.chunk_sizes, dtype=np.intp)
- )
- query_adapter = vector_search_chunk_index.query_adapter
- return index, chunk_size_cumsum, query_adapter
+from raglite._typing import FloatMatrix
def vector_search(
@@ -44,101 +23,148 @@ def vector_search(
*,
num_results: int = 3,
config: RAGLiteConfig | None = None,
-) -> tuple[list[int], list[float]]:
+) -> tuple[list[str], list[float]]:
"""Search chunks using ANN vector search."""
- # Retrieve the index from the database.
+ # Read the config.
config = config or RAGLiteConfig()
- index, chunk_size_cumsum, Q = _vector_search_chunk_index(config) # noqa: N806
+ db_backend = make_url(config.db_url).get_backend_name()
+ # Get the index metadata (including the query adapter, and in the case of SQLite, the index).
+ index_metadata = IndexMetadata.get("default", config=config)
# Embed the prompt.
prompt_embedding = (
- embed_strings([prompt], config=config)
+ embed_strings([prompt], config=config)[0, :]
if isinstance(prompt, str)
- else np.reshape(prompt, (1, -1))
+ else np.ravel(prompt)
)
- # Apply the query adapter.
- if config.enable_query_adapter and Q is not None:
- prompt_embedding = (Q @ prompt_embedding[0, :])[np.newaxis, :].astype(config.embedder_dtype)
- # Find the neighbouring multi-vector indices.
- multi_vector_indices, cosine_distance = index.query(prompt_embedding, k=8 * num_results)
- cosine_similarity = 1 - cosine_distance[0, :]
- # Transform the multi-vector indices into chunk rowids.
- chunk_rowids = np.searchsorted(chunk_size_cumsum, multi_vector_indices[0, :], side="right") + 1
- # Score each unique chunk rowid as the mean cosine similarity of its multi-vector hits.
- # Chunk rowids with fewer hits are padded with the minimum cosine similarity of the result set.
- unique_chunk_rowids, counts = np.unique(chunk_rowids, return_counts=True)
+ # Apply the query adapter to the prompt embedding.
+ Q = index_metadata.get("query_adapter") # noqa: N806
+ if config.vector_search_query_adapter and Q is not None:
+ prompt_embedding = (Q @ prompt_embedding).astype(prompt_embedding.dtype)
+ # Search for the multi-vector chunk embeddings that are most similar to the prompt embedding.
+ if db_backend == "postgresql":
+ # Check that the selected metric is supported by pgvector.
+ metrics = {"cosine": "<=>", "dot": "<#>", "euclidean": "<->", "l1": "<+>", "l2": "<->"}
+ if config.vector_search_index_metric not in metrics:
+ error_message = f"Unsupported metric {config.vector_search_index_metric}."
+ raise ValueError(error_message)
+ # With pgvector, we can obtain the nearest neighbours and similarities with a single query.
+ engine = create_database_engine(config)
+ with Session(engine) as session:
+ distance_func = getattr(
+ ChunkEmbedding.embedding, f"{config.vector_search_index_metric}_distance"
+ )
+ distance = distance_func(prompt_embedding).label("distance")
+ results = session.exec(
+ select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results)
+ )
+ chunk_ids_, distance = zip(*results, strict=True)
+ chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance)
+ elif db_backend == "sqlite":
+ # Load the NNDescent index.
+ index = index_metadata.get("index")
+ ids = np.asarray(index_metadata.get("chunk_ids"))
+ cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes")))
+ # Find the neighbouring multi-vector indices.
+ from pynndescent import NNDescent
+
+ multi_vector_indices, distance = cast(NNDescent, index).query(
+ prompt_embedding[np.newaxis, :], k=8 * num_results
+ )
+ similarity = 1 - distance[0, :]
+ # Transform the multi-vector indices into chunk indices, and then to chunk ids.
+ chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
+ chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
+ # Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with
+ # fewer hits are padded with the minimum similarity of the result set.
+ unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True)
score = np.full(
- (len(unique_chunk_rowids), np.max(counts)),
- np.min(cosine_similarity),
- dtype=cosine_similarity.dtype,
+ (len(unique_chunk_ids), np.max(counts)), np.min(similarity), dtype=similarity.dtype
)
- for i, (unique_chunk_rowid, count) in enumerate(zip(unique_chunk_rowids, counts, strict=True)):
- score[i, :count] = cosine_similarity[chunk_rowids == unique_chunk_rowid]
- pooled_cosine_similarity = np.mean(score, axis=1)
- # Sort the chunk rowids by adjusted cosine similarity.
- sorted_indices = np.argsort(pooled_cosine_similarity)[::-1]
- unique_chunk_rowids = unique_chunk_rowids[sorted_indices][:num_results]
- pooled_cosine_similarity = pooled_cosine_similarity[sorted_indices][:num_results]
- return unique_chunk_rowids.tolist(), pooled_cosine_similarity.tolist()
-
-
-def _prompt_to_fts_query(prompt: str) -> str:
- """Convert a prompt to an FTS5 query."""
- # https://www.sqlite.org/fts5.html#full_text_query_syntax
- prompt = re.sub(f"[{re.escape(string.punctuation)}]", "", prompt)
- fts_query = " OR ".join(prompt.split())
- return fts_query
+ for i, (unique_chunk_id, count) in enumerate(zip(unique_chunk_ids, counts, strict=True)):
+ score[i, :count] = similarity[chunk_ids == unique_chunk_id]
+ pooled_similarity = np.mean(score, axis=1)
+ # Sort the chunk ids by their adjusted similarity.
+ sorted_indices = np.argsort(pooled_similarity)[::-1]
+ unique_chunk_ids = unique_chunk_ids[sorted_indices][:num_results]
+ pooled_similarity = pooled_similarity[sorted_indices][:num_results]
+ return unique_chunk_ids.tolist(), pooled_similarity.tolist()
def keyword_search(
prompt: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
-) -> tuple[list[int], list[float]]:
+) -> tuple[list[str], list[float]]:
"""Search chunks using BM25 keyword search."""
+ # Read the config.
config = config or RAGLiteConfig()
- engine = create_database_engine(config.db_url)
+ db_backend = make_url(config.db_url).get_backend_name()
+ # Connect to the database.
+ engine = create_database_engine(config)
with Session(engine) as session:
- # Perform the full-text search query using the BM25 ranking.
- statement = text(
- "SELECT chunk.rowid, bm25(fts_chunk_index) FROM chunk JOIN fts_chunk_index ON chunk.rowid = fts_chunk_index.rowid WHERE fts_chunk_index MATCH :match ORDER BY rank LIMIT :limit;"
- )
- results = session.execute(
- statement, params={"match": _prompt_to_fts_query(prompt), "limit": num_results}
- )
- # Unpack the results and make FTS5's negative BM25 scores [1] positive.
- # https://www.sqlite.org/fts5.html#the_bm25_function
- chunk_rowids, bm25_score = zip(*results, strict=True)
- chunk_rowids, bm25_score = list(chunk_rowids), [-s for s in bm25_score] # type: ignore[assignment]
- return chunk_rowids, bm25_score # type: ignore[return-value]
+ if db_backend == "postgresql":
+ # Convert the prompt to a tsquery [1].
+ # [1] https://www.postgresql.org/docs/current/textsearch-controls.html
+ prompt_escaped = re.sub(r"[&|!():<>\"]", " ", prompt)
+ tsv_query = " | ".join(prompt_escaped.split())
+ # Perform full-text search with tsvector.
+ statement = text("""
+ SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
+ FROM chunk
+ WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
+ ORDER BY score DESC
+ LIMIT :limit;
+ """)
+ results = session.execute(statement, params={"query": tsv_query, "limit": num_results})
+ elif db_backend == "sqlite":
+ # Convert the prompt to an FTS5 query [1].
+ # [1] https://www.sqlite.org/fts5.html#full_text_query_syntax
+ prompt_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", prompt)
+ fts5_query = " OR ".join(prompt_escaped.split())
+ # Perform full-text search with FTS5. In FTS5, BM25 scores are negative [1], so we
+ # negate them to make them positive.
+ # [1] https://www.sqlite.org/fts5.html#the_bm25_function
+ statement = text("""
+ SELECT chunk.id as chunk_id, -bm25(fts_chunk_index) as score
+ FROM chunk JOIN fts_chunk_index ON chunk.rowid = fts_chunk_index.rowid
+ WHERE fts_chunk_index MATCH :match
+ ORDER BY score DESC
+ LIMIT :limit;
+ """)
+ results = session.execute(statement, params={"match": fts5_query, "limit": num_results})
+ # Unpack the results.
+ chunk_ids, keyword_score = zip(*results, strict=True)
+ chunk_ids, keyword_score = list(chunk_ids), list(keyword_score) # type: ignore[assignment]
+ return chunk_ids, keyword_score # type: ignore[return-value]
def reciprocal_rank_fusion(
- rankings: list[list[int]], *, k: int = 60
-) -> tuple[list[int], list[float]]:
+ rankings: list[list[str]], *, k: int = 60
+) -> tuple[list[str], list[float]]:
"""Reciprocal Rank Fusion."""
# Compute the RRF score.
- rowids = {rowid for ranking in rankings for rowid in ranking}
- rowid_score: defaultdict[int, float] = defaultdict(float)
+ chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking}
+ chunk_id_score: defaultdict[str, float] = defaultdict(float)
for ranking in rankings:
- rowid_index = {rowid: i for i, rowid in enumerate(ranking)}
- for rowid in rowids:
- rowid_score[rowid] += 1 / (k + rowid_index.get(rowid, len(rowid_index)))
+ chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)}
+ for chunk_id in chunk_ids:
+ chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index)))
# Rank RRF results according to descending RRF score.
- rrf_rowids, rrf_score = zip(
- *sorted(rowid_score.items(), key=lambda x: x[1], reverse=True), strict=True
+ rrf_chunk_ids, rrf_score = zip(
+ *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True
)
- return list(rrf_rowids), list(rrf_score)
+ return list(rrf_chunk_ids), list(rrf_score)
def hybrid_search(
prompt: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None
-) -> tuple[list[int], list[float]]:
+) -> tuple[list[str], list[float]]:
"""Search chunks by combining ANN vector search with BM25 keyword search."""
# Run both searches.
chunks_vector, _ = vector_search(prompt, num_results=num_rerank, config=config)
- chunks_bm25, _ = keyword_search(prompt, num_results=num_rerank, config=config)
+ chunks_keyword, _ = keyword_search(prompt, num_results=num_rerank, config=config)
# Combine the results with Reciprocal Rank Fusion (RRF).
- chunk_rowids, hybrid_score = reciprocal_rank_fusion([chunks_vector, chunks_bm25])
- chunk_rowids, hybrid_score = chunk_rowids[:num_results], hybrid_score[:num_results]
- return chunk_rowids, hybrid_score
+ chunk_ids, hybrid_score = reciprocal_rank_fusion([chunks_vector, chunks_keyword])
+ chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results]
+ return chunk_ids, hybrid_score
def fusion_search(
@@ -147,7 +173,7 @@ def fusion_search(
num_results: int = 5,
num_rerank: int = 100,
config: RAGLiteConfig | None = None,
-) -> tuple[list[int], list[float]]:
+) -> tuple[list[str], list[float]]:
"""Search for chunks with the RAG-Fusion method."""
class QueriesResponse(BaseModel):
@@ -172,31 +198,31 @@ class QueriesResponse(BaseModel):
for query in queries:
# Run both searches.
chunks_vector, _ = vector_search(query, num_results=num_rerank, config=config)
- chunks_bm25, _ = keyword_search(query, num_results=num_rerank, config=config)
+ chunks_keyword, _ = keyword_search(query, num_results=num_rerank, config=config)
# Add results to the rankings.
rankings.append(chunks_vector)
- rankings.append(chunks_bm25)
+ rankings.append(chunks_keyword)
# Combine all the search results with Reciprocal Rank Fusion (RRF).
- chunk_rowids, fusion_score = reciprocal_rank_fusion(rankings)
- chunk_rowids, fusion_score = chunk_rowids[:num_results], fusion_score[:num_results]
- return chunk_rowids, fusion_score
+ chunk_ids, fusion_score = reciprocal_rank_fusion(rankings)
+ chunk_ids, fusion_score = chunk_ids[:num_results], fusion_score[:num_results]
+ return chunk_ids, fusion_score
def retrieve_segments(
- chunk_rowids: list[int],
+ chunk_ids: list[str],
*,
neighbors: tuple[int, ...] | None = (-1, 1),
config: RAGLiteConfig | None = None,
) -> list[str]:
- """Group the chunks into contiguous segments and retrieve them."""
- # Get the chunks by rowid and extend them with their neighbours.
+ """Group chunks into contiguous segments and retrieve them."""
+ # Get the chunks and extend them with their neighbours.
config = config or RAGLiteConfig()
chunks = set()
- engine = create_database_engine(config.db_url)
+ engine = create_database_engine(config)
with Session(engine) as session:
- for chunk_rowid in chunk_rowids:
- # Get the chunk at the given rowid.
- chunk = session.exec(select(Chunk).offset(chunk_rowid - 1)).first()
+ for chunk_id in chunk_ids:
+ # Get the chunk by id.
+ chunk = session.get(Chunk, chunk_id)
if chunk is not None:
chunks.add(chunk)
# Extend the chunk with its neighbouring chunks.
diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py
index c80f0e4..adda9d0 100644
--- a/src/raglite/_typing.py
+++ b/src/raglite/_typing.py
@@ -1,8 +1,137 @@
"""RAGLite typing."""
+import io
+import pickle
+from collections.abc import Callable
from typing import Any
import numpy as np
+from sqlalchemy.engine import Dialect
+from sqlalchemy.sql.operators import Operators
+from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType
FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
+FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
+
+
+class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
+ """A NumPy array column type for SQLAlchemy."""
+
+ impl = LargeBinary
+
+ def process_bind_param(
+ self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
+ ) -> bytes | None:
+ """Convert a NumPy array to bytes."""
+ if value is None:
+ return None
+ buffer = io.BytesIO()
+ np.save(buffer, value, allow_pickle=False, fix_imports=False)
+ return buffer.getvalue()
+
+ def process_result_value(
+ self, value: bytes | None, dialect: Dialect
+ ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
+ """Convert bytes to a NumPy array."""
+ if value is None:
+ return None
+ return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]
+
+
+class PickledObject(TypeDecorator[object]):
+ """A pickled object column type for SQLAlchemy."""
+
+ impl = LargeBinary
+
+ def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
+ """Convert a Python object to bytes."""
+ if value is None:
+ return None
+ return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)
+
+ def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
+ """Convert bytes to a Python object."""
+ if value is None:
+ return None
+ return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301
+
+
+class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]):
+ """A mixin that provides comparison operators for halfvecs."""
+
+ def cosine_distance(self, other: FloatVector) -> Operators:
+ """Compute the cosine distance."""
+ return self.op("<=>", return_type=Float)(other)
+
+ def dot_distance(self, other: FloatVector) -> Operators:
+ """Compute the dot product distance."""
+ return self.op("<#>", return_type=Float)(other)
+
+ def euclidean_distance(self, other: FloatVector) -> Operators:
+ """Compute the Euclidean distance."""
+ return self.op("<->", return_type=Float)(other)
+
+ def l1_distance(self, other: FloatVector) -> Operators:
+ """Compute the L1 distance."""
+ return self.op("<+>", return_type=Float)(other)
+
+ def l2_distance(self, other: FloatVector) -> Operators:
+ """Compute the L2 distance."""
+ return self.op("<->", return_type=Float)(other)
+
+
+class HalfVec(UserDefinedType[FloatVector]):
+ """A PostgreSQL half-precision vector column type for SQLAlchemy."""
+
+ cache_ok = True # HalfVec is immutable.
+
+ def __init__(self, dim: int | None = None) -> None:
+ super().__init__()
+ self.dim = dim
+
+ def get_col_spec(self, **kwargs: Any) -> str:
+ return f"halfvec({self.dim})"
+
+ def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]:
+ """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters."""
+
+ def process(value: FloatVector | None) -> str | None:
+ return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None
+
+ return process
+
+ def result_processor(
+ self, dialect: Dialect, coltype: Any
+ ) -> Callable[[str | None], FloatVector | None]:
+ """Process PostgreSQL halfvec format to NumPy ndarray."""
+
+ def process(value: str | None) -> FloatVector | None:
+ if value is None:
+ return None
+ return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16)
+
+ return process
+
+ class comparator_factory(HalfVecComparatorMixin): # noqa: N801
+ ...
+
+
+class Embedding(TypeDecorator[FloatVector]):
+ """An embedding column type for SQLAlchemy."""
+
+ cache_ok = True # Embedding is immutable.
+
+ impl = NumpyArray
+
+ def __init__(self, dim: int = -1):
+ super().__init__()
+ self.dim = dim
+
+ def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]:
+ if dialect.name == "postgresql":
+ return dialect.type_descriptor(HalfVec(self.dim))
+ return dialect.type_descriptor(NumpyArray())
+
+ class comparator_factory(HalfVecComparatorMixin): # noqa: N801
+ ...
diff --git a/tests/conftest.py b/tests/conftest.py
index a28e805..4439923 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,14 +1,35 @@
"""Fixtures for the tests."""
+import socket
+
import pytest
from llama_cpp import Llama
+from sqlalchemy import create_engine, text
from raglite import RAGLiteConfig
-@pytest.fixture
-def simple_config() -> RAGLiteConfig:
- """Create a lightweight in-memory config for testing."""
+def is_postgres_running() -> bool:
+ """Check if PostgreSQL is running."""
+ try:
+ with socket.create_connection(("postgres", 5432), timeout=1):
+ return True
+ except OSError:
+ return False
+
+
+@pytest.fixture(
+ params=[
+ pytest.param("sqlite:///:memory:", id="SQLite"),
+ pytest.param(
+ "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres",
+ id="PostgreSQL",
+ marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"),
+ ),
+ ]
+)
+def simple_config(request: pytest.FixtureRequest) -> RAGLiteConfig:
+ """Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
# Use a lightweight embedder.
embedder = Llama.from_pretrained(
repo_id="ChristianAzinn/snowflake-arctic-embed-xs-gguf", # https://github.com/Snowflake-Labs/arctic-embed
@@ -18,8 +39,18 @@ def simple_config() -> RAGLiteConfig:
verbose=False,
embedding=True,
)
- # Use an in-memory SQLite database.
- db_url = "sqlite:///:memory:"
- # Create the config.
- config = RAGLiteConfig(embedder=embedder, db_url=db_url)
- return config
+ # Yield a SQLite config.
+ if "sqlite" in request.param:
+ sqlite_config = RAGLiteConfig(embedder=embedder, db_url=request.param)
+ return sqlite_config
+ # Yield a PostgreSQL config.
+ if "postgresql" in request.param:
+ engine = create_engine(request.param, isolation_level="AUTOCOMMIT")
+ with engine.connect() as conn:
+ conn.execute(text("DROP DATABASE IF EXISTS raglite_test_db"))
+ conn.execute(text("CREATE DATABASE raglite_test_db"))
+ postgresql_config = RAGLiteConfig(
+ embedder=embedder, db_url=request.param.replace("raglite_db", "raglite_test_db")
+ )
+ return postgresql_config
+ raise ValueError
diff --git a/tests/test_basic.py b/tests/test_basic.py
index afa569e..abf0e3c 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -12,11 +12,11 @@ def test_insert_index_search(simple_config: RAGLiteConfig) -> None:
insert_document(doc_path, config=simple_config)
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
- chunk_rowids, scores = hybrid_search(query, config=simple_config)
- assert len(chunk_rowids) == len(scores)
- assert all(isinstance(rowid, int) for rowid in chunk_rowids)
+ chunk_ids, scores = hybrid_search(query, config=simple_config)
+ assert len(chunk_ids) == len(scores)
+ assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)
# Group the chunks into segments and retrieve them.
- segments = retrieve_segments(chunk_rowids, neighbors=None, config=simple_config)
+ segments = retrieve_segments(chunk_ids, neighbors=None, config=simple_config)
assert all(isinstance(segment, str) for segment in segments)
assert "Definition of Simultaneity" in segments[0] + segments[1]