Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch rerankers flashrank issue #22

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with Postgr

## Installing

First, begin by installing SpaCy's multilingual sentence model:
First, begin by installing spaCy's multilingual sentence model:

```sh
# Install SpaCy's xx_sent_ud_sm:
# Install spaCy's xx_sent_ud_sm:
pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl
```

Expand Down
3 changes: 2 additions & 1 deletion src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from llama_cpp import llama_supports_gpu_offload
from sqlalchemy.engine import URL

from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
with contextlib.redirect_stdout(StringIO()):
from rerankers.models.flashrank_ranker import FlashRankRanker
from rerankers.models.ranker import BaseRanker


Expand Down
41 changes: 41 additions & 0 deletions src/raglite/_flashrank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Patched version of FlashRankRanker that fixes incorrect reranking [1].

[1] https://github.com/AnswerDotAI/rerankers/issues/39
"""

import contextlib
from io import StringIO
from typing import Any

from flashrank import RerankRequest

# Suppress rerankers output on import until [1] is fixed.
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
with contextlib.redirect_stdout(StringIO()):
from rerankers.documents import Document
from rerankers.models.flashrank_ranker import FlashRankRanker
from rerankers.results import RankedResults, Result
from rerankers.utils import prep_docs


class PatchedFlashRankRanker(FlashRankRanker):
def rank(
self,
query: str,
docs: str | list[str] | Document | list[Document],
doc_ids: list[str] | list[int] | None = None,
metadata: list[dict[str, Any]] | None = None,
) -> RankedResults:
docs = prep_docs(docs, doc_ids, metadata)
passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)]
rerank_request = RerankRequest(query=query, passages=passages)
flashrank_results = self.model.rerank(rerank_request)
ranked_results = [
Result(
document=docs[result["id"]], # This patches the incorrect ranking in the original.
score=result["score"],
rank=idx + 1,
)
for idx, result in enumerate(flashrank_results)
]
return RankedResults(results=ranked_results, query=query, has_scores=True)
34 changes: 27 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import os
import socket
import tempfile
from collections.abc import Generator
from pathlib import Path

import pytest
from sqlalchemy import create_engine, text

from raglite import RAGLiteConfig
from raglite import RAGLiteConfig, insert_document

POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"

Expand All @@ -26,7 +29,7 @@ def is_openai_available() -> bool:


def pytest_sessionstart(session: pytest.Session) -> None:
"""Reset the PostgreSQL database."""
"""Reset the PostgreSQL and SQLite databases."""
if is_postgres_running():
engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
with engine.connect() as conn:
Expand All @@ -35,9 +38,18 @@ def pytest_sessionstart(session: pytest.Session) -> None:
conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))


@pytest.fixture(scope="session")
def sqlite_url() -> Generator[str, None, None]:
"""Create a temporary SQLite database file and return the database URL."""
with tempfile.TemporaryDirectory() as temp_dir:
db_file = Path(temp_dir) / "raglite_test.sqlite"
yield f"sqlite:///{db_file}"


@pytest.fixture(
scope="session",
params=[
pytest.param("sqlite:///:memory:", id="sqlite"),
pytest.param("sqlite", id="sqlite"),
pytest.param(
POSTGRES_URL,
id="postgres",
Expand All @@ -47,11 +59,14 @@ def pytest_sessionstart(session: pytest.Session) -> None:
)
def database(request: pytest.FixtureRequest) -> str:
"""Get a database URL to test RAGLite with."""
db_url: str = request.param
db_url: str = (
request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param
)
return db_url


@pytest.fixture(
scope="session",
params=[
pytest.param(
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf",
Expand All @@ -70,13 +85,18 @@ def embedder(request: pytest.FixtureRequest) -> str:
return embedder


@pytest.fixture
@pytest.fixture(scope="session")
def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
"""Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
# Select the PostgreSQL database based on the embedder.
# Select the database based on the embedder.
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
if "postgres" in database:
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
database = database.replace("/postgres", f"/raglite_test_{variant}")
elif "sqlite" in database:
database = database.replace(".sqlite", f"_{variant}.sqlite")
# Create a RAGLite config for the given database and embedder.
db_config = RAGLiteConfig(db_url=database, embedder=embedder)
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=db_config)
return db_config
51 changes: 22 additions & 29 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,40 @@
"""Test RAGLite's RAG functionality."""

import os
from pathlib import Path
from typing import TYPE_CHECKING

import pytest
from llama_cpp import llama_supports_gpu_offload

from raglite import RAGLiteConfig, hybrid_search, insert_document, rag, retrieve_segments
from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks

if TYPE_CHECKING:
from raglite._database import Chunk
from raglite._typing import SearchMethod


def is_accelerator_available() -> bool:
"""Check if an accelerator is available."""
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004


def test_insert_index_search(raglite_test_config: RAGLiteConfig) -> None:
"""Test inserting a document, updating the indexes, and searching for a query."""
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=raglite_test_config)
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
chunk_ids, scores = hybrid_search(query, config=raglite_test_config)
assert len(chunk_ids) == len(scores)
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)
# Group the chunks into segments and retrieve them.
segments = retrieve_segments(chunk_ids, neighbors=None, config=raglite_test_config)
assert all(isinstance(segment, str) for segment in segments)
assert "Definition of Simultaneity" in "".join(segments[:2])


@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation."""
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=raglite_test_config)
# Answer a question with RAG.
# Assemble different types of search inputs for RAG.
prompt = "What does it mean for two events to be simultaneous?"
stream = rag(prompt, search=hybrid_search, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
hybrid_search, # A search method as input.
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
retrieve_chunks( # Chunks as input.
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
),
]
# Answer a question with RAG.
for search_input in search_inputs:
stream = rag(prompt, search=search_input, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()
54 changes: 54 additions & 0 deletions tests/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test RAGLite's reranking functionality."""

import pytest
from rerankers.models.ranker import BaseRanker

from raglite import RAGLiteConfig, hybrid_search, rerank, retrieve_chunks
from raglite._database import Chunk
from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker


@pytest.fixture(
params=[
pytest.param(None, id="no_reranker"),
pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"),
pytest.param(
(
("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)),
("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)),
),
id="flashrank_multilingual",
),
],
)
def reranker(
request: pytest.FixtureRequest,
) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None:
"""Get a reranker to test RAGLite with."""
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param
return reranker


def test_reranker(
raglite_test_config: RAGLiteConfig,
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None,
) -> None:
"""Test inserting a document, updating the indexes, and searching for a query."""
# Update the config with the reranker.
raglite_test_config = RAGLiteConfig(
db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker
)
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config)
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
# Rerank the chunks given an inverted chunk order.
reranked_chunks = rerank(query, chunks[::-1], config=raglite_test_config)
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
assert reranked_chunks[:3] == chunks[:3]
# Test that we can also rerank given the chunk_ids only.
reranked_chunks = rerank(query, chunk_ids[::-1], config=raglite_test_config)
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
assert reranked_chunks[:3] == chunks[:3]
47 changes: 47 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test RAGLite's search functionality."""

import pytest

from raglite import (
RAGLiteConfig,
hybrid_search,
keyword_search,
retrieve_chunks,
retrieve_segments,
vector_search,
)
from raglite._database import Chunk
from raglite._typing import SearchMethod


@pytest.fixture(
params=[
pytest.param(keyword_search, id="keyword_search"),
pytest.param(vector_search, id="vector_search"),
pytest.param(hybrid_search, id="hybrid_search"),
],
)
def search_method(
request: pytest.FixtureRequest,
) -> SearchMethod:
"""Get a search method to test RAGLite with."""
search_method: SearchMethod = request.param
return search_method


def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:
"""Test searching for a query."""
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
num_results = 5
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
assert len(chunk_ids) == len(scores) == num_results
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
# Extend the chunks with their neighbours and group them into contiguous segments.
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
assert all(isinstance(segment, str) for segment in segments)
Loading