Skip to content

Commit

Permalink
Implement query_vector() for all vector_store_drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
cjkindel committed Dec 30, 2024
1 parent ac4c034 commit a63760b
Show file tree
Hide file tree
Showing 24 changed files with 444 additions and 44 deletions.
36 changes: 31 additions & 5 deletions griptape/drivers/vector/astradb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
for match in self.collection.find(filter=find_filter, projection={"*": 1})
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
"""Run a similarity search on the Astra DB store, based on a query string.
"""Run a similarity search on the Astra DB store, based on a vector list.
Args:
query: the query string.
vector: the vector to be queried.
count: the maximum number of results to return. If omitted, defaults will apply.
namespace: the namespace to filter results by.
include_vectors: whether to include vector data in the results.
Expand All @@ -168,7 +168,6 @@ def query(
find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
find_filter = {**(query_filter or {}), **find_filter_ns}
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
vector = self.embedding_driver.embed_string(query)
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
matches = self.collection.find(
filter=find_filter,
Expand All @@ -187,3 +186,30 @@ def query(
)
for match in matches
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
"""Run a similarity search on the Astra DB store, based on a query string.
Args:
query: the query string.
count: the maximum number of results to return. If omitted, defaults will apply.
namespace: the namespace to filter results by.
include_vectors: whether to include vector data in the results.
kwargs: additional keyword arguments. Currently only the free-form dict `filter`
is recognized (and goes straight to the Data API query);
others will generate a warning and be ignored.
Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries,
with their `score` attribute set to the vector similarity to the query.
"""
vector = self.embedding_driver.embed_string(query)
return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)
29 changes: 23 additions & 6 deletions griptape/drivers/vector/azure_mongodb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,22 @@
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
"""A Vector Store Driver for CosmosDB with MongoDB vCore API."""

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
"""Queries the MongoDB collection for documents that match the provided vector list.
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
collection = self.get_collection()

# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)

count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
offset = offset or 0

Expand Down Expand Up @@ -63,3 +60,23 @@ def query(
)
for doc in collection.aggregate(pipeline)
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)
return self.query_vector(

Check warning on line 80 in griptape/drivers/vector/azure_mongodb_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/azure_mongodb_vector_store_driver.py#L79-L80

Added lines #L79 - L80 were not covered by tests
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
)
11 changes: 11 additions & 0 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
@abstractmethod
def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...

@abstractmethod
def query_vector(
self,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[Entry]: ...

@abstractmethod
def query(
self,
Expand Down
11 changes: 11 additions & 0 deletions griptape/drivers/vector/dummy_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
raise DummyError(__class__.__name__, "load_entries")

def query_vector(
self,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
raise DummyError(__class__.__name__, "query_vector")

def query(
self,
query: str,
Expand Down
11 changes: 11 additions & 0 deletions griptape/drivers/vector/griptape_cloud_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")

def query_vector(
self,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")

def query(
self,
query: str,
Expand Down
20 changes: 15 additions & 5 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,22 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
query_embedding = self.embedding_driver.embed_string(query)

if namespace:
entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
else:
entries = self.entries

entries_and_relatednesses = [
(entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values())
(entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)
Expand All @@ -113,6 +111,18 @@ def query(
for r in result
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
vector = self.embedding_driver.embed_string(query)
return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)

def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

Expand Down
11 changes: 11 additions & 0 deletions griptape/drivers/vector/marqo_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto

return entries

def query_vector(
self,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")

def query(
self,
query: str,
Expand Down
29 changes: 23 additions & 6 deletions griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,22 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
for doc in cursor
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
"""Queries the MongoDB collection for documents that match the provided vector list.
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
collection = self.get_collection()

# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)

count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
offset = offset or 0

Expand Down Expand Up @@ -171,6 +168,26 @@ def query(
for doc in collection.aggregate(pipeline)
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)
return self.query_vector(

Check warning on line 187 in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/mongodb_atlas_vector_store_driver.py#L186-L187

Added lines #L186 - L187 were not covered by tests
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
)

def delete_vector(self, vector_id: str) -> None:
"""Deletes the vector from the collection."""
collection = self.get_collection()
Expand Down
36 changes: 32 additions & 4 deletions griptape/drivers/vector/opensearch_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
for hit in response["hits"]["hits"]
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
Expand All @@ -130,15 +130,14 @@ def query(
field_name: str = "vector",
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.
Results can be limited using the count parameter and optionally filtered by a namespace.
Returns:
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
"""
count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
vector = self.embedding_driver.embed_string(query)
# Base k-NN query
query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}

Expand All @@ -165,5 +164,34 @@ def query(
for hit in response["hits"]["hits"]
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
include_metadata: bool = True,
field_name: str = "vector",
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.
Results can be limited using the count parameter and optionally filtered by a namespace.
Returns:
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
"""
vector = self.embedding_driver.embed_string(query)
return self.query_vector(

Check warning on line 186 in griptape/drivers/vector/opensearch_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/opensearch_vector_store_driver.py#L185-L186

Added lines #L185 - L186 were not covered by tests
vector,
count=count,
namespace=namespace,
include_vectors=include_vectors,
include_metadata=include_metadata,
field_name=field_name,
**kwargs,
)

def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
Loading

0 comments on commit a63760b

Please sign in to comment.