Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/embedder #9

Merged
merged 20 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ document_store_filename: document_store.jsonl

# EmbeddingStore parameters
embedding_store_type: numpy
num_documents_to_retrieve: 3

# Embedder parameters
embedder_type: e5
document_text_field: text
embedder_id: intfloat/multilingual-e5-large
num_documents_to_retrieve: 3

# Generator parameters
generator_type: openai
Expand Down
97 changes: 96 additions & 1 deletion src/ragger/embedder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
"""Embed documents using a pre-trained model."""

import logging
import os
import re
from abc import ABC, abstractmethod

import numpy as np
import torch
from omegaconf import DictConfig
from sentence_transformers import SentenceTransformer

from .utils import Document

os.environ["TOKENIZERS_PARALLELISM"] = "false"

logger = logging.getLogger(__name__)


class Embedder(ABC):
"""An abstract embedder, which embeds documents using a pre-trained model."""
Expand Down Expand Up @@ -37,4 +46,90 @@ def embed_documents(self, documents: list[Document]) -> np.ndarray:
class E5Embedder(Embedder):
"""An embedder that uses an E5 model to embed documents."""

pass
def __init__(self, config: DictConfig) -> None:
"""Initialise the E5 embedder.

Args:
config:
The Hydra configuration.
"""
super().__init__(config)
self.embedder = SentenceTransformer(self.config.embedder_id)
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def embed_documents(self, documents: list[Document]) -> np.ndarray:
"""Embed a list of documents using an E5 model.

Args:
documents:
A list of documents to embed.

Returns:
An array of embeddings, where each row corresponds to a document.
"""
# Prepare the texts for embedding
texts = [document.text for document in documents]
prepared_texts = self._prepare_texts_for_embedding(texts=texts)

# Embed the texts
embeddings = self.embedder.encode(
sentences=prepared_texts,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=False,
)
assert isinstance(embeddings, np.ndarray)
return embeddings

def embed_query(self, query: str) -> np.ndarray:
"""Embed a query.

Args:
query:
A query.

Returns:
The embedding of the query.
"""
prepared_query = self._prepare_query_for_embedding(query=query)
query_embedding = self.embedder.encode(
sentences=[prepared_query],
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=False,
device=self.device,
)[0]
return query_embedding

def _prepare_texts_for_embedding(self, texts: list[str]) -> list[str]:
"""This prepares texts for embedding.

The precise preparation depends on the embedding model and usecase.

Args:
texts:
The texts to prepare.

Returns:
The prepared texts.
"""
return texts
AJDERS marked this conversation as resolved.
Show resolved Hide resolved

def _prepare_query_for_embedding(self, query: str) -> str:
"""This prepares a query for embedding.

The precise preparation depends on the embedding model.

Args:
query:
A query.

Returns:
A prepared query.
"""
# Add question marks at the end of the question, if not already present
saattrupdan marked this conversation as resolved.
Show resolved Hide resolved
query = re.sub(r"[。\?]$", "?", query).strip()
if not query.endswith("?"):
query += "?"
AJDERS marked this conversation as resolved.
Show resolved Hide resolved

return query
AJDERS marked this conversation as resolved.
Show resolved Hide resolved
75 changes: 74 additions & 1 deletion src/ragger/embedding_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Store and fetch embeddings from a database."""

from abc import ABC, abstractmethod
from pathlib import Path

import numpy as np
from omegaconf import DictConfig
from transformers import AutoConfig

from .utils import Index

Expand Down Expand Up @@ -47,4 +49,75 @@ def get_nearest_neighbours(self, embedding: np.ndarray) -> list[Index]:
class NumpyEmbeddingStore(EmbeddingStore):
"""An embedding store that fetches embeddings from a NumPy file."""

pass
def __init__(self, config: DictConfig) -> None:
"""Initialise the NumPy embedding store.

Args:
config:
The Hydra configuration.
"""
super().__init__(config)
self.embedding_dim = self._get_embedding_dimension()
self.embeddings = np.array([]).reshape(0, self.embedding_dim)
AJDERS marked this conversation as resolved.
Show resolved Hide resolved

def _get_embedding_dimension(self) -> int:
"""This returns the embedding dimension for the embedding model.

Returns:
The embedding dimension.
"""
model_config = AutoConfig.from_pretrained(self.config.embedder_id)
return model_config.hidden_size

def add_embeddings(self, embeddings: list[np.ndarray]) -> None:
"""Add embeddings to the store.

Args:
embeddings:
A list of embeddings to add to the store.
"""
self.embeddings = np.vstack([self.embeddings, np.array(embeddings)])
AJDERS marked this conversation as resolved.
Show resolved Hide resolved

def reset(self) -> None:
"""This resets the embeddings store."""
self.embeddings = np.array([]).reshape(0, self.embedding_dim)
AJDERS marked this conversation as resolved.
Show resolved Hide resolved

def save(self, path: Path | str) -> None:
"""This saves the embeddings store to disk.

This will store the embeddings in `npy`-file, called
`embeddings.npy`.

Args:
path: The path to the embeddings store in.
AJDERS marked this conversation as resolved.
Show resolved Hide resolved
"""
path = Path(path)
np.save(file=path, arr=self.embeddings)

def load(self, path: Path | str) -> None:
"""This loads the embeddings store from disk.

Args:
path:
The path to the zip file to load the embeddings store from.
"""
path = Path(path)
embeddings = np.load(file=path, allow_pickle=False)
assert self.embedding_dim == embeddings.shape[1]
self.embeddings = embeddings

def get_nearest_neighbours(self, embedding: np.ndarray) -> list[Index]:
"""Get the nearest neighbours to a given embedding.

Args:
embedding:
The embedding to find nearest neighbours for.

Returns:
A list of indices of the nearest neighbours.
"""
# Get the top-k documents
num_documents = self.config.num_documents_to_retrieve
scores = self.embeddings @ embedding
top_indices = np.argsort(scores)[::-1][:num_documents]
AJDERS marked this conversation as resolved.
Show resolved Hide resolved
return top_indices
64 changes: 64 additions & 0 deletions tests/test_embedder.py
Original file line number Diff line number Diff line change
@@ -1 +1,65 @@
"""Unit tests for the `embedder` module."""

from typing import Generator

import numpy as np
import pytest
from omegaconf import DictConfig
from ragger.embedder import E5Embedder, Embedder
from ragger.utils import Document


class TestE5Embedder:
"""Tests for the `Embedder` class."""

@pytest.fixture(scope="class")
def embedder(self) -> Generator[E5Embedder, None, None]:
"""Initialise an Embedder for testing."""
config = DictConfig(dict(embedder_id="intfloat/multilingual-e5-large"))
embedder = E5Embedder(config=config)
yield embedder

@pytest.fixture(scope="class")
def documents(self) -> list[Document]:
"""Initialise a list of documents for testing."""
return [
Document(id="1", text="Hello, world!"),
Document(id="2", text="Goodbye, world!"),
]

@pytest.fixture(scope="class")
def query(self) -> str:
"""Initialise a query for testing."""
return "Hello, world!"

def is_embedder(self):
"""Test that the Embedder is an ABC."""
assert issubclass(E5Embedder, Embedder)

def test_initialisation(self, embedder):
"""Test that the Embedder can be initialised."""
assert embedder

def test_embed(self, embedder, documents):
"""Test that the Embedder can embed text."""
embeddings = embedder.embed_documents(documents)
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape[0] == len(documents)

def test_embed_query(self, embedder, query):
"""Test that the Embedder can embed a query."""
embeddings = embedder.embed_query(query)
assert isinstance(embeddings, np.ndarray)

def test_prepare_query(self, embedder, query):
"""Test that the Embedder can prepare a query."""
prepared_query = embedder._prepare_query_for_embedding(query)
assert isinstance(prepared_query, str)
assert "?" in prepared_query

def test_prepare_texts(self, embedder, documents):
"""Test that the Embedder can prepare texts for embedding."""
texts = [document.text for document in documents]
prepared_texts = embedder._prepare_texts_for_embedding(texts)
assert isinstance(prepared_texts, list)
assert len(prepared_texts) == len(texts)
AJDERS marked this conversation as resolved.
Show resolved Hide resolved
75 changes: 75 additions & 0 deletions tests/test_embedding_store.py
Original file line number Diff line number Diff line change
@@ -1 +1,76 @@
"""Unit tests for the `embedding_store` module."""

from tempfile import NamedTemporaryFile
from typing import Generator

import numpy as np
import pytest
from omegaconf import DictConfig
from ragger.embedding_store import EmbeddingStore, NumpyEmbeddingStore


class TestNumpyEmbeddingStore:
"""Tests for the `NumpyEmbeddingStore` class."""

@pytest.fixture(scope="class")
def embedding_store(self) -> Generator[NumpyEmbeddingStore, None, None]:
"""Initialise a NumpyEmbeddingStore for testing."""
config = DictConfig(
dict(
num_documents_to_retrieve=2,
embedder_id="intfloat/multilingual-e5-large",
)
)
store = NumpyEmbeddingStore(config=config)
yield store

@pytest.fixture(scope="class")
def embeddings(self, embedding_store) -> list[np.array]:
"""Initialise a list of documents for testing."""
return [
np.ones(shape=(embedding_store.embedding_dim,)),
np.zeros(shape=(embedding_store.embedding_dim,)),
]

def is_embedding_store(self):
"""Test that the NumpyEmbeddingStore is an EmbeddingStore."""
assert issubclass(NumpyEmbeddingStore, EmbeddingStore)

def test_initialisation(self, embedding_store):
"""Test that the NumpyEmbeddingStore can be initialised."""
assert embedding_store

def test_add_embeddings(self, embedding_store, embeddings):
"""Test that embeddings can be added to the NumpyEmbeddingStore."""
embedding_store.add_embeddings(embeddings)
assert len(embedding_store.embeddings) == 2
assert np.array_equal(embedding_store.embeddings[0], embeddings[0])
assert np.array_equal(embedding_store.embeddings[1], embeddings[1])
embedding_store.reset()

def test_get_nearest_neighbours(self, embedding_store, embeddings):
"""Test that the nearest neighbours to an embedding can be found."""
embedding_store.add_embeddings(embeddings)
neighbours = embedding_store.get_nearest_neighbours(embeddings[0])
assert np.array_equal(np.array(neighbours), np.array([0, 1]))
neighbours = embedding_store.get_nearest_neighbours(embeddings[1])
assert np.array_equal(np.array(neighbours), np.array([1, 0]))
embedding_store.reset()

def test_reset(self, embedding_store, embeddings):
"""Test that the NumpyEmbeddingStore can be reset."""
embedding_store.add_embeddings(embeddings)
embedding_store.reset()
assert embedding_store.embeddings.shape == (0, embedding_store.embedding_dim)
embedding_store.reset()

def test_save_load(self, embedding_store, embeddings):
"""Test that the NumpyEmbeddingStore can be saved."""
embedding_store.add_embeddings(embeddings)
new_store = NumpyEmbeddingStore(embedding_store.config)
with NamedTemporaryFile(suffix=".npy") as file:
embedding_store.save(file.name)
new_store.load(file.name)
assert np.array_equal(new_store.embeddings, embedding_store.embeddings)
assert new_store.embedding_dim == embedding_store.embedding_dim
embedding_store.reset()
Loading