diff --git a/CHANGELOG.md b/CHANGELOG.md index 97c34314a..a8d6002e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed ### Fixed +- Parameter `count` for `QdrantVectorStoreDriver.query` now optional as per documentation. ## [0.28.2] - 2024-07-12 ### Fixed diff --git a/docs/griptape-framework/drivers/vector-store-drivers.md b/docs/griptape-framework/drivers/vector-store-drivers.md index ea2b72a56..c1f54ab2e 100644 --- a/docs/griptape-framework/drivers/vector-store-drivers.md +++ b/docs/griptape-framework/drivers/vector-store-drivers.md @@ -24,7 +24,6 @@ The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vecto ```python import os -from griptape.artifacts import BaseArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -40,16 +39,11 @@ artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") # Upsert Artifacts into the Vector Store Driver [vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] -results = vector_store_driver.query( - "creativity", - count=3, - namespace="griptape" -) +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] print("\n\n".join(values)) - ``` ### Griptape Cloud Knowledge Base @@ -58,7 +52,6 @@ The [GriptapeCloudKnowledgeBaseVectorStoreDriver](../../reference/griptape/drive ```python import os -from griptape.artifacts import BaseArtifact from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver @@ -68,12 +61,11 @@ gt_cloud_knowledge_base_id = os.environ["GRIPTAPE_CLOUD_KB_ID"] vector_store_driver = GriptapeCloudKnowledgeBaseVectorStoreDriver(api_key=gt_cloud_api_key, knowledge_base_id=gt_cloud_knowledge_base_id) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] print("\n\n".join(values)) - ``` ### Pinecone @@ -86,31 +78,10 @@ The [PineconeVectorStoreDriver](../../reference/griptape/drivers/vector/pinecone Here is an example of how the Driver can be used to load and query information in a Pinecone cluster: ```python -import os -import hashlib -import json -from urllib.request import urlopen +import os from griptape.drivers import PineconeVectorStoreDriver, OpenAiEmbeddingDriver +from griptape.loaders import WebLoader -def load_data(driver: PineconeVectorStoreDriver) -> None: - response = urlopen( - "https://raw.githubusercontent.com/wedeploy-examples/" - "supermarket-web-example/master/products.json" - ) - - for product in json.loads(response.read()): - driver.upsert_text( - product["description"], - vector_id=hashlib.md5(product["title"].encode()).hexdigest(), - meta={ - "title": product["title"], - "description": product["description"], - "type": product["type"], - "price": product["price"], - "rating": product["rating"], - }, - namespace="supermarket-products", - ) # Initialize an Embedding Driver embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) @@ -118,18 +89,17 @@ embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) vector_store_driver = PineconeVectorStoreDriver( api_key=os.environ["PINECONE_API_KEY"], environment=os.environ["PINECONE_ENVIRONMENT"], - index_name=os.environ['PINECONE_INDEX_NAME'], + index_name=os.environ["PINECONE_INDEX_NAME"], embedding_driver=embedding_driver, ) -load_data(vector_store_driver) +# Load Artifacts from the web +artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") -results = vector_store_driver.query( - "fruit", - count=3, - filter={"price": {"$lte": 15}, "rating": {"$gte": 4}}, - namespace="supermarket-products", -) +# Upsert Artifacts into the Vector Store Driver +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] + +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -175,7 +145,7 @@ vector_store_driver.upsert_text_artifacts( } ) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -227,7 +197,7 @@ vector_store_driver.upsert_text_artifacts( } ) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -298,7 +268,7 @@ vector_store_driver.upsert_text_artifacts( } ) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -341,7 +311,7 @@ vector_store_driver.upsert_text_artifacts( } ) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -388,7 +358,7 @@ vector_store_driver.upsert_text_artifacts( } ) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -450,7 +420,7 @@ vector_store_driver.upsert_text_artifacts( } ) -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] @@ -468,54 +438,40 @@ Here is an example of how the Driver can be used to query information in a Qdran ```python import os -from griptape.drivers import QdrantVectorStoreDriver, HuggingFaceHubEmbeddingDriver -from griptape.tokenizers import HuggingFaceTokenizer +from griptape.drivers import QdrantVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader # Set up environment variables -embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" host = os.environ["QDRANT_CLUSTER_ENDPOINT"] -huggingface_token = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"] +api_key = os.environ["QDRANT_CLUSTER_API_KEY"] -# Initialize HuggingFace Embedding Driver -embedding_driver = HuggingFaceHubEmbeddingDriver( - api_token=huggingface_token, - model=embedding_model_name, - tokenizer=HuggingFaceTokenizer(model=embedding_model_name, max_output_tokens=512), -) +# Initialize an Embedding Driver. +embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"]) -# Initialize Qdrant Vector Store Driver vector_store_driver = QdrantVectorStoreDriver( url=host, collection_name="griptape", content_payload_key="content", embedding_driver=embedding_driver, - api_key=os.environ["QDRANT_CLUSTER_API_KEY"], + api_key=api_key, ) # Load Artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") -# Encode text to get embeddings -embeddings = embedding_driver.embed_text_artifact(artifacts[0]) - # Recreate Qdrant collection vector_store_driver.client.recreate_collection( collection_name=vector_store_driver.collection_name, vectors_config={ - "size": len(embeddings), + "size": 1536, "distance": vector_store_driver.distance }, ) -# Upsert vector into Qdrant -vector_store_driver.upsert_vector( - vector=embeddings, - vector_id=str(artifacts[0].id), - content=artifacts[0].value -) +# Upsert Artifacts into the Vector Store Driver +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] -results =vector_store_driver.query(query="What is griptape?") +results = vector_store_driver.query(query="What is griptape?") values = [r.to_artifact().value for r in results] diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py index 34345b6df..a5162f754 100644 --- a/griptape/drivers/vector/qdrant_vector_store_driver.py +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -106,7 +106,9 @@ def query( query_vector = self.embedding_driver.embed_string(query) # Create a search request - results = self.client.search(collection_name=self.collection_name, query_vector=query_vector, limit=count) + request = {"collection_name": self.collection_name, "query_vector": query_vector, "limit": count} + request = {k: v for k, v in request.items() if v is not None} + results = self.client.search(**request) # Convert results to QueryResult objects query_results = [