From 735c3465488c5f96acb3143d2e52a8201381a326 Mon Sep 17 00:00:00 2001 From: Dan Saattrup Nielsen <47701536+saattrupdan@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:15:12 +0200 Subject: [PATCH] Feat/add postgres embedding store (#53) * feat: Add TxtDocumentStore * feat: Add PostgresEmbeddingStore * docs: Add `postgres` extra to readme * tests: Split demo feedback test into two * tests: Init default embedding store with no args * feat: Create index on embedding column in PostgresEmbeddingStore * chore: Install pgvector extension in CI * fix: Set embedding_dim in PostgresEmbeddingStore * feat: Add more dunder methods to EmbeddingStore * tests: Use random embeddings * fix: Pgvector fixes * docs: Update cov badge * chore: Change apt package name for pgvector and postgres * chore: Set pgvector apt package * chore: apt package * chore: apt package * chore: Yes to all * docs: Add pgvector installation link --- .github/workflows/ci.yaml | 4 +- CHANGELOG.md | 2 + README.md | 22 ++- src/ragger/data_models.py | 81 ++++++-- src/ragger/embedding_store.py | 340 +++++++++++++++++++++++++++++++++- tests/conftest.py | 8 +- tests/test_demo.py | 45 +++-- tests/test_embedding_store.py | 73 +++++--- 8 files changed, 505 insertions(+), 70 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a1bb9187..fd693745 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -35,7 +35,9 @@ jobs: - name: Setup PostgreSQL server run: | sudo apt-get update - sudo apt-get install -y postgresql + sudo apt-get install -y postgresql-common + yes '' | sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh + sudo apt-get install -y postgresql-16 postgresql-16-pgvector sudo service postgresql start sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';" diff --git a/CHANGELOG.md b/CHANGELOG.md index 20dffa30..3d9dd577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Added a `PostgresDocumentStore` that uses a PostgreSQL database to store documents. - Added a `TxtDocumentStore` that reads documents from a single text file, separated by newlines. +- Added a `PostgresEmbeddingStore` that uses a PostgreSQL database to store embeddings, + using the `pgvector` extension. ### Changed - Added defaults to all arguments in each component's constructor, so that the diff --git a/README.md b/README.md index a5411ecb..164ecc11 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ A package for general-purpose RAG applications. ______________________________________________________________________ -[![Code Coverage](https://img.shields.io/badge/Coverage-72%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests) +[![Code Coverage](https://img.shields.io/badge/Coverage-74%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests) Developer(s): @@ -35,10 +35,17 @@ Installation with `poetry`: poetry add git+ssh://git@github.com/alexandrainst/ragger.git --extras all ``` -You can replace the `all` extra with any combination of `vllm`, `openai` and `demo` to -install only the components you need. For `pip`, this is done by comma-separating the -extras (e.g., `ragger[vllm,demo]`), while for `poetry`, you add multiple `--extras` -flags (e.g., `--extras vllm --extras demo`). +You can replace the `all` extra with any combination of the following, to install only +the components you need: + +- `postgres` +- `vllm` +- `openai` +- `demo` + +For `pip`, this is done by comma-separating the extras (e.g., `ragger[vllm,demo]`), +while for `poetry`, you add multiple `--extras` flags (e.g., `--extras vllm --extras +demo`). ## Quick Start @@ -101,6 +108,11 @@ imported from `ragger.embedding_store`. - `NumpyEmbeddingStore`: An embedding store that stores embeddings in a NumPy array. (default) +- `PostgresEmbeddingStore`: An embedding store that uses a PostgreSQL database to store + embeddings, using the `pgvector` extension. This assumes that the PostgreSQL server is + already running, and that the `pgvector` extension is installed. See + [here](https://github.com/pgvector/pgvector?tab=readme-ov-file#installation) for more + information on how to install the extension. ### Generators diff --git a/src/ragger/data_models.py b/src/ragger/data_models.py index e1938148..1f5386b7 100644 --- a/src/ragger/data_models.py +++ b/src/ragger/data_models.py @@ -6,7 +6,7 @@ import annotated_types import numpy as np -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict Index = str @@ -17,6 +17,20 @@ class Document(BaseModel): id: Index text: str + def __eq__(self, other: object) -> bool: + """Check if two documents are equal. + + Args: + other: + The object to compare to. + + Returns: + Whether the two documents are equal. + """ + if not isinstance(other, Document): + return False + return self.id == other.id and self.text == other.text + class Embedding(BaseModel): """An embedding of a document.""" @@ -26,14 +40,25 @@ class Embedding(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) + def __eq__(self, other: object) -> bool: + """Check if two embeddings are equal. + + Args: + other: + The object to compare to. + + Returns: + Whether the two embeddings are equal. + """ + if not isinstance(other, Embedding): + return False + return self.id == other.id and np.allclose(self.embedding, other.embedding) + class GeneratedAnswer(BaseModel): """A generated answer to a question.""" - sources: typing.Annotated[ - list[typing.Annotated[Index, annotated_types.Len(min_length=1)]], - Field(max_length=5), - ] + sources: list[typing.Annotated[Index, annotated_types.Len(min_length=1)]] answer: str = "" @@ -216,7 +241,9 @@ def compile( pass @abstractmethod - def add_embeddings(self, embeddings: typing.Iterable[Embedding]) -> None: + def add_embeddings( + self, embeddings: typing.Iterable[Embedding] + ) -> "EmbeddingStore": """Add embeddings to the store. Args: @@ -238,6 +265,29 @@ def get_nearest_neighbours(self, embedding: np.ndarray) -> list[Index]: """ ... + @abstractmethod + def clear(self) -> None: + """Clear all embeddings from the store.""" + ... + + @abstractmethod + def remove(self) -> None: + """Remove the embedding store.""" + ... + + @abstractmethod + def __getitem__(self, document_id: Index) -> Embedding: + """Fetch an embedding by its document ID. + + Args: + document_id: + The ID of the document to fetch the embedding for. + + Returns: + The embedding with the given document ID. + """ + ... + @abstractmethod def __contains__(self, document_id: Index) -> bool: """Check if a document exists in the store. @@ -252,22 +302,21 @@ def __contains__(self, document_id: Index) -> bool: ... @abstractmethod - def __len__(self) -> int: - """Return the number of embeddings in the store. + def __iter__(self) -> typing.Generator[Embedding, None, None]: + """Iterate over the embeddings in the store. - Returns: - The number of embeddings in the store. + Yields: + The embeddings in the store. """ ... @abstractmethod - def clear(self) -> None: - """Clear all embeddings from the store.""" - ... + def __len__(self) -> int: + """Return the number of embeddings in the store. - @abstractmethod - def remove(self) -> None: - """Remove the embedding store.""" + Returns: + The number of embeddings in the store. + """ ... def __repr__(self) -> str: diff --git a/src/ragger/embedding_store.py b/src/ragger/embedding_store.py index 06767cf2..407f7d63 100644 --- a/src/ragger/embedding_store.py +++ b/src/ragger/embedding_store.py @@ -1,10 +1,13 @@ """Store and fetch embeddings from a database.""" +import importlib.util import io import json import logging +import typing import zipfile from collections import defaultdict +from contextlib import contextmanager from pathlib import Path from typing import Iterable @@ -19,6 +22,12 @@ Index, ) +if importlib.util.find_spec("psycopg2") is not None: + import psycopg2 + +if typing.TYPE_CHECKING: + import psycopg2 + logger = logging.getLogger(__package__) @@ -80,7 +89,7 @@ def row_id_to_index(self) -> dict[int, Index]: """Return a mapping of row IDs to indices.""" return {row_id: index for index, row_id in self.index_to_row_id.items()} - def add_embeddings(self, embeddings: Iterable[Embedding]) -> None: + def add_embeddings(self, embeddings: Iterable[Embedding]) -> "EmbeddingStore": """Add embeddings to the store. Args: @@ -92,7 +101,7 @@ def add_embeddings(self, embeddings: Iterable[Embedding]) -> None: If any of the embeddings already exist in the store. """ if not embeddings: - return + return self already_existing_indices = [ embedding.id @@ -112,7 +121,7 @@ def add_embeddings(self, embeddings: Iterable[Embedding]) -> None: if embedding.id not in self.index_to_row_id ] if not embeddings: - return + return self # In case we haven't inferred the embedding dimension yet, we do it now if self.embedding_dim is None or self.embeddings is None: @@ -135,6 +144,7 @@ def add_embeddings(self, embeddings: Iterable[Embedding]) -> None: logger.info("Added embeddings to the embedding store.") self._save(path=self.path) + return self def _save(self, path: Path | str) -> None: """Save the embedding store to disk. @@ -234,6 +244,28 @@ def remove(self) -> None: """Remove the embedding store.""" self.path.unlink(missing_ok=True) + def __getitem__(self, document_id: Index) -> Embedding: + """Fetch an embedding by its document ID. + + Args: + document_id: + The ID of the document to fetch the embedding for. + + Returns: + The embedding with the given document ID. + + Raises: + KeyError: + If the document ID does not exist in the store, or if the store is empty. + """ + if self.embeddings is None: + raise KeyError("The store is empty.") + if document_id not in self.index_to_row_id: + raise KeyError(f"The document ID {document_id!r} does not exist.") + row_id = self.index_to_row_id[document_id] + embedding = self.embeddings[row_id] + return Embedding(id=document_id, embedding=embedding) + def __contains__(self, document_id: Index) -> bool: """Check if a document exists in the store. @@ -246,6 +278,18 @@ def __contains__(self, document_id: Index) -> bool: """ return document_id in self.index_to_row_id + def __iter__(self) -> typing.Generator[Embedding, None, None]: + """Iterate over the embeddings in the store. + + Yields: + The embeddings in the store. + """ + if self.embeddings is None: + return + for document_id, row_id in self.index_to_row_id.items(): + embedding = self.embeddings[row_id] + yield Embedding(id=document_id, embedding=embedding) + def __len__(self) -> int: """Return the number of embeddings in the store. @@ -255,3 +299,293 @@ def __len__(self) -> int: if self.embeddings is None: return 0 return self.embeddings.shape[0] + + +class PostgresEmbeddingStore(EmbeddingStore): + """An embedding store that fetches embeddings from a PostgreSQL database.""" + + def __init__( + self, + embedding_dim: int | None = None, + host: str = "localhost", + port: int = 5432, + user: str | None = "postgres", + password: str | None = "postgres", + database_name: str = "postgres", + table_name: str = "embeddings", + id_column: str = "id", + embedding_column: str = "embedding", + ) -> None: + """Initialise the PostgreSQL embedding store. + + Args: + embedding_dim (optional): + The dimension of the embeddings. If None then the dimension will be + inferred when embeddings are added. Defaults to None. + host (optional): + The hostname of the PostgreSQL server. Defaults to "localhost". + port (optional): + The port of the PostgreSQL server. Defaults to 5432. + user (optional): + The username to use when connecting to the PostgreSQL server. Defaults + to "postgres". + password (optional): + The password to use when connecting to the PostgreSQL server. Defaults + to "postgres". + database_name (optional): + The name of the database to use. Defaults to "postgres". + table_name (optional): + The name of the table to use. Defaults to "documents". + id_column (optional): + The name of the column containing the document IDs. Defaults to "id". + embedding_column (optional): + The name of the column containing the embeddings. Defaults to + "embedding". + """ + psycopg2_not_installed = importlib.util.find_spec("psycopg2") is None + if psycopg2_not_installed: + raise ImportError( + "The `postgres` extra is required to use the `PostgresDocumentStore`. " + "Please install it by running `pip install ragger[postgres]@" + "git+ssh://git@github.com/alexandrainst/ragger.git` and try again." + ) + + self.embedding_dim = embedding_dim + self.host = host + self.port = port + self.user = user + self.password = password + self.database_name = database_name + self.table_name = table_name + self.id_column = id_column + self.embedding_column = embedding_column + + with self._connect() as conn: + cursor = conn.cursor() + try: + cursor.execute(f"CREATE DATABASE {database_name}") + except psycopg2.errors.DuplicateDatabase: + pass + try: + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") + except psycopg2.errors.UniqueViolation: + pass + + self._create_table() + + def _create_table(self) -> None: + """Create the table in the database.""" + if self.embedding_dim is None: + return + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + {self.id_column} TEXT PRIMARY KEY, + {self.embedding_column} VECTOR({self.embedding_dim}) + ) + """) + cursor.execute(f""" + CREATE INDEX IF NOT EXISTS cosine_hnsw_embedding_idx + ON {self.table_name} + USING hnsw ({self.embedding_column} vector_cosine_ops) + """) + + @contextmanager + def _connect(self) -> typing.Generator[psycopg2.extensions.connection, None, None]: + """Connect to the PostgreSQL database. + + Yields: + The connection to the database. + """ + connection = psycopg2.connect( + user=self.user, + password=self.password, + host=self.host, + port=self.port, + dbname=self.database_name, + ) + connection.autocommit = True + yield connection + connection.close() + + def add_embeddings( + self, embeddings: typing.Iterable[Embedding] + ) -> "EmbeddingStore": + """Add embeddings to the store. + + Args: + embeddings: + An iterable of embeddings to add to the store. + """ + if not embeddings: + return self + + # Ensure that we can access the embeddings multiple times + embeddings = list(embeddings) + + if self.embedding_dim is None: + self.embedding_dim = embeddings[0].embedding.shape[0] + self._create_table() + + id_embedding_pairs = [ + (embedding.id, json.dumps(embedding.embedding.tolist())) + for embedding in embeddings + ] + + with self._connect() as conn: + cursor = conn.cursor() + cursor.executemany( + f""" + INSERT INTO {self.table_name} ({self.id_column}, {self.embedding_column}) + VALUES (%s, %s) + ON CONFLICT ({self.id_column}) DO UPDATE + SET {self.embedding_column} = EXCLUDED.{self.embedding_column} + """, + id_embedding_pairs, + ) + return self + + def get_nearest_neighbours( + self, embedding: np.ndarray, num_docs: int = 5 + ) -> list[Index]: + """Get the nearest neighbours to a given embedding. + + Args: + embedding: + The embedding to find nearest neighbours for. + num_docs (optional): + The number of documents to retrieve. Defaults to 5. + + Returns: + A list of indices of the nearest neighbours. + """ + if self.embedding_dim is None: + return list() + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT {self.id_column} + FROM {self.table_name} + ORDER BY {self.embedding_column} <=> %s + LIMIT {num_docs} + """, + (json.dumps(embedding.tolist()),), + ) + return [row[0] for row in cursor.fetchall()] + + def __getitem__(self, document_id: Index) -> Embedding: + """Fetch an embedding by its document ID. + + Args: + document_id: + The ID of the document to fetch the embedding for. + + Returns: + The embedding with the given document ID. + + Raises: + KeyError: + If the document ID does not exist in the store, or if the store is empty. + """ + if self.embedding_dim is None: + raise KeyError("The store is empty.") + + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT {self.embedding_column} + FROM {self.table_name} + WHERE {self.id_column} = %s + """, + (document_id,), + ) + result = cursor.fetchone() + if result is None: + raise KeyError(f"The document ID {document_id!r} does not exist.") + embedding = np.asarray(json.loads(result[0])) + return Embedding(id=document_id, embedding=embedding) + + def __contains__(self, document_id: Index) -> bool: + """Check if a document exists in the store. + + Args: + document_id: + The ID of the document to check. + + Returns: + Whether the document exists in the store. + """ + if self.embedding_dim is None: + return False + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT 1 + FROM {self.table_name} + WHERE {self.id_column} = %s + """, + (document_id,), + ) + return cursor.fetchone() is not None + + def __iter__(self) -> typing.Generator[Embedding, None, None]: + """Iterate over the embeddings in the store. + + Yields: + The embeddings in the store. + """ + if self.embedding_dim is None: + return + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT {self.id_column}, {self.embedding_column} + FROM {self.table_name} + WHERE {self.embedding_column} IS NOT NULL + """ + ) + for row in cursor.fetchall(): + embedding = np.asarray(json.loads(row[1])) + yield Embedding(id=row[0], embedding=embedding) + + def __len__(self) -> int: + """Return the number of embeddings in the store. + + Returns: + The number of embeddings in the store. + """ + if self.embedding_dim is None: + return 0 + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute(f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE {self.embedding_column} IS NOT NULL + """) + result = cursor.fetchone() + assert result is not None + return result[0] + + def clear(self) -> None: + """Clear all embeddings from the store.""" + if self.embedding_dim is None: + return + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute(f""" + UPDATE {self.table_name} SET {self.embedding_column} = NULL + """) + + def remove(self) -> None: + """Remove the embedding store.""" + if self.embedding_dim is None: + return + with self._connect() as conn: + cursor = conn.cursor() + cursor.execute(f"DROP TABLE {self.table_name}") diff --git a/tests/conftest.py b/tests/conftest.py index 591cf19f..513a4a05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,13 +70,9 @@ def default_embedder() -> typing.Generator[Embedder, None, None]: @pytest.fixture(scope="session") -def default_embedding_store( - default_embedder, -) -> typing.Generator[EmbeddingStore, None, None]: +def default_embedding_store() -> typing.Generator[EmbeddingStore, None, None]: """An embedding store for testing.""" - embedding_store = NumpyEmbeddingStore( - embedding_dim=default_embedder.embedding_dim, path=Path("test-embeddings.zip") - ) + embedding_store = NumpyEmbeddingStore() yield embedding_store embedding_store.clear() diff --git a/tests/test_demo.py b/tests/test_demo.py index e15d9112..e1b7562e 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -27,21 +27,38 @@ def test_initialisation(demo): def test_initialisation_with_feedback(rag_system): """Test the initialisation of the demo.""" - feedback_modes: list[typing.Literal["strict-feedback", "feedback"]] = [ - "strict-feedback", - "feedback", - ] - sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='feedback'" - for feedback_mode in feedback_modes: - with NamedTemporaryFile(mode="w", suffix=".db") as file: - demo = Demo( - feedback_mode=feedback_mode, - rag_system=rag_system, - feedback_db_path=Path(file.name), + with NamedTemporaryFile(mode="w", suffix=".db") as file: + demo = Demo( + feedback_mode="feedback", + rag_system=rag_system, + feedback_db_path=Path(file.name), + ) + with sqlite3.connect(database=demo.feedback_db_path) as connection: + assert ( + connection.execute(""" + SELECT name FROM sqlite_master WHERE type='table' AND name='feedback' + """).fetchone() + is not None + ) + file.close() + + +def test_initialisation_with_strict_feedback(rag_system): + """Test the initialisation of the demo.""" + with NamedTemporaryFile(mode="w", suffix=".db") as file: + demo = Demo( + feedback_mode="strict-feedback", + rag_system=rag_system, + feedback_db_path=Path(file.name), + ) + with sqlite3.connect(database=demo.feedback_db_path) as connection: + assert ( + connection.execute(""" + SELECT name FROM sqlite_master WHERE type='table' AND name='feedback' + """).fetchone() + is not None ) - with sqlite3.connect(database=demo.feedback_db_path) as connection: - assert connection.execute(sql).fetchone() - file.close() + file.close() def test_build(demo): diff --git a/tests/test_embedding_store.py b/tests/test_embedding_store.py index f7ef4987..84672d45 100644 --- a/tests/test_embedding_store.py +++ b/tests/test_embedding_store.py @@ -9,6 +9,21 @@ from ragger.embedding_store import EmbeddingStore +@pytest.fixture(scope="module") +def embeddings(default_embedder) -> typing.Generator[list[Embedding], None, None]: + """Initialise a list of documents for testing.""" + rng = np.random.default_rng(seed=4242) + yield [ + Embedding( + id="an id", embedding=rng.random(size=(default_embedder.embedding_dim,)) + ), + Embedding( + id="another id", + embedding=rng.random(size=(default_embedder.embedding_dim,)), + ), + ] + + @pytest.fixture( scope="module", params=[ @@ -38,35 +53,11 @@ def embedding_store( embedding_store.remove() -@pytest.fixture(scope="module") -def embeddings() -> typing.Generator[list[Embedding], None, None]: - """Initialise a list of documents for testing.""" - yield [ - Embedding(id="an id", embedding=np.ones(shape=(8,))), - Embedding(id="another id", embedding=np.zeros(shape=(8,))), - ] - - def test_initialisation(embedding_store): """Test that the embedding store can be initialised.""" assert isinstance(embedding_store, EmbeddingStore) -def test_add_embeddings(embedding_store, embeddings): - """Test that embeddings can be added to the embedding store.""" - embedding_store.clear() - embedding_store.add_embeddings(embeddings=embeddings) - assert len(embedding_store.embeddings) == 2 - assert np.array_equal( - embedding_store.embeddings[embedding_store.index_to_row_id["an id"]], - embeddings[0].embedding, - ) - assert np.array_equal( - embedding_store.embeddings[embedding_store.index_to_row_id["another id"]], - embeddings[1].embedding, - ) - - def test_get_nearest_neighbours(embedding_store, embeddings): """Test that the nearest neighbours to an embedding can be found.""" embedding_store.clear() @@ -85,4 +76,36 @@ def test_clear(embedding_store, embeddings): """Test that the embedding store can be cleared.""" embedding_store.add_embeddings(embeddings=embeddings) embedding_store.clear() - assert embedding_store.embeddings.shape == (0, embedding_store.embedding_dim) + assert len(embedding_store) == 0 + + +def test_getitem(embedding_store, embeddings): + """Test that embeddings can be fetched from the embedding store.""" + embedding_store.clear() + embedding_store.add_embeddings(embeddings=embeddings) + for embedding in embeddings: + assert embedding_store[embedding.id] == embedding + + +def test_getitem_missing(embedding_store, embeddings, non_existing_id): + """Test that fetching a missing embedding raises a KeyError.""" + embedding_store.clear() + embedding_store.add_embeddings(embeddings=embeddings) + with pytest.raises(KeyError): + embedding_store[non_existing_id] + + +def test_contains(embeddings, embedding_store, non_existing_id): + """Test that the embedding store can check if it contains a embedding.""" + embedding_store.clear() + embedding_store.add_embeddings(embeddings=embeddings) + for embedding in embeddings: + assert embedding.id in embedding_store + assert non_existing_id not in embedding_store + + +def test_len(embedding_store, embeddings): + """Test that the embedding store can return the number of embeddings.""" + embedding_store.clear() + embedding_store.add_embeddings(embeddings=embeddings) + assert len(embedding_store) == len(embeddings)