diff --git a/tests/test_rerank.py b/tests/test_rerank.py index 2ec359a..6a14afa 100644 --- a/tests/test_rerank.py +++ b/tests/test_rerank.py @@ -1,6 +1,7 @@ """Test RAGLite's reranking functionality.""" import random +from functools import partial from typing import TypeVar import pytest @@ -8,7 +9,14 @@ from rerankers.models.ranker import BaseRanker from scipy.stats import kendalltau -from raglite import RAGLiteConfig, rerank_chunks, retrieve_chunks, vector_search +from raglite import ( + RAGLiteConfig, + hybrid_search, + keyword_search, + rerank_chunks, + retrieve_chunks, + vector_search, +) from raglite._database import Chunk T = TypeVar("T") @@ -54,7 +62,15 @@ def test_reranker( ) # Search for a query. query = "What does it mean for two events to be simultaneous?" - chunk_ids, _ = vector_search(query, config=raglite_test_config, max_chunks=20) + chunk_ids, _ = hybrid_search( + query, + config=raglite_test_config, + subsearches=[ + partial(vector_search, max_chunks=40, config=raglite_test_config), + partial(keyword_search, max_chunks=40, config=raglite_test_config), + ], + max_chunks=20, + ) # Retrieve the chunks. chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) assert all(isinstance(chunk, Chunk) for chunk in chunks)