From 501063bf723f63db81a5fb34da7e90f05f16c2ab Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 12 Dec 2024 22:33:55 +0100 Subject: [PATCH] fix: Rerank test --- tests/test_rerank.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/test_rerank.py b/tests/test_rerank.py index 2ec359a..6a14afa 100644 --- a/tests/test_rerank.py +++ b/tests/test_rerank.py @@ -1,6 +1,7 @@ """Test RAGLite's reranking functionality.""" import random +from functools import partial from typing import TypeVar import pytest @@ -8,7 +9,14 @@ from rerankers.models.ranker import BaseRanker from scipy.stats import kendalltau -from raglite import RAGLiteConfig, rerank_chunks, retrieve_chunks, vector_search +from raglite import ( + RAGLiteConfig, + hybrid_search, + keyword_search, + rerank_chunks, + retrieve_chunks, + vector_search, +) from raglite._database import Chunk T = TypeVar("T") @@ -54,7 +62,15 @@ def test_reranker( ) # Search for a query. query = "What does it mean for two events to be simultaneous?" - chunk_ids, _ = vector_search(query, config=raglite_test_config, max_chunks=20) + chunk_ids, _ = hybrid_search( + query, + config=raglite_test_config, + subsearches=[ + partial(vector_search, max_chunks=40, config=raglite_test_config), + partial(keyword_search, max_chunks=40, config=raglite_test_config), + ], + max_chunks=20, + ) # Retrieve the chunks. chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) assert all(isinstance(chunk, Chunk) for chunk in chunks)