Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Qdrant query count not optional #972

Merged
merged 12 commits into from
Jul 13, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

### Fixed
- Parameter `count` for `QdrantVectorStoreDriver.query` now optional as per documentation.

## [0.28.2] - 2024-07-12
### Fixed
Expand Down
96 changes: 26 additions & 70 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vecto

```python
import os
from griptape.artifacts import BaseArtifact
from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

Expand All @@ -40,16 +39,11 @@ artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai")
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results = vector_store_driver.query(
"creativity",
count=3,
namespace="griptape"
)
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))

```

### Griptape Cloud Knowledge Base
Expand All @@ -58,7 +52,6 @@ The [GriptapeCloudKnowledgeBaseVectorStoreDriver](../../reference/griptape/drive

```python
import os
from griptape.artifacts import BaseArtifact
from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver


Expand All @@ -68,12 +61,11 @@ gt_cloud_knowledge_base_id = os.environ["GRIPTAPE_CLOUD_KB_ID"]

vector_store_driver = GriptapeCloudKnowledgeBaseVectorStoreDriver(api_key=gt_cloud_api_key, knowledge_base_id=gt_cloud_knowledge_base_id)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))

```

### Pinecone
Expand All @@ -86,50 +78,28 @@ The [PineconeVectorStoreDriver](../../reference/griptape/drivers/vector/pinecone
Here is an example of how the Driver can be used to load and query information in a Pinecone cluster:

```python
import os
import hashlib
import json
from urllib.request import urlopen
import os
from griptape.drivers import PineconeVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

def load_data(driver: PineconeVectorStoreDriver) -> None:
response = urlopen(
"https://raw.githubusercontent.com/wedeploy-examples/"
"supermarket-web-example/master/products.json"
)

for product in json.loads(response.read()):
driver.upsert_text(
product["description"],
vector_id=hashlib.md5(product["title"].encode()).hexdigest(),
meta={
"title": product["title"],
"description": product["description"],
"type": product["type"],
"price": product["price"],
"rating": product["rating"],
},
namespace="supermarket-products",
)

# Initialize an Embedding Driver
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

vector_store_driver = PineconeVectorStoreDriver(
api_key=os.environ["PINECONE_API_KEY"],
environment=os.environ["PINECONE_ENVIRONMENT"],
index_name=os.environ['PINECONE_INDEX_NAME'],
index_name=os.environ["PINECONE_INDEX_NAME"],
embedding_driver=embedding_driver,
)

load_data(vector_store_driver)
# Load Artifacts from the web
artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai")

results = vector_store_driver.query(
"fruit",
count=3,
filter={"price": {"$lte": 15}, "rating": {"$gte": 4}},
namespace="supermarket-products",
)
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -175,7 +145,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -227,7 +197,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -298,7 +268,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -341,7 +311,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -388,7 +358,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -450,7 +420,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand All @@ -468,54 +438,40 @@ Here is an example of how the Driver can be used to query information in a Qdran

```python
import os
from griptape.drivers import QdrantVectorStoreDriver, HuggingFaceHubEmbeddingDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.drivers import QdrantVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

# Set up environment variables
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
host = os.environ["QDRANT_CLUSTER_ENDPOINT"]
huggingface_token = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"]
api_key = os.environ["QDRANT_CLUSTER_API_KEY"]

# Initialize HuggingFace Embedding Driver
embedding_driver = HuggingFaceHubEmbeddingDriver(
api_token=huggingface_token,
model=embedding_model_name,
tokenizer=HuggingFaceTokenizer(model=embedding_model_name, max_output_tokens=512),
)
# Initialize an Embedding Driver.
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

# Initialize Qdrant Vector Store Driver
vector_store_driver = QdrantVectorStoreDriver(
url=host,
collection_name="griptape",
content_payload_key="content",
embedding_driver=embedding_driver,
api_key=os.environ["QDRANT_CLUSTER_API_KEY"],
api_key=api_key,
)

# Load Artifacts from the web
artifacts = WebLoader().load("https://www.griptape.ai")

# Encode text to get embeddings
embeddings = embedding_driver.embed_text_artifact(artifacts[0])

# Recreate Qdrant collection
vector_store_driver.client.recreate_collection(
collection_name=vector_store_driver.collection_name,
vectors_config={
"size": len(embeddings),
"size": 1536,
"distance": vector_store_driver.distance
},
)

# Upsert vector into Qdrant
vector_store_driver.upsert_vector(
vector=embeddings,
vector_id=str(artifacts[0].id),
content=artifacts[0].value
)
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/vector/qdrant_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def query(
query_vector = self.embedding_driver.embed_string(query)

# Create a search request
results = self.client.search(collection_name=self.collection_name, query_vector=query_vector, limit=count)
request = {"collection_name": self.collection_name, "query_vector": query_vector, "limit": count}
request = {k: v for k, v in request.items() if v is not None}
results = self.client.search(**request)

# Convert results to QueryResult objects
query_results = [
Expand Down
Loading