From a63760b6cacfbc90b003c8fc7ca74822daada4ed Mon Sep 17 00:00:00 2001 From: cjkindel Date: Mon, 30 Dec 2024 11:06:19 -0800 Subject: [PATCH] Implement query_vector() for all vector_store_drivers --- .../vector/astradb_vector_store_driver.py | 36 ++++++++++-- .../azure_mongodb_vector_store_driver.py | 29 ++++++++-- .../vector/base_vector_store_driver.py | 11 ++++ .../vector/dummy_vector_store_driver.py | 11 ++++ .../griptape_cloud_vector_store_driver.py | 11 ++++ .../vector/local_vector_store_driver.py | 20 +++++-- .../vector/marqo_vector_store_driver.py | 11 ++++ .../mongodb_atlas_vector_store_driver.py | 29 ++++++++-- .../vector/opensearch_vector_store_driver.py | 36 ++++++++++-- .../vector/pgvector_vector_store_driver.py | 27 +++++++-- .../vector/pinecone_vector_store_driver.py | 26 +++++++-- .../vector/qdrant_vector_store_driver.py | 33 +++++++++-- .../vector/redis_vector_store_driver.py | 25 +++++++-- .../test_astra_db_vector_store_driver.py | 22 ++++++++ .../test_azure_mongodb_vector_store_driver.py | 16 ++++++ .../vector/test_dummy_vector_store_driver.py | 4 ++ ...loud_knowledge_base_vector_store_driver.py | 4 ++ .../vector/test_marqo_vector_store_driver.py | 4 ++ .../test_mongodb_atlas_vector_store_driver.py | 16 ++++++ .../test_opensearch_vector_store_driver.py | 10 ++++ .../test_pgvector_vector_store_driver.py | 55 +++++++++++++++++++ .../test_pinecone_vector_storage_driver.py | 6 ++ .../vector/test_qdrant_vector_store_driver.py | 26 +++++++++ .../vector/test_redis_vector_store_driver.py | 20 +++++++ 24 files changed, 444 insertions(+), 44 deletions(-) diff --git a/griptape/drivers/vector/astradb_vector_store_driver.py b/griptape/drivers/vector/astradb_vector_store_driver.py index 1e8398809..0456e16e9 100644 --- a/griptape/drivers/vector/astradb_vector_store_driver.py +++ b/griptape/drivers/vector/astradb_vector_store_driver.py @@ -140,19 +140,19 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto for match in self.collection.find(filter=find_filter, projection={"*": 1}) ] - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: - """Run a similarity search on the Astra DB store, based on a query string. + """Run a similarity search on the Astra DB store, based on a vector list. Args: - query: the query string. + vector: the vector to be queried. count: the maximum number of results to return. If omitted, defaults will apply. namespace: the namespace to filter results by. include_vectors: whether to include vector data in the results. @@ -168,7 +168,6 @@ def query( find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace} find_filter = {**(query_filter or {}), **find_filter_ns} find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None - vector = self.embedding_driver.embed_string(query) ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT matches = self.collection.find( filter=find_filter, @@ -187,3 +186,30 @@ def query( ) for match in matches ] + + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs: Any, + ) -> list[BaseVectorStoreDriver.Entry]: + """Run a similarity search on the Astra DB store, based on a query string. + + Args: + query: the query string. + count: the maximum number of results to return. If omitted, defaults will apply. + namespace: the namespace to filter results by. + include_vectors: whether to include vector data in the results. + kwargs: additional keyword arguments. Currently only the free-form dict `filter` + is recognized (and goes straight to the Data API query); + others will generate a warning and be ignored. + + Returns: + A list of vector (`BaseVectorStoreDriver.Entry`) entries, + with their `score` attribute set to the vector similarity to the query. + """ + vector = self.embedding_driver.embed_string(query) + return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) diff --git a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py index 993f7a300..06df781fb 100644 --- a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py +++ b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py @@ -11,9 +11,9 @@ class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver): """A Vector Store Driver for CosmosDB with MongoDB vCore API.""" - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -21,15 +21,12 @@ def query( offset: Optional[int] = None, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: - """Queries the MongoDB collection for documents that match the provided query string. + """Queries the MongoDB collection for documents that match the provided vector list. Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. """ collection = self.get_collection() - # Using the embedding driver to convert the query string into a vector - vector = self.embedding_driver.embed_string(query) - count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT offset = offset or 0 @@ -63,3 +60,23 @@ def query( ) for doc in collection.aggregate(pipeline) ] + + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + offset: Optional[int] = None, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Queries the MongoDB collection for documents that match the provided query string. + + Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. + """ + # Using the embedding driver to convert the query string into a vector + vector = self.embedding_driver.embed_string(query) + return self.query_vector( + vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs + ) diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 13aa3f193..c1bfc653d 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -139,6 +139,17 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti @abstractmethod def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ... + @abstractmethod + def query_vector( + self, + vector: list[float], + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> list[Entry]: ... + @abstractmethod def query( self, diff --git a/griptape/drivers/vector/dummy_vector_store_driver.py b/griptape/drivers/vector/dummy_vector_store_driver.py index 3bbfbd304..45a6f7ab8 100644 --- a/griptape/drivers/vector/dummy_vector_store_driver.py +++ b/griptape/drivers/vector/dummy_vector_store_driver.py @@ -35,6 +35,17 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "load_entries") + def query_vector( + self, + vector: list[float], + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + raise DummyError(__class__.__name__, "query_vector") + def query( self, query: str, diff --git a/griptape/drivers/vector/griptape_cloud_vector_store_driver.py b/griptape/drivers/vector/griptape_cloud_vector_store_driver.py index 9f902b976..4ca584d0b 100644 --- a/griptape/drivers/vector/griptape_cloud_vector_store_driver.py +++ b/griptape/drivers/vector/griptape_cloud_vector_store_driver.py @@ -79,6 +79,17 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.") + def query_vector( + self, + vector: list[float], + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> NoReturn: + raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.") + def query( self, query: str, diff --git a/griptape/drivers/vector/local_vector_store_driver.py b/griptape/drivers/vector/local_vector_store_driver.py index 557937431..8d03d6787 100644 --- a/griptape/drivers/vector/local_vector_store_driver.py +++ b/griptape/drivers/vector/local_vector_store_driver.py @@ -78,24 +78,22 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace] - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: - query_embedding = self.embedding_driver.embed_string(query) - if namespace: entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")} else: entries = self.entries entries_and_relatednesses = [ - (entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values()) + (entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values()) ] entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) @@ -113,6 +111,18 @@ def query( for r in result ] + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + vector = self.embedding_driver.embed_string(query) + return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) + def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index ce431d38d..f1bd86a3c 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -165,6 +165,17 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto return entries + def query_vector( + self, + vector: list[float], + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> NoReturn: + raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.") + def query( self, query: str, diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index a6f32620a..d02f48c41 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -114,9 +114,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto for doc in cursor ] - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -124,15 +124,12 @@ def query( offset: Optional[int] = None, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: - """Queries the MongoDB collection for documents that match the provided query string. + """Queries the MongoDB collection for documents that match the provided vector list. Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. """ collection = self.get_collection() - # Using the embedding driver to convert the query string into a vector - vector = self.embedding_driver.embed_string(query) - count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT offset = offset or 0 @@ -171,6 +168,26 @@ def query( for doc in collection.aggregate(pipeline) ] + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + offset: Optional[int] = None, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Queries the MongoDB collection for documents that match the provided query string. + + Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. + """ + # Using the embedding driver to convert the query string into a vector + vector = self.embedding_driver.embed_string(query) + return self.query_vector( + vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs + ) + def delete_vector(self, vector_id: str) -> None: """Deletes the vector from the collection.""" collection = self.get_collection() diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index 5f247f6db..a1311b9fc 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -119,9 +119,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto for hit in response["hits"]["hits"] ] - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -130,7 +130,7 @@ def query( field_name: str = "vector", **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: - """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string. + """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list. Results can be limited using the count parameter and optionally filtered by a namespace. @@ -138,7 +138,6 @@ def query( A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. """ count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT - vector = self.embedding_driver.embed_string(query) # Base k-NN query query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}} @@ -165,5 +164,34 @@ def query( for hit in response["hits"]["hits"] ] + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + include_metadata: bool = True, + field_name: str = "vector", + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string. + + Results can be limited using the count parameter and optionally filtered by a namespace. + + Returns: + A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. + """ + vector = self.embedding_driver.embed_string(query) + return self.query_vector( + vector, + count=count, + namespace=namespace, + include_vectors=include_vectors, + include_metadata=include_metadata, + field_name=field_name, + **kwargs, + ) + def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index 038925df3..254cb82b8 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -127,9 +127,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto for result in results ] - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, namespace: Optional[str] = None, @@ -152,8 +152,6 @@ def query( op = distance_metrics[distance_metric] with sqlalchemy_orm.Session(self.engine) as session: - vector = self.embedding_driver.embed_string(query) - # The query should return both the vector and the distance metric score. query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector)) # pyright: ignore[reportOptionalCall] @@ -182,6 +180,27 @@ def query( for result in results ] + def query( + self, + query: str, + *, + count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + namespace: Optional[str] = None, + include_vectors: bool = False, + distance_metric: str = "cosine_distance", + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.""" + vector = self.embedding_driver.embed_string(query) + return self.query_vector( + vector, + count=count, + namespace=namespace, + include_vectors=include_vectors, + distance_metric=distance_metric, + **kwargs, + ) + def default_vector_model(self) -> Any: pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy") sqlalchemy = import_optional_dependency("sqlalchemy") diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index 81e593e72..22d2b7b40 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -87,9 +87,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto for r in results["matches"] ] - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -97,8 +97,6 @@ def query( include_metadata: bool = True, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: - vector = self.embedding_driver.embed_string(query) - params = { "top_k": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "namespace": namespace, @@ -119,5 +117,25 @@ def query( for r in results["matches"] ] + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + vector = self.embedding_driver.embed_string(query) + return self.query_vector( + vector, + count=count, + namespace=namespace, + include_vectors=include_vectors, + include_metadata=include_metadata, + **kwargs, + ) + def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py index 79cf64f37..58c9586f4 100644 --- a/griptape/drivers/vector/qdrant_vector_store_driver.py +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -93,9 +93,9 @@ def delete_vector(self, vector_id: str) -> None: if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED: logging.info("ID %s is successfully deleted", vector_id) - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -105,7 +105,7 @@ def query( """Query the Qdrant collection based on a query vector. Parameters: - query (str): Query string. + query (list[float]): Query vector. count (Optional[int]): Optional number of results to return. namespace (Optional[str]): Optional namespace of the vectors. include_vectors (bool): Whether to include vectors in the results. @@ -113,10 +113,8 @@ def query( Returns: list[BaseVectorStoreDriver.Entry]: List of Entry objects. """ - query_vector = self.embedding_driver.embed_string(query) - # Create a search request - request = {"collection_name": self.collection_name, "query_vector": query_vector, "limit": count} + request = {"collection_name": self.collection_name, "query_vector": vector, "limit": count} request = {k: v for k, v in request.items() if v is not None} results = self.client.search(**request) @@ -131,6 +129,29 @@ def query( for result in results ] + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Query the Qdrant collection based on a query vector. + + Parameters: + query (str): Query string. + count (Optional[int]): Optional number of results to return. + namespace (Optional[str]): Optional namespace of the vectors. + include_vectors (bool): Whether to include vectors in the results. + + Returns: + list[BaseVectorStoreDriver.Entry]: List of Entry objects. + """ + vector = self.embedding_driver.embed_string(query) + return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) + def upsert_vector( self, vector: list[float], diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index d220878f3..0bddb7fc6 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -107,9 +107,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto return entries - def query( + def query_vector( self, - query: str, + vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -125,8 +125,6 @@ def query( """ search_query = import_optional_dependency("redis.commands.search.query") - vector = self.embedding_driver.embed_string(query) - filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*" query_expression = ( search_query.Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]") @@ -157,6 +155,25 @@ def query( ) return query_results + def query( + self, + query: str, + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector. + + Results can be limited using the count parameter and optionally filtered by a namespace. + + Returns: + A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. + """ + vector = self.embedding_driver.embed_string(query) + return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) + def _generate_key(self, vector_id: str, namespace: Optional[str] = None) -> str: """Generates a Redis key using the provided vector ID and optionally a namespace.""" return f"{namespace}:{vector_id}" if namespace else vector_id diff --git a/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py b/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py index b544a3494..c21914b62 100644 --- a/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_astra_db_vector_store_driver.py @@ -111,6 +111,28 @@ def test_load_entries(self, driver, mock_collection, one_entry): projection={"*": 1}, ) + def test_query_vector_allparams(self, driver, mock_collection, one_query_entry): + entries1 = driver.query_vector([0.0, 0.5], count=999, namespace="some_namespace", include_vectors=True) + assert entries1 == [one_query_entry] + mock_collection.return_value.find.assert_called_once_with( + filter={"namespace": "some_namespace"}, + sort={"$vector": [0.0, 0.5]}, + limit=999, + projection={"*": 1}, + include_similarity=True, + ) + + def test_query_vector_minparams(self, driver, mock_collection, one_query_entry): + entries0 = driver.query_vector([0.0, 0.5]) + assert entries0 == [one_query_entry] + mock_collection.return_value.find.assert_called_once_with( + filter={}, + sort={"$vector": [0.0, 0.5]}, + limit=BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + projection=None, + include_similarity=True, + ) + def test_query_allparams(self, driver, mock_collection, one_query_entry): entries1 = driver.query("some query", count=999, namespace="some_namespace", include_vectors=True) assert entries1 == [one_query_entry] diff --git a/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py b/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py index 6dd4fa5e9..6306e9998 100644 --- a/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py @@ -37,6 +37,22 @@ def test_upsert_text(self, driver): test_id = driver.upsert_text(text, vector_id=vector_id_str) assert test_id == vector_id_str + def test_query_vector(self, driver, monkeypatch): + mock_query_result = [ + BaseVectorStoreDriver.Entry("foo", [0.5, 0.5, 0.5], score=0.0, meta={}, namespace=None), + BaseVectorStoreDriver.Entry("foo", vector=[0.5, 0.5, 0.5], score=0.0, meta={}, namespace=None), + ] + + monkeypatch.setattr(AzureMongoDbVectorStoreDriver, "query_vector", lambda *args, **kwargs: mock_query_result) + + query_vector = [0.0, 0.5, 1.0] + results = driver.query_vector(query_vector, include_vectors=True) + assert len(results) == len(mock_query_result) + for result, expected in zip(results, mock_query_result): + assert result.id == expected.id + assert result.vector == expected.vector + assert isinstance(result, BaseVectorStoreDriver.Entry) + def test_query(self, driver, monkeypatch): mock_query_result = [ BaseVectorStoreDriver.Entry("foo", [0.5, 0.5, 0.5], score=0.0, meta={}, namespace=None), diff --git a/tests/unit/drivers/vector/test_dummy_vector_store_driver.py b/tests/unit/drivers/vector/test_dummy_vector_store_driver.py index 3cde52ebc..518c3abda 100644 --- a/tests/unit/drivers/vector/test_dummy_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_dummy_vector_store_driver.py @@ -25,6 +25,10 @@ def test_load_entries(self, vector_store_driver): with pytest.raises(DummyError): vector_store_driver.load_entries(namespace="foo bar huzzah") + def test_query_vector(self, vector_store_driver): + with pytest.raises(DummyError): + vector_store_driver.query_vector([0.0, 0.5]) + def test_query(self, vector_store_driver): with pytest.raises(DummyError): vector_store_driver.query("foo bar huzzah") diff --git a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py index f87b30444..3727292de 100644 --- a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py @@ -40,6 +40,10 @@ def driver(self, mocker): return GriptapeCloudVectorStoreDriver(api_key="foo bar", knowledge_base_id="1") + def test_query_vector(self, driver): + with pytest.raises(NotImplementedError): + driver.query_vector([0.0, 0.5]) + def test_query(self, driver): result = driver.query( "some query", count=10, namespace="foo", include_vectors=True, distance_metric="bar", filter={"foo": "bar"} diff --git a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py index 521c8670d..de5125b78 100644 --- a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py @@ -117,6 +117,10 @@ def test_upsert_text_artifact(self, driver, mock_marqo): } assert result == expected_return_value["items"][0]["_id"] + def test_query_vector(self, driver): + with pytest.raises(NotImplementedError): + driver.query_vector([0.0, 0.5]) + def test_search(self, driver, mock_marqo): results = driver.query("Test query") mock_marqo.index().search.assert_called() diff --git a/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py b/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py index 20cb8bdc0..f4599f5ac 100644 --- a/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py @@ -37,6 +37,22 @@ def test_upsert_text(self, driver): test_id = driver.upsert_text(text, vector_id=vector_id_str) assert test_id == vector_id_str + def test_query_vector(self, driver, monkeypatch): + mock_query_result = [ + BaseVectorStoreDriver.Entry("foo", [0.5, 0.5, 0.5], score=0.0, meta={}, namespace=None), + BaseVectorStoreDriver.Entry("foo", vector=[0.5, 0.5, 0.5], score=0.0, meta={}, namespace=None), + ] + + monkeypatch.setattr(MongoDbAtlasVectorStoreDriver, "query_vector", lambda *args, **kwargs: mock_query_result) + + query_vector = [0.0, 0.5, 1.0] + results = driver.query_vector(query_vector, include_vectors=True) + assert len(results) == len(mock_query_result) + for result, expected in zip(results, mock_query_result): + assert result.id == expected.id + assert result.vector == expected.vector + assert isinstance(result, BaseVectorStoreDriver.Entry) + def test_query(self, driver, monkeypatch): mock_query_result = [ BaseVectorStoreDriver.Entry("foo", [0.5, 0.5, 0.5], score=0.0, meta={}, namespace=None), diff --git a/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py b/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py index cef3805ab..2fe2006cf 100644 --- a/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py @@ -41,6 +41,16 @@ def test_load_entries(self, driver): assert np.allclose(entries[0].vector, [0.7, 0.8, 0.9], atol=1e-6) assert entries[0].meta is None + def test_query_vector(self, driver): + mock_result = Mock() + mock_result.id = "query_result" + + with patch.object(driver, "query_vector", return_value=[mock_result]): + query_string = [0.0, 0.5, 1.0] + results = driver.query_vector(query_string, count=5, namespace="company") + assert len(results) == 1, "Expected results from the query" + assert results[0].id == "query_result", "Expected a result id" + def test_query(self, driver): mock_result = Mock() mock_result.id = "query_result" diff --git a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py index 72833f5a3..90ede58a5 100644 --- a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py @@ -102,6 +102,61 @@ def test_load_entries(self, mock_session, mock_engine): assert entries[0].meta == test_metas[0] assert entries[1].meta == test_metas[1] + def test_query_vector_invalid_distance_metric(self, mock_engine): + driver = PgVectorVectorStoreDriver( + embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + ) + + with pytest.raises(ValueError): + driver.query_vector([0.0, 0.5], distance_metric="invalid") + + def test_query_vector(self, mock_session, mock_engine): + test_ids = [str(uuid.uuid4()), str(uuid.uuid4())] + test_vecs = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + test_namespaces = [str(uuid.uuid4()), str(uuid.uuid4())] + test_metas = [{"key": "value1"}, {"key": "value2"}] + test_result = [ + [Mock(id=test_ids[0], vector=test_vecs[0], namespace=test_namespaces[0], meta=test_metas[0]), 0.1], + [Mock(id=test_ids[1], vector=test_vecs[1], namespace=test_namespaces[1], meta=test_metas[1]), 0.9], + ] + mock_session.query().order_by().limit().all.return_value = test_result + + driver = PgVectorVectorStoreDriver( + embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + ) + + result = driver.query_vector([0.0, 0.5], include_vectors=True) + + assert result[0].id == test_ids[0] + assert result[1].id == test_ids[1] + assert result[0].vector == test_vecs[0] + assert result[1].vector == test_vecs[1] + assert result[0].namespace == test_namespaces[0] + assert result[1].namespace == test_namespaces[1] + assert result[0].meta == test_metas[0] + assert result[1].meta == test_metas[1] + + def test_query_vector_filter(self, mock_session, mock_engine): + test_ids = [str(uuid.uuid4()), str(uuid.uuid4())] + test_vecs = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + test_namespaces = [str(uuid.uuid4()), str(uuid.uuid4())] + test_metas = [{"key": "value1"}, {"key": "value2"}] + test_result = [ + [Mock(id=test_ids[0], vector=test_vecs[0], namespace=test_namespaces[0], meta=test_metas[0]), 0.1] + ] + mock_session.query().order_by().filter_by().limit().all.return_value = test_result + + driver = PgVectorVectorStoreDriver( + embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + ) + + result = driver.query_vector([0.0, 0.5], include_vectors=True, filter={"namespace": test_namespaces[0]}) + + assert result[0].id == test_ids[0] + assert result[0].vector == test_vecs[0] + assert result[0].namespace == test_namespaces[0] + assert result[0].meta == test_metas[0] + def test_query_invalid_distance_metric(self, mock_engine): driver = PgVectorVectorStoreDriver( embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 8be38c51e..6d9f952f9 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -48,6 +48,12 @@ def test_upsert_text(self, driver): assert driver.upsert_text("foo", vector_id="foo") == "foo" assert isinstance(driver.upsert_text("foo"), str) + def test_query_vector(self, driver): + results = driver.query_vector([0.0, 0.5]) + + assert results[0].vector == [0, 1, 0] + assert results[0].id == "foo" + def test_query(self, driver): results = driver.query("test") diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py index ae86c5f42..87264efb9 100644 --- a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -57,6 +57,32 @@ def test_delete_vector(self, driver): points_selector=mock_import.return_value.PointIdsList(points=[vector_id]), ) + def test_query_vector(self, driver): + mock_query_result = [ + MagicMock( + id="foo", vector=[0, 1, 0], score=42, payload={"foo": "bar", "_score": 0.99, "_tensor_facets": []} + ) + ] + + with ( + patch.object(driver.client, "search", return_value=mock_query_result) as mock_search, + ): + vector = [0.1, 0.2, 0.3] + count = 10 + include_vectors = True + + results = driver.query_vector(vector, count=count, include_vectors=include_vectors) + + mock_search.assert_called_once_with( + collection_name=driver.collection_name, query_vector=[0.1, 0.2, 0.3], limit=count + ) + + assert len(results) == 1 + assert results[0].id == "foo" + assert results[0].vector == [0, 1, 0] if include_vectors else [] + assert results[0].score == 42 + assert results[0].meta == {"foo": "bar"} + def test_query(self, driver): mock_query_result = [ MagicMock( diff --git a/tests/unit/drivers/vector/test_redis_vector_store_driver.py b/tests/unit/drivers/vector/test_redis_vector_store_driver.py index 2f74b9279..de70a0a83 100644 --- a/tests/unit/drivers/vector/test_redis_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_redis_vector_store_driver.py @@ -78,6 +78,26 @@ def test_load_entries_with_namespace(self, driver, mock_keys, mock_hgetall): assert entries[0].vector == [1.0, 2.0, 3.0] assert entries[0].meta == {"foo": "bar"} + def test_query_vector(self, driver, mock_search): + results = driver.query_vector([0.0, 0.5]) + mock_search.assert_called_once() + assert len(results) == 1 + assert results[0].namespace == "some_namespace" + assert results[0].id == "some_vector_id" + assert results[0].score == 0.456198036671 + assert results[0].meta == {"foo": "bar"} + assert results[0].vector is None + + def test_query_vector_with_include_vectors(self, driver, mock_search): + results = driver.query_vector([0.0, 0.5], include_vectors=True) + mock_search.assert_called_once() + assert len(results) == 1 + assert results[0].namespace == "some_namespace" + assert results[0].id == "some_vector_id" + assert results[0].score == 0.456198036671 + assert results[0].meta == {"foo": "bar"} + assert results[0].vector == [1.0, 2.0, 3.0] + def test_query(self, driver, mock_search): results = driver.query("Some query") mock_search.assert_called_once()