diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 42c3547b0b..5638f1d662 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None: self._cursor = [] self._dialect = None self._add_scores = False + self._scorer = "TFIDF" def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest": self._add_scores = True return self + def scorer(self, scorer: str) -> "AggregateRequest": + """ + Use a different scoring function to evaluate document relevance. + Default is `TFIDF`. + + :param scorer: The scoring function to use + (e.g. `TFIDF.DOCNORM` or `BM25`) + """ + self._scorer = scorer + return self + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self @@ -323,6 +335,9 @@ def build_args(self) -> List[str]: if self._verbatim: ret.append("VERBATIM") + if self._scorer: + ret.extend(["SCORER", self._scorer]) + if self._add_scores: ret.append("ADDSCORES") @@ -332,6 +347,7 @@ def build_args(self) -> List[str]: if self._loadall: ret.append("LOAD") ret.append("*") + elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 0e6fe22131..fb813b0bc7 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): assert res.rows[1] == ["__score", "0.2"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + assert await decoded_r.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "an animal book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + assert await decoded_r.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "an animal book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = await decoded_r.ft().aggregate( + req, + query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()}, + ) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 + + @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index dde59f0f87..0f0e7bb309 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1466,6 +1466,61 @@ def test_aggregations_add_scores(client): assert res.rows[1] == ["__score", "0.2"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_hybrid_scoring(client): + client.ft().create_index( + ( + TextField("name", sortable=True, weight=5.0), + TextField("description", sortable=True, weight=5.0), + VectorField( + "vector", + "HNSW", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + client.hset( + "doc1", + mapping={ + "name": "cat book", + "description": "an animal book about cats", + "vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(), + }, + ) + client.hset( + "doc2", + mapping={ + "name": "dog book", + "description": "an animal book about dogs", + "vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(), + }, + ) + + query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]" + req = ( + aggregations.AggregateRequest(query_string) + .scorer("BM25") + .add_scores() + .apply(hybrid_score="@__score + @dist") + .load("*") + .dialect(4) + ) + + res = client.ft().aggregate( + req, + query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()}, + ) + + if isinstance(res, dict): + assert len(res["results"]) == 2 + else: + assert len(res.rows) == 2 + for row in res.rows: + len(row) == 6 + + @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client):