Skip to content

Commit

Permalink
Feat/add postgres embedding store (#53)
Browse files Browse the repository at this point in the history
* feat: Add TxtDocumentStore

* feat: Add PostgresEmbeddingStore

* docs: Add `postgres` extra to readme

* tests: Split demo feedback test into two

* tests: Init default embedding store with no args

* feat: Create index on embedding column in PostgresEmbeddingStore

* chore: Install pgvector extension in CI

* fix: Set embedding_dim in PostgresEmbeddingStore

* feat: Add more dunder methods to EmbeddingStore

* tests: Use random embeddings

* fix: Pgvector fixes

* docs: Update cov badge

* chore: Change apt package name for pgvector and postgres

* chore: Set pgvector apt package

* chore: apt package

* chore: apt package

* chore: Yes to all

* docs: Add pgvector installation link
  • Loading branch information
saattrupdan authored Aug 13, 2024
1 parent 32bbc09 commit 735c346
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 70 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ jobs:
- name: Setup PostgreSQL server
run: |
sudo apt-get update
sudo apt-get install -y postgresql
sudo apt-get install -y postgresql-common
yes '' | sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
sudo apt-get install -y postgresql-16 postgresql-16-pgvector
sudo service postgresql start
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';"
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Added a `PostgresDocumentStore` that uses a PostgreSQL database to store documents.
- Added a `TxtDocumentStore` that reads documents from a single text file, separated by
newlines.
- Added a `PostgresEmbeddingStore` that uses a PostgreSQL database to store embeddings,
using the `pgvector` extension.

### Changed
- Added defaults to all arguments in each component's constructor, so that the
Expand Down
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
A package for general-purpose RAG applications.

______________________________________________________________________
[![Code Coverage](https://img.shields.io/badge/Coverage-72%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-74%25-yellow.svg)](https://github.com/alexandrainst/ragger/tree/main/tests)


Developer(s):
Expand All @@ -35,10 +35,17 @@ Installation with `poetry`:
poetry add git+ssh://[email protected]/alexandrainst/ragger.git --extras all
```

You can replace the `all` extra with any combination of `vllm`, `openai` and `demo` to
install only the components you need. For `pip`, this is done by comma-separating the
extras (e.g., `ragger[vllm,demo]`), while for `poetry`, you add multiple `--extras`
flags (e.g., `--extras vllm --extras demo`).
You can replace the `all` extra with any combination of the following, to install only
the components you need:

- `postgres`
- `vllm`
- `openai`
- `demo`

For `pip`, this is done by comma-separating the extras (e.g., `ragger[vllm,demo]`),
while for `poetry`, you add multiple `--extras` flags (e.g., `--extras vllm --extras
demo`).


## Quick Start
Expand Down Expand Up @@ -101,6 +108,11 @@ imported from `ragger.embedding_store`.

- `NumpyEmbeddingStore`: An embedding store that stores embeddings in a NumPy array.
(default)
- `PostgresEmbeddingStore`: An embedding store that uses a PostgreSQL database to store
embeddings, using the `pgvector` extension. This assumes that the PostgreSQL server is
already running, and that the `pgvector` extension is installed. See
[here](https://github.com/pgvector/pgvector?tab=readme-ov-file#installation) for more
information on how to install the extension.


### Generators
Expand Down
81 changes: 65 additions & 16 deletions src/ragger/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import annotated_types
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict

Index = str

Expand All @@ -17,6 +17,20 @@ class Document(BaseModel):
id: Index
text: str

def __eq__(self, other: object) -> bool:
"""Check if two documents are equal.
Args:
other:
The object to compare to.
Returns:
Whether the two documents are equal.
"""
if not isinstance(other, Document):
return False
return self.id == other.id and self.text == other.text


class Embedding(BaseModel):
"""An embedding of a document."""
Expand All @@ -26,14 +40,25 @@ class Embedding(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

def __eq__(self, other: object) -> bool:
"""Check if two embeddings are equal.
Args:
other:
The object to compare to.
Returns:
Whether the two embeddings are equal.
"""
if not isinstance(other, Embedding):
return False
return self.id == other.id and np.allclose(self.embedding, other.embedding)


class GeneratedAnswer(BaseModel):
"""A generated answer to a question."""

sources: typing.Annotated[
list[typing.Annotated[Index, annotated_types.Len(min_length=1)]],
Field(max_length=5),
]
sources: list[typing.Annotated[Index, annotated_types.Len(min_length=1)]]
answer: str = ""


Expand Down Expand Up @@ -216,7 +241,9 @@ def compile(
pass

@abstractmethod
def add_embeddings(self, embeddings: typing.Iterable[Embedding]) -> None:
def add_embeddings(
self, embeddings: typing.Iterable[Embedding]
) -> "EmbeddingStore":
"""Add embeddings to the store.
Args:
Expand All @@ -238,6 +265,29 @@ def get_nearest_neighbours(self, embedding: np.ndarray) -> list[Index]:
"""
...

@abstractmethod
def clear(self) -> None:
"""Clear all embeddings from the store."""
...

@abstractmethod
def remove(self) -> None:
"""Remove the embedding store."""
...

@abstractmethod
def __getitem__(self, document_id: Index) -> Embedding:
"""Fetch an embedding by its document ID.
Args:
document_id:
The ID of the document to fetch the embedding for.
Returns:
The embedding with the given document ID.
"""
...

@abstractmethod
def __contains__(self, document_id: Index) -> bool:
"""Check if a document exists in the store.
Expand All @@ -252,22 +302,21 @@ def __contains__(self, document_id: Index) -> bool:
...

@abstractmethod
def __len__(self) -> int:
"""Return the number of embeddings in the store.
def __iter__(self) -> typing.Generator[Embedding, None, None]:
"""Iterate over the embeddings in the store.
Returns:
The number of embeddings in the store.
Yields:
The embeddings in the store.
"""
...

@abstractmethod
def clear(self) -> None:
"""Clear all embeddings from the store."""
...
def __len__(self) -> int:
"""Return the number of embeddings in the store.
@abstractmethod
def remove(self) -> None:
"""Remove the embedding store."""
Returns:
The number of embeddings in the store.
"""
...

def __repr__(self) -> str:
Expand Down
Loading

0 comments on commit 735c346

Please sign in to comment.