From 32c2ab3ede530f0c21cfdd35a2ff99e8ff49a6b9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 24 Jul 2024 13:44:18 -0700 Subject: [PATCH] Remove image query drivers/task/engine --- CHANGELOG.md | 4 + docs/griptape-framework/data/index.md | 2 - .../drivers/image-query-drivers.md | 162 ------------------ .../engines/image-query-engines.md | 30 ---- docs/griptape-framework/structures/config.md | 6 +- docs/griptape-framework/structures/tasks.md | 42 ----- .../official-tools/image-query-client.md | 9 +- griptape/common/prompt_stack/prompt_stack.py | 2 +- .../config/amazon_bedrock_structure_config.py | 14 -- griptape/config/anthropic_structure_config.py | 7 - .../config/azure_openai_structure_config.py | 17 -- griptape/config/base_structure_config.py | 3 - griptape/config/openai_structure_config.py | 7 - griptape/config/structure_config.py | 7 - griptape/drivers/__init__.py | 20 --- griptape/drivers/image_query/__init__.py | 0 .../amazon_bedrock_image_query_driver.py | 43 ----- .../anthropic_image_query_driver.py | 64 ------- .../azure_openai_image_query_driver.py | 50 ------ .../image_query/base_image_query_driver.py | 41 ----- .../base_multi_model_image_query_driver.py | 23 --- .../image_query/dummy_image_query_driver.py | 20 --- .../image_query/openai_image_query_driver.py | 55 ------ .../drivers/image_query_model/__init__.py | 0 .../base_image_query_model_driver.py | 20 --- ...bedrock_claude_image_query_model_driver.py | 40 ----- griptape/engines/__init__.py | 2 - griptape/engines/image_query/__init__.py | 0 .../engines/image_query/image_query_engine.py | 17 -- griptape/events/__init__.py | 4 - griptape/events/base_image_query_event.py | 9 - griptape/events/finish_image_query_event.py | 8 - griptape/events/start_image_query_event.py | 11 -- griptape/schemas/base_schema.py | 2 - griptape/tasks/__init__.py | 2 - griptape/tasks/image_query_task.py | 90 ---------- griptape/tools/image_query_client/tool.py | 47 ++++- mkdocs.yml | 1 - tests/mocks/mock_image_query_driver.py | 16 -- tests/mocks/mock_structure_config.py | 4 - .../test_amazon_bedrock_structure_config.py | 12 -- .../config/test_anthropic_structure_config.py | 5 - .../test_azure_openai_structure_config.py | 11 -- .../config/test_cohere_structure_config.py | 1 - .../config/test_google_structure_config.py | 1 - .../config/test_openai_structure_config.py | 9 - tests/unit/config/test_structure_config.py | 2 - tests/unit/drivers/image_query/__init__.py | 0 .../test_amazon_bedrock_image_query_driver.py | 52 ------ .../test_anthropic_image_query_driver.py | 97 ----------- .../test_azure_openai_image_query_driver.py | 70 -------- .../test_base_image_query_driver.py | 26 --- .../test_dummy_image_query_driver.py | 18 -- .../test_openai_image_query_driver.py | 61 ------- ...bedrock_claude_image_query_model_driver.py | 42 ----- tests/unit/tasks/test_image_query_task.py | 83 --------- 56 files changed, 49 insertions(+), 1342 deletions(-) delete mode 100644 docs/griptape-framework/drivers/image-query-drivers.md delete mode 100644 docs/griptape-framework/engines/image-query-engines.md delete mode 100644 griptape/drivers/image_query/__init__.py delete mode 100644 griptape/drivers/image_query/amazon_bedrock_image_query_driver.py delete mode 100644 griptape/drivers/image_query/anthropic_image_query_driver.py delete mode 100644 griptape/drivers/image_query/azure_openai_image_query_driver.py delete mode 100644 griptape/drivers/image_query/base_image_query_driver.py delete mode 100644 griptape/drivers/image_query/base_multi_model_image_query_driver.py delete mode 100644 griptape/drivers/image_query/dummy_image_query_driver.py delete mode 100644 griptape/drivers/image_query/openai_image_query_driver.py delete mode 100644 griptape/drivers/image_query_model/__init__.py delete mode 100644 griptape/drivers/image_query_model/base_image_query_model_driver.py delete mode 100644 griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py delete mode 100644 griptape/engines/image_query/__init__.py delete mode 100644 griptape/engines/image_query/image_query_engine.py delete mode 100644 griptape/events/base_image_query_event.py delete mode 100644 griptape/events/finish_image_query_event.py delete mode 100644 griptape/events/start_image_query_event.py delete mode 100644 griptape/tasks/image_query_task.py delete mode 100644 tests/mocks/mock_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query/__init__.py delete mode 100644 tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query/test_anthropic_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query/test_base_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query/test_dummy_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query/test_openai_image_query_driver.py delete mode 100644 tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py delete mode 100644 tests/unit/tasks/test_image_query_task.py diff --git a/CHANGELOG.md b/CHANGELOG.md index efe70904d..153708e06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Renamed `drivers-vector-postgresql` extra to `drivers-vector-pgvector`. - **BREAKING**: Update `marqo` dependency to `^3.7.0`. - **BREAKING**: Removed `drivers-sql-postgresql` extra. Use `drivers-sql` extra and install necessary drivers (i.e. `psycopg2`) separately. +- **BREAKING**: Removed `ImageQueryDriver` in favor of using `PromptDriver` with `TextArtifact` and `ImageArtifact` inputs. +- **BREAKING**: Removed `ImageQueryEngine`, in favor of using `PromptDriver` directly. +- **BREAKING**: Removed `ImageQueryTask`, in favor of `PromptTask` with `TextArtifact` and `ImageArtifact` inputs. +- **BREAKING**: `ImageQueryClient` now takes a `PromptDriver` instead of an `ImageQueryEngine`. - Removed unnecessary `sqlalchemy-redshift` dependency in `drivers-sql-amazon-redshift` extra. - Removed unnecessary `transformers` dependency in `drivers-prompt-huggingface` extra. - Removed unnecessary `huggingface-hub` dependency in `drivers-prompt-huggingface-pipeline` extra. diff --git a/docs/griptape-framework/data/index.md b/docs/griptape-framework/data/index.md index 3e4359737..ce6fad56a 100644 --- a/docs/griptape-framework/data/index.md +++ b/docs/griptape-framework/data/index.md @@ -18,8 +18,6 @@ Griptape provides several abstractions for working with data. [Extraction Engines](../engines/extraction-engines.md) are used for extracting structured content. -[Image Query Engines](../engines/image-query-engines.md) are used for querying images with text. - [Image Generation Engines](../engines/image-generation-engines.md) are used for generating images. [Summary Engines](../engines/summary-engines.md) are used for summarizing text content. diff --git a/docs/griptape-framework/drivers/image-query-drivers.md b/docs/griptape-framework/drivers/image-query-drivers.md deleted file mode 100644 index 04e3ebaee..000000000 --- a/docs/griptape-framework/drivers/image-query-drivers.md +++ /dev/null @@ -1,162 +0,0 @@ ---- -search: - boost: 2 ---- - -## Overview - -Image Query Drivers are used by [Image Query Engines](../engines/image-query-engines.md) to execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular Image Query Driver. - -!!! info - All Image Query Drivers default to a `max_tokens` of 256. It is recommended that you set this value to correspond to the desired response length. - -## Image Query Drivers - -### Anthropic - -!!! info - To tune `max_tokens`, see [Anthropic's documentation on image tokens](https://docs.anthropic.com/claude/docs/vision#image-costs) for more information on how to relate token count to response length. - -The [AnthropicImageQueryDriver](../../reference/griptape/drivers/image_query/anthropic_image_query_driver.md) is used to query images using Anthropic's Claude 3 multi-modal model. Here is an example of how to use it: - -```python -from griptape.drivers import AnthropicImageQueryDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader - -driver = AnthropicImageQueryDriver( - model="claude-3-sonnet-20240229", - max_tokens=1024, -) - -engine = ImageQueryEngine( - image_query_driver=driver, -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - -engine.run("Describe the weather in the image", [image_artifact]) -``` - -You can also specify multiple images with a single text prompt. This applies the same text prompt to all images specified, up to a max of 20. However, you will still receive one text response from the model currently. - -```python -from griptape.drivers import AnthropicImageQueryDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader - -driver = AnthropicImageQueryDriver( - model="claude-3-sonnet-20240229", - max_tokens=1024, -) - -engine = ImageQueryEngine( - image_query_driver=driver, -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact1 = ImageLoader().load(f.read()) - -with open("tests/resources/cow.png", "rb") as f: - image_artifact2 = ImageLoader().load(f.read()) - -result = engine.run("Describe the weather in the image", [image_artifact1, image_artifact2]) - -print(result) -``` - -### OpenAI - -!!! info - While the `max_tokens` field is optional, it is recommended to set this to a value that corresponds to the desired response length. Without an explicit value, the model will default to very short responses. See [OpenAI's documentation](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) for more information on how to relate token count to response length. - -The [OpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/openai_image_query_driver.md) is used to query images using the OpenAI Vision API. Here is an example of how to use it: - -```python -from griptape.drivers import OpenAiImageQueryDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader - -driver = OpenAiImageQueryDriver( - model="gpt-4o", - max_tokens=256, -) - -engine = ImageQueryEngine( - image_query_driver=driver, -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - -engine.run("Describe the weather in the image", [image_artifact]) -``` - -### Azure OpenAI - -!!! info - In order to use the `gpt-4-vision-preview` model on Azure OpenAI, the `gpt-4` model must be deployed with the version set to `vision-preview`. More information can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/gpt-with-vision). - -The [AzureOpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/azure_openai_image_query_driver.md) is used to query images using the Azure OpenAI Vision API. Here is an example of how to use it: - -```python -import os -from griptape.drivers import AzureOpenAiImageQueryDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader - -driver = AzureOpenAiImageQueryDriver( - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_2"], - api_key=os.environ["AZURE_OPENAI_API_KEY_2"], - model="gpt-4o", - azure_deployment="gpt-4o", - max_tokens=256, -) - -engine = ImageQueryEngine( - image_query_driver=driver, -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - -engine.run("Describe the weather in the image", [image_artifact]) -``` - -### Amazon Bedrock - -The [Amazon Bedrock Image Query Driver](../../reference/griptape/drivers/image_query/amazon_bedrock_image_query_driver.md) provides multi-model access to image query models hosted by Amazon Bedrock. This Driver manages API calls to the Bedrock API, while the specific Model Drivers below format the API requests and parse the responses. - -#### Claude - -The [BedrockClaudeImageQueryModelDriver](../../reference/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.md) provides support for Claude models hosted by Bedrock. - -```python -from griptape.drivers import AmazonBedrockImageQueryDriver, BedrockClaudeImageQueryModelDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader -import boto3 - -session = boto3.Session( - region_name="us-west-2" -) - -driver = AmazonBedrockImageQueryDriver( - image_query_model_driver=BedrockClaudeImageQueryModelDriver(), - model="anthropic.claude-3-sonnet-20240229-v1:0", - session=session -) - -engine = ImageQueryEngine( - image_query_driver=driver -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - - -result = engine.run("Describe the weather in the image", [image_artifact]) - -print(result) -``` diff --git a/docs/griptape-framework/engines/image-query-engines.md b/docs/griptape-framework/engines/image-query-engines.md deleted file mode 100644 index 1db247cb7..000000000 --- a/docs/griptape-framework/engines/image-query-engines.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -search: - boost: 2 ---- - -## Image Query Engines - -The [Image Query Engine](../../reference/griptape/engines/image_query/image_query_engine.md) allows you to perform natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular [Image Query Driver](../drivers/image-query-drivers.md). - -All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing. - -```python -from griptape.drivers import OpenAiImageQueryDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader - -driver = OpenAiImageQueryDriver( - model="gpt-4o", - max_tokens=256 -) - -engine = ImageQueryEngine( - image_query_driver=driver -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - -engine.run("Describe the weather in the image", [image_artifact]) -``` diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 3f510eb86..8213142bc 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -41,11 +41,7 @@ agent = Agent( config=AzureOpenAiStructureConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] - ).merge_config({ - "image_query_driver": { - "azure_deployment": "gpt-4o", - }, - }), + ) ) ``` diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 6d479578b..c1781fc3d 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -655,48 +655,6 @@ pipeline.add_task( pipeline.run("An image of a mountain shrouded by clouds") ``` -## Image Query Task - -The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) performs a natural language query on one or more input images. This Task uses an [Image Query Engine](../engines/image-query-engines.md) configured with an [Image Query Driver](../drivers/image-query-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver. - -This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#textartifact)) and a list of [Image Artifacts](../data/artifacts.md#imageartifact) or a Callable returning these two values. - -```python -from griptape.engines import ImageQueryEngine -from griptape.drivers import OpenAiImageQueryDriver -from griptape.tasks import ImageQueryTask -from griptape.loaders import ImageLoader -from griptape.structures import Pipeline - -# Create a driver configured to use OpenAI's GPT-4 Vision model. -driver = OpenAiImageQueryDriver( - model="gpt-4o", - max_tokens=100, -) - -# Create an engine configured to use the driver. -engine = ImageQueryEngine( - image_query_driver=driver, -) - -# Load the input image artifact. -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - -# Instantiate a pipeline. -pipeline = Pipeline() - -# Add an ImageQueryTask to the pipeline. -pipeline.add_task( - ImageQueryTask( - input=("{{ args[0] }}", [image_artifact]), - image_query_engine=engine, - ) -) - -pipeline.run("Describe the weather in the image") -``` - ## Structure Run Task The [Structure Run Task](../../reference/griptape/tasks/structure_run_task.md) runs another Structure with a given input. This Task is useful for orchestrating multiple specialized Structures in a single run. Note that the input to the Task is a tuple of arguments that will be passed to the Structure. diff --git a/docs/griptape-tools/official-tools/image-query-client.md b/docs/griptape-tools/official-tools/image-query-client.md index a1044fb91..483d7cd68 100644 --- a/docs/griptape-tools/official-tools/image-query-client.md +++ b/docs/griptape-tools/official-tools/image-query-client.md @@ -5,7 +5,7 @@ This tool allows Agents to execute natural language queries on the contents of i ```python from griptape.structures import Agent from griptape.tools import ImageQueryClient -from griptape.drivers import OpenAiImageQueryDriver +from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import ImageQueryEngine # Create an Image Query Driver. @@ -13,14 +13,9 @@ driver = OpenAiImageQueryDriver( model="gpt-4o" ) -# Create an Image Query Engine configured to use the driver. -engine = ImageQueryEngine( - image_query_driver=driver, -) - # Create an Image Query Client configured to use the engine. tool = ImageQueryClient( - image_query_engine=engine, + prompt_driver=engine, ) # Create an agent and provide the tool to it. diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3186dac89..a49780fdd 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -22,7 +22,7 @@ @define class PromptStack(SerializableMixin): - messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) + messages: list[Message] = field(factory=list, metadata={"serializable": True}) tools: list[BaseTool] = field(factory=list, kw_only=True) @property diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py index 3ad7f8f48..786eead59 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_structure_config.py @@ -7,14 +7,12 @@ from griptape.config import StructureConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, - AmazonBedrockImageQueryDriver, AmazonBedrockPromptDriver, AmazonBedrockTitanEmbeddingDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, BasePromptDriver, BaseVectorStoreDriver, - BedrockClaudeImageQueryModelDriver, BedrockTitanImageGenerationModelDriver, LocalVectorStoreDriver, ) @@ -63,18 +61,6 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageGenerationDriver = field( - default=Factory( - lambda self: AmazonBedrockImageQueryDriver( - session=self.session, - model="anthropic.claude-3-5-sonnet-20240620-v1:0", - image_query_model_driver=BedrockClaudeImageQueryModelDriver(), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, - ) vector_store_driver: BaseVectorStoreDriver = field( default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), kw_only=True, diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_structure_config.py index 1bb5bf49b..91b7653b3 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_structure_config.py @@ -2,10 +2,8 @@ from griptape.config import StructureConfig from griptape.drivers import ( - AnthropicImageQueryDriver, AnthropicPromptDriver, BaseEmbeddingDriver, - BaseImageQueryDriver, BasePromptDriver, BaseVectorStoreDriver, LocalVectorStoreDriver, @@ -32,8 +30,3 @@ class AnthropicStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")), - kw_only=True, - metadata={"serializable": True}, - ) diff --git a/griptape/config/azure_openai_structure_config.py b/griptape/config/azure_openai_structure_config.py index ce0303e34..beafd94c8 100644 --- a/griptape/config/azure_openai_structure_config.py +++ b/griptape/config/azure_openai_structure_config.py @@ -9,10 +9,8 @@ AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, AzureOpenAiImageGenerationDriver, - AzureOpenAiImageQueryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseVectorStoreDriver, LocalVectorStoreDriver, @@ -30,7 +28,6 @@ class AzureOpenAiStructureConfig(StructureConfig): api_key: An optional Azure API key. prompt_driver: An Azure OpenAI Chat Prompt Driver. image_generation_driver: An Azure OpenAI Image Generation Driver. - image_query_driver: An Azure OpenAI Vision Image Query Driver. embedding_driver: An Azure OpenAI Embedding Driver. vector_store_driver: A Local Vector Store Driver. """ @@ -72,20 +69,6 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_query_driver: BaseImageQueryDriver = field( - default=Factory( - lambda self: AzureOpenAiImageQueryDriver( - model="gpt-4o", - azure_endpoint=self.azure_endpoint, - api_key=self.api_key, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - metadata={"serializable": True}, - kw_only=True, - ) embedding_driver: BaseEmbeddingDriver = field( default=Factory( lambda self: AzureOpenAiEmbeddingDriver( diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index 31949cd2f..3bf13e6a2 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -16,7 +16,6 @@ BaseConversationMemoryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseTextToSpeechDriver, BaseVectorStoreDriver, @@ -28,7 +27,6 @@ class BaseStructureConfig(BaseConfig, ABC): prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) - image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( @@ -47,7 +45,6 @@ def drivers(self) -> list: return [ self.prompt_driver, self.image_generation_driver, - self.image_query_driver, self.embedding_driver, self.vector_store_driver, self.conversation_memory_driver, diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 63806dfc9..ef13277a8 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -5,7 +5,6 @@ BaseAudioTranscriptionDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseTextToSpeechDriver, BaseVectorStoreDriver, @@ -14,7 +13,6 @@ OpenAiChatPromptDriver, OpenAiEmbeddingDriver, OpenAiImageGenerationDriver, - OpenAiImageQueryDriver, OpenAiTextToSpeechDriver, ) @@ -31,11 +29,6 @@ class OpenAiStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: OpenAiImageQueryDriver(model="gpt-4o")), - kw_only=True, - metadata={"serializable": True}, - ) embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True}, diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index ef95012ce..49a41e068 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -10,14 +10,12 @@ BaseConversationMemoryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseTextToSpeechDriver, BaseVectorStoreDriver, DummyAudioTranscriptionDriver, DummyEmbeddingDriver, DummyImageGenerationDriver, - DummyImageQueryDriver, DummyPromptDriver, DummyTextToSpeechDriver, DummyVectorStoreDriver, @@ -36,11 +34,6 @@ class StructureConfig(BaseStructureConfig): default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( - kw_only=True, - default=Factory(lambda: DummyImageQueryDriver()), - metadata={"serializable": True}, - ) embedding_driver: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 9e1790b01..787e701ee 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -79,17 +79,6 @@ HuggingFacePipelineImageGenerationDriver, ) -from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver -from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver - -from .image_query.base_image_query_driver import BaseImageQueryDriver -from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver -from .image_query.dummy_image_query_driver import DummyImageQueryDriver -from .image_query.openai_image_query_driver import OpenAiImageQueryDriver -from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver -from .image_query.azure_openai_image_query_driver import AzureOpenAiImageQueryDriver -from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver - from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver from .web_scraper.markdownify_web_scraper_driver import MarkdownifyWebScraperDriver @@ -194,15 +183,6 @@ "AzureOpenAiImageGenerationDriver", "DummyImageGenerationDriver", "HuggingFacePipelineImageGenerationDriver", - "BaseImageQueryModelDriver", - "BedrockClaudeImageQueryModelDriver", - "BaseImageQueryDriver", - "OpenAiImageQueryDriver", - "AzureOpenAiImageQueryDriver", - "DummyImageQueryDriver", - "AnthropicImageQueryDriver", - "BaseMultiModelImageQueryDriver", - "AmazonBedrockImageQueryDriver", "BaseWebScraperDriver", "TrafilaturaWebScraperDriver", "MarkdownifyWebScraperDriver", diff --git a/griptape/drivers/image_query/__init__.py b/griptape/drivers/image_query/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py deleted file mode 100644 index 46406d972..000000000 --- a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -import json -from typing import TYPE_CHECKING, Any - -from attrs import Factory, define, field - -from griptape.drivers import BaseMultiModelImageQueryDriver -from griptape.utils import import_optional_dependency - -if TYPE_CHECKING: - import boto3 - - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver): - session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens) - - response = self.bedrock_client.invoke_model( - modelId=self.model, - contentType="application/json", - accept="application/json", - body=json.dumps(payload), - ) - - response_body = json.loads(response.get("body").read()) - - if response_body is None: - raise ValueError("Model response is empty") - - try: - return self.image_query_model_driver.process_output(response_body) - except Exception as e: - raise ValueError(f"Output is unable to be processed as returned {e}") from e diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py deleted file mode 100644 index bd19862ec..000000000 --- a/griptape/drivers/image_query/anthropic_image_query_driver.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from typing import Any, Optional - -from attrs import Factory, define, field - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BaseImageQueryDriver -from griptape.utils import import_optional_dependency - - -@define -class AnthropicImageQueryDriver(BaseImageQueryDriver): - """Anthropic Image Query Driver. - - Attributes: - api_key: Anthropic API key. - model: Anthropic model name. - client: Custom `Anthropic` client. - """ - - api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) - model: str = field(kw_only=True, metadata={"serializable": True}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - if self.max_tokens is None: - raise TypeError("max_output_tokens can't be empty") - - response = self.client.messages.create(**self._base_params(query, images)) - content_blocks = response.content - - if len(content_blocks) < 1: - raise ValueError("Response content is empty") - - text_content = content_blocks[0].text - - return TextArtifact(text_content) - - def _base_params(self, text_query: str, images: list[ImageArtifact]) -> dict: - content = [self._construct_image_message(image) for image in images] - content.append(self._construct_text_message(text_query)) - messages = self._construct_messages(content) - params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens} - - return params - - def _construct_image_message(self, image_data: ImageArtifact) -> dict: - data = image_data.base64 - media_type = image_data.mime_type - - return {"source": {"data": data, "media_type": media_type, "type": "base64"}, "type": "image"} - - def _construct_text_message(self, query: str) -> dict: - return {"text": query, "type": "text"} - - def _construct_messages(self, content: list) -> list: - return [{"content": content, "role": "user"}] diff --git a/griptape/drivers/image_query/azure_openai_image_query_driver.py b/griptape/drivers/image_query/azure_openai_image_query_driver.py deleted file mode 100644 index 04492e471..000000000 --- a/griptape/drivers/image_query/azure_openai_image_query_driver.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Optional - -import openai -from attrs import Factory, define, field - -from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver - - -@define -class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver): - """Driver for Azure-hosted OpenAI image query API. - - Attributes: - azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. - azure_endpoint: An Azure OpenAi endpoint. - azure_ad_token: An optional Azure Active Directory token. - azure_ad_token_provider: An optional Azure Active Directory token provider. - api_version: An Azure OpenAi API version. - client: An `openai.AzureOpenAI` client. - """ - - azure_deployment: str = field( - kw_only=True, - default=Factory(lambda self: self.model, takes_self=True), - metadata={"serializable": True}, - ) - azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) - azure_ad_token_provider: Optional[Callable[[], str]] = field( - kw_only=True, - default=None, - metadata={"serializable": False}, - ) - api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py deleted file mode 100644 index b39f198d4..000000000 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): - max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) - - def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( - StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), - ) - - def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) - - def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - for attempt in self.retrying(): - with attempt: - self.before_run(query, images) - - result = self.try_query(query, images) - - self.after_run(result.value) - - return result - else: - raise Exception("image query driver failed after all retry attempts") - - @abstractmethod - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: ... diff --git a/griptape/drivers/image_query/base_multi_model_image_query_driver.py b/griptape/drivers/image_query/base_multi_model_image_query_driver.py deleted file mode 100644 index 52af617e8..000000000 --- a/griptape/drivers/image_query/base_multi_model_image_query_driver.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from abc import ABC - -from attrs import define, field - -from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver - - -@define -class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC): - """Image Query Driver for platforms like Amazon Bedrock that host many LLM models. - - Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the - image generation request in the format required by the model and to process the output. - - Attributes: - model: Model name to use - image_query_model_driver: Image Model Driver to use. - """ - - model: str = field(kw_only=True, metadata={"serializable": True}) - image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/drivers/image_query/dummy_image_query_driver.py b/griptape/drivers/image_query/dummy_image_query_driver.py deleted file mode 100644 index 62820efd7..000000000 --- a/griptape/drivers/image_query/dummy_image_query_driver.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.drivers import BaseImageQueryDriver -from griptape.exceptions import DummyError - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class DummyImageQueryDriver(BaseImageQueryDriver): - model: None = field(init=False, default=None, kw_only=True) - max_tokens: None = field(init=False, default=None, kw_only=True) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - raise DummyError(__class__.__name__, "try_query") diff --git a/griptape/drivers/image_query/openai_image_query_driver.py b/griptape/drivers/image_query/openai_image_query_driver.py deleted file mode 100644 index b607c97f5..000000000 --- a/griptape/drivers/image_query/openai_image_query_driver.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Optional - -import openai -from attrs import Factory, define, field -from openai.types.chat import ( - ChatCompletionContentPartImageParam, - ChatCompletionContentPartParam, - ChatCompletionContentPartTextParam, - ChatCompletionUserMessageParam, -) - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers.image_query.base_image_query_driver import BaseImageQueryDriver - - -@define -class OpenAiImageQueryDriver(BaseImageQueryDriver): - model: str = field(kw_only=True, metadata={"serializable": True}) - api_type: str = field(default=openai.api_type, kw_only=True) - api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) - base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - api_key: Optional[str] = field(default=None, kw_only=True) - organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - message_parts: list[ChatCompletionContentPartParam] = [ - ChatCompletionContentPartTextParam(type="text", text=query), - ] - - for image in images: - message_parts.append( - ChatCompletionContentPartImageParam( - type="image_url", - image_url={"url": f"data:{image.mime_type};base64,{image.base64}", "detail": self.image_quality}, - ), - ) - - messages = ChatCompletionUserMessageParam(content=message_parts, role="user") - params = {"model": self.model, "messages": [messages], "max_tokens": self.max_tokens} - - response = self.client.chat.completions.create(**params) - - if len(response.choices) != 1: - raise Exception("Image query responses with more than one choice are not supported yet.") - - return TextArtifact(response.choices[0].message.content) diff --git a/griptape/drivers/image_query_model/__init__.py b/griptape/drivers/image_query_model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/drivers/image_query_model/base_image_query_model_driver.py b/griptape/drivers/image_query_model/base_image_query_model_driver.py deleted file mode 100644 index 5f60367d5..000000000 --- a/griptape/drivers/image_query_model/base_image_query_model_driver.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from attrs import define - -from griptape.mixins import SerializableMixin - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class BaseImageQueryModelDriver(SerializableMixin, ABC): - @abstractmethod - def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: ... - - @abstractmethod - def process_output(self, output: dict) -> TextArtifact: ... diff --git a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py deleted file mode 100644 index 8260ce3d5..000000000 --- a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -from attrs import define - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BaseImageQueryModelDriver - - -@define -class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver): - ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example - - def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: - content = [self._construct_image_message(image) for image in images] - content.append(self._construct_text_message(query)) - messages = self._construct_messages(content) - input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens} - - return input_params - - def process_output(self, output: dict) -> TextArtifact: - content_blocks = output["content"] - if len(content_blocks) < 1: - raise ValueError("Response content is empty") - - text_content = content_blocks[0]["text"] - - return TextArtifact(text_content) - - def _construct_image_message(self, image_data: ImageArtifact) -> dict: - data = image_data.base64 - media_type = image_data.mime_type - - return {"source": {"data": data, "media_type": media_type, "type": "base64"}, "type": "image"} - - def _construct_text_message(self, query: str) -> dict: - return {"text": query, "type": "text"} - - def _construct_messages(self, content: list) -> list: - return [{"content": content, "role": "user"}] diff --git a/griptape/engines/__init__.py b/griptape/engines/__init__.py index 7835b2238..d1b4b1e4d 100644 --- a/griptape/engines/__init__.py +++ b/griptape/engines/__init__.py @@ -8,7 +8,6 @@ from .image.variation_image_generation_engine import VariationImageGenerationEngine from .image.inpainting_image_generation_engine import InpaintingImageGenerationEngine from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine -from .image_query.image_query_engine import ImageQueryEngine from .audio.text_to_speech_engine import TextToSpeechEngine from .audio.audio_transcription_engine import AudioTranscriptionEngine @@ -23,7 +22,6 @@ "VariationImageGenerationEngine", "InpaintingImageGenerationEngine", "OutpaintingImageGenerationEngine", - "ImageQueryEngine", "TextToSpeechEngine", "AudioTranscriptionEngine", ] diff --git a/griptape/engines/image_query/__init__.py b/griptape/engines/image_query/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py deleted file mode 100644 index d0a1e99d4..000000000 --- a/griptape/engines/image_query/image_query_engine.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import define, field - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - from griptape.drivers import BaseImageQueryDriver - - -@define -class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(kw_only=True) - - def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - return self.image_query_driver.query(query, images) diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..c332ca88f 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -14,8 +14,6 @@ from .event_listener import EventListener from .start_image_generation_event import StartImageGenerationEvent from .finish_image_generation_event import FinishImageGenerationEvent -from .start_image_query_event import StartImageQueryEvent -from .finish_image_query_event import FinishImageQueryEvent from .base_text_to_speech_event import BaseTextToSpeechEvent from .start_text_to_speech_event import StartTextToSpeechEvent from .finish_text_to_speech_event import FinishTextToSpeechEvent @@ -40,8 +38,6 @@ "EventListener", "StartImageGenerationEvent", "FinishImageGenerationEvent", - "StartImageQueryEvent", - "FinishImageQueryEvent", "BaseTextToSpeechEvent", "StartTextToSpeechEvent", "FinishTextToSpeechEvent", diff --git a/griptape/events/base_image_query_event.py b/griptape/events/base_image_query_event.py deleted file mode 100644 index b634f2bf0..000000000 --- a/griptape/events/base_image_query_event.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC - -from attrs import define - -from griptape.events import BaseEvent - - -@define -class BaseImageQueryEvent(BaseEvent, ABC): ... diff --git a/griptape/events/finish_image_query_event.py b/griptape/events/finish_image_query_event.py deleted file mode 100644 index 3eb2e7ccb..000000000 --- a/griptape/events/finish_image_query_event.py +++ /dev/null @@ -1,8 +0,0 @@ -from attrs import define, field - -from griptape.events.base_image_query_event import BaseImageQueryEvent - - -@define -class FinishImageQueryEvent(BaseImageQueryEvent): - result: str = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/start_image_query_event.py b/griptape/events/start_image_query_event.py deleted file mode 100644 index 8deeaaa5a..000000000 --- a/griptape/events/start_image_query_event.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from attrs import define, field - -from griptape.events.base_image_query_event import BaseImageQueryEvent - - -@define -class StartImageQueryEvent(BaseImageQueryEvent): - query: str = field(kw_only=True, metadata={"serializable": True}) - images_info: list[str] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index f2172b119..7ecc31e05 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -118,7 +118,6 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseConversationMemoryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseTextToSpeechDriver, BaseVectorStoreDriver, @@ -135,7 +134,6 @@ def _resolve_types(cls, attrs_cls: type) -> None: localns={ "Any": Any, "BasePromptDriver": BasePromptDriver, - "BaseImageQueryDriver": BaseImageQueryDriver, "BaseEmbeddingDriver": BaseEmbeddingDriver, "BaseVectorStoreDriver": BaseVectorStoreDriver, "BaseTextToSpeechDriver": BaseTextToSpeechDriver, diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 764d1669a..f9836d65e 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -16,7 +16,6 @@ from .inpainting_image_generation_task import InpaintingImageGenerationTask from .outpainting_image_generation_task import OutpaintingImageGenerationTask from .variation_image_generation_task import VariationImageGenerationTask -from .image_query_task import ImageQueryTask from .base_audio_generation_task import BaseAudioGenerationTask from .text_to_speech_task import TextToSpeechTask from .structure_run_task import StructureRunTask @@ -41,7 +40,6 @@ "VariationImageGenerationTask", "InpaintingImageGenerationTask", "OutpaintingImageGenerationTask", - "ImageQueryTask", "BaseAudioGenerationTask", "TextToSpeechTask", "StructureRunTask", diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py deleted file mode 100644 index ea1b53739..000000000 --- a/griptape/tasks/image_query_task.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from typing import Callable - -from attrs import define, field - -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -from griptape.engines import ImageQueryEngine -from griptape.tasks import BaseTask -from griptape.utils import J2 - - -@define -class ImageQueryTask(BaseTask): - """A task that executes a natural language query on one or more input images. - - Accepts a text prompt and a list of - images as input in one of the following formats: - - tuple of (template string, list[ImageArtifact]) - - tuple of (TextArtifact, list[ImageArtifact]) - - Callable that returns a tuple of (TextArtifact, list[ImageArtifact]). - - Attributes: - image_query_engine: The engine used to execute the query. - """ - - _image_query_engine: ImageQueryEngine = field(default=None, kw_only=True, alias="image_query_engine") - _input: ( - tuple[str, list[ImageArtifact]] - | tuple[TextArtifact, list[ImageArtifact]] - | Callable[[BaseTask], ListArtifact] - | ListArtifact - ) = field(default=None, alias="input") - - @property - def input(self) -> ListArtifact: - if isinstance(self._input, ListArtifact): - return self._input - elif isinstance(self._input, tuple): - if isinstance(self._input[0], TextArtifact): - query_text = self._input[0] - else: - query_text = TextArtifact(J2().render_from_string(self._input[0], **self.full_context)) - - return ListArtifact([query_text, *self._input[1]]) - elif isinstance(self._input, Callable): - return self._input(self) - else: - raise ValueError( - "Input must be a tuple of a TextArtifact and a list of ImageArtifacts or a callable that " - "returns a tuple of a TextArtifact and a list of ImageArtifacts.", - ) - - @input.setter - def input( - self, - value: ( - tuple[str, list[ImageArtifact]] - | tuple[TextArtifact, list[ImageArtifact]] - | Callable[[BaseTask], ListArtifact] - ), - ) -> None: - self._input = value - - @property - def image_query_engine(self) -> ImageQueryEngine: - if self._image_query_engine is None: - if self.structure is not None: - self._image_query_engine = ImageQueryEngine(image_query_driver=self.structure.config.image_query_driver) - else: - raise ValueError("Image Query Engine is not set.") - return self._image_query_engine - - @image_query_engine.setter - def image_query_engine(self, value: ImageQueryEngine) -> None: - self._image_query_engine = value - - def run(self) -> TextArtifact: - query = self.input.value[0] - - if all(isinstance(artifact, ImageArtifact) for artifact in self.input.value[1:]): - image_artifacts = [ - image_artifact for image_artifact in self.input.value[1:] if isinstance(image_artifact, ImageArtifact) - ] - else: - raise ValueError("All inputs after the query must be ImageArtifacts.") - - self.output = self.image_query_engine.run(query.value, image_artifacts) - - return self.output diff --git a/griptape/tools/image_query_client/tool.py b/griptape/tools/image_query_client/tool.py index a10929b13..4881e8d2b 100644 --- a/griptape/tools/image_query_client/tool.py +++ b/griptape/tools/image_query_client/tool.py @@ -7,18 +7,19 @@ from schema import Literal, Schema from griptape.artifacts import BlobArtifact, ErrorArtifact, ImageArtifact, TextArtifact +from griptape.common import ImageMessageContent, Message, PromptStack, TextMessageContent from griptape.loaders import ImageLoader from griptape.tools import BaseTool from griptape.utils import load_artifact_from_memory from griptape.utils.decorators import activity if TYPE_CHECKING: - from griptape.engines import ImageQueryEngine + from griptape.drivers import BasePromptDriver @define class ImageQueryClient(BaseTool): - image_query_engine: ImageQueryEngine = field(kw_only=True) + prompt_driver: BasePromptDriver = field(kw_only=True) image_loader: ImageLoader = field(default=Factory(lambda: ImageLoader()), kw_only=True) @activity( @@ -36,14 +37,31 @@ class ImageQueryClient(BaseTool): }, ) def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: - query = params["values"]["query"] + query: str = params["values"]["query"] image_paths = params["values"]["image_paths"] - image_artifacts = [] + image_artifacts: list[ImageArtifact] = [] for image_path in image_paths: image_artifacts.append(self.image_loader.load(Path(image_path).read_bytes())) - return self.image_query_engine.run(query, image_artifacts) + output = self.prompt_driver.run( + PromptStack( + [ + Message( + content=[ + *[ImageMessageContent(image_artifact) for image_artifact in image_artifacts], + TextMessageContent(TextArtifact(query)), + ], + role=Message.USER_ROLE, + ) + ] + ) + ).to_artifact() + + if isinstance(output, (TextArtifact, ErrorArtifact)): + return output + else: + raise Exception("Invalid output type") @activity( config={ @@ -95,4 +113,21 @@ def query_images_from_memory(self, params: dict[str, Any]) -> TextArtifact | Err except Exception as e: return ErrorArtifact(str(e)) - return self.image_query_engine.run(query, image_artifacts) + output = self.prompt_driver.run( + PromptStack( + [ + Message( + content=[ + *[ImageMessageContent(image_artifact) for image_artifact in image_artifacts], + TextMessageContent(TextArtifact(query)), + ], + role=Message.USER_ROLE, + ) + ] + ) + ).to_artifact() + + if isinstance(output, (TextArtifact, ErrorArtifact)): + return output + else: + raise Exception("Invalid output type") diff --git a/mkdocs.yml b/mkdocs.yml index 6c9bffa13..c5f8ed758 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -95,7 +95,6 @@ nav: - Building Custom Tools: "griptape-tools/custom-tools/index.md" - Engines: - RAG Engines: "griptape-framework/engines/rag-engines.md" - - Image Query Engines: "griptape-framework/engines/image-query-engines.md" - Extraction Engines: "griptape-framework/engines/extraction-engines.md" - Summary Engines: "griptape-framework/engines/summary-engines.md" - Image Generation Engines: "griptape-framework/engines/image-generation-engines.md" diff --git a/tests/mocks/mock_image_query_driver.py b/tests/mocks/mock_image_query_driver.py deleted file mode 100644 index 8f8cc888c..000000000 --- a/tests/mocks/mock_image_query_driver.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from attrs import define - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BaseImageQueryDriver - - -@define -class MockImageQueryDriver(BaseImageQueryDriver): - model: Optional[str] = None - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - return TextArtifact(value="mock text") diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py index 3f95288f4..f3070aad0 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_structure_config.py @@ -3,7 +3,6 @@ from griptape.config import StructureConfig from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_image_query_driver import MockImageQueryDriver from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -15,9 +14,6 @@ class MockStructureConfig(StructureConfig): image_generation_driver: MockImageGenerationDriver = field( default=Factory(lambda: MockImageGenerationDriver(model="dall-e-2")), metadata={"serializable": True} ) - image_query_driver: MockImageQueryDriver = field( - default=Factory(lambda: MockImageQueryDriver(model="gpt-4-vision-preview")), metadata={"serializable": True} - ) embedding_driver: MockEmbeddingDriver = field( default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} ) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index afe9b3720..37084eb42 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -40,12 +40,6 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { - "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "max_tokens": 256, - "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, - }, "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -93,12 +87,6 @@ def test_to_dict_with_values(self, config_with_values): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { - "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "max_tokens": 256, - "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, - }, "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 05519fa5e..172a007da 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -27,11 +27,6 @@ def test_to_dict(self, config): "use_native_tools": True, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": { - "type": "AnthropicImageQueryDriver", - "model": "claude-3-5-sonnet-20240620", - "max_tokens": 256, - }, "embedding_driver": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index dcdc3a1dc..3fe85fa6a 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -59,17 +59,6 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "image_query_driver": { - "base_url": None, - "image_quality": "auto", - "max_tokens": 256, - "model": "gpt-4o", - "api_version": "2024-02-01", - "azure_deployment": "gpt-4o", - "azure_endpoint": "http://localhost:8080", - "organization": None, - "type": "AzureOpenAiImageQueryDriver", - }, "vector_store_driver": { "embedding_driver": { "base_url": None, diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_structure_config.py index 113a589ec..26cd2fdb5 100644 --- a/tests/unit/config/test_cohere_structure_config.py +++ b/tests/unit/config/test_cohere_structure_config.py @@ -12,7 +12,6 @@ def test_to_dict(self, config): assert config.to_dict() == { "type": "CohereStructureConfig", "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, "conversation_memory_driver": None, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index e193cc983..a2d809441 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -27,7 +27,6 @@ def test_to_dict(self, config): "use_native_tools": True, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, "embedding_driver": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 8969e0ad0..4525073fb 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -46,15 +46,6 @@ def test_to_dict(self, config): "style": None, "type": "OpenAiImageGenerationDriver", }, - "image_query_driver": { - "api_version": None, - "base_url": None, - "image_quality": "auto", - "max_tokens": 256, - "model": "gpt-4o", - "organization": None, - "type": "OpenAiImageQueryDriver", - }, "vector_store_driver": { "embedding_driver": { "base_url": None, diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..af2366d16 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -22,7 +22,6 @@ def test_to_dict(self, config): "conversation_memory_driver": None, "embedding_driver": {"type": "DummyEmbeddingDriver"}, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, "vector_store_driver": { "embedding_driver": {"type": "DummyEmbeddingDriver"}, "type": "DummyVectorStoreDriver", @@ -66,7 +65,6 @@ def test_drivers(self, config): assert config.drivers == [ config.prompt_driver, config.image_generation_driver, - config.image_query_driver, config.embedding_driver, config.vector_store_driver, config.conversation_memory_driver, diff --git a/tests/unit/drivers/image_query/__init__.py b/tests/unit/drivers/image_query/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py deleted file mode 100644 index 9493ab23d..000000000 --- a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py +++ /dev/null @@ -1,52 +0,0 @@ -import io -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import AmazonBedrockImageQueryDriver - - -class TestAmazonBedrockImageQueryDriver: - @pytest.fixture() - def bedrock_client(self, mocker): - return Mock() - - @pytest.fixture() - def session(self, bedrock_client): - session = Mock() - session.client.return_value = bedrock_client - - return session - - @pytest.fixture() - def model_driver(self): - model_driver = Mock() - model_driver.image_query_request_parameters.return_value = {} - model_driver.process_output.return_value = TextArtifact("content") - - return model_driver - - @pytest.fixture() - def image_query_driver(self, session, model_driver): - return AmazonBedrockImageQueryDriver(session=session, model="model", image_query_model_driver=model_driver) - - def test_init(self, image_query_driver): - assert image_query_driver - - def test_try_query(self, image_query_driver): - image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} - - text_artifact = image_query_driver.try_query( - "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")] - ) - - assert text_artifact.value == "content" - - def test_try_query_no_body(self, image_query_driver): - image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"")} - - with pytest.raises(ValueError): - image_query_driver.try_query( - "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")] - ) diff --git a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py deleted file mode 100644 index db4b2407c..000000000 --- a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py +++ /dev/null @@ -1,97 +0,0 @@ -import base64 -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import AnthropicImageQueryDriver - - -class TestAnthropicImageQueryDriver: - @pytest.fixture() - def mock_client(self, mocker): - mock_client = mocker.patch("anthropic.Anthropic") - return_value = Mock(text="Content") - mock_client.return_value.messages.create.return_value.content = [return_value] - - return mock_client - - @pytest.mark.parametrize( - "model", [("claude-3-haiku-20240307"), ("claude-3-sonnet-20240229"), ("claude-3-opus-20240229")] - ) - def test_init(self, model): - assert AnthropicImageQueryDriver(model=model) - - def test_try_query(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model") - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - - text_artifact = driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="png")] - ) - - expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string) - - mock_client.return_value.messages.create.assert_called_once_with( - model=driver.model, max_tokens=256, messages=[expected_message] - ) - - assert text_artifact.value == "Content" - - def test_try_query_max_tokens_value(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model", max_tokens=1024) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - - text_artifact = driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="png")] - ) - - expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string) - - mock_client.return_value.messages.create.assert_called_once_with( - model=driver.model, max_tokens=1024, messages=[expected_message] - ) - - assert text_artifact.value == "Content" - - def test_try_query_max_tokens_none(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore[reportArgumentType] - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - with pytest.raises(TypeError): - driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="png")] - ) - - def test_try_query_wrong_media_type(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model") - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - - # we expect this to pass Griptape code as the model will error appropriately - text_artifact = driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="exr")] - ) - - expected_message = self._expected_message(test_binary_data, "image/exr", test_prompt_string) - - mock_client.return_value.messages.create.assert_called_once_with( - model=driver.model, messages=[expected_message], max_tokens=256 - ) - - assert text_artifact.value == "Content" - - def _expected_message(self, expected_data, expected_media_type, expected_prompt_string): - encoded_data = base64.b64encode(expected_data).decode("utf-8") - return { - "content": [ - { - "source": {"data": encoded_data, "media_type": expected_media_type, "type": "base64"}, - "type": "image", - }, - {"text": expected_prompt_string, "type": "text"}, - ], - "role": "user", - } diff --git a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py deleted file mode 100644 index a1d428197..000000000 --- a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py +++ /dev/null @@ -1,70 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import AzureOpenAiImageQueryDriver - - -class TestAzureOpenAiVisionImageQueryDriver: - @pytest.fixture() - def mock_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create - mock_choice = Mock(message=Mock(content="expected_output_text")) - mock_chat_create.return_value.choices = [mock_choice] - return mock_chat_create - - def test_init(self): - assert AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" - ) - assert AzureOpenAiImageQueryDriver(azure_endpoint="test-endpoint", model="gpt-4").azure_deployment == "gpt-4" - - def test_try_query_defaults(self, mock_completion_create): - driver = AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" - ) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - text_artifact = driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=256) - - assert text_artifact.value == "expected_output_text" - - def test_try_query_max_tokens(self, mock_completion_create): - driver = AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4", max_tokens=1024 - ) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=1024) - - def test_try_query_multiple_choices(self, mock_completion_create): - mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) - driver = AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" - ) - - with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): - driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) - - def _expected_messages(self, expected_prompt_string, expected_binary_data): - return { - "content": [ - {"type": "text", "text": expected_prompt_string}, - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{expected_binary_data}", "detail": "auto"}, - }, - ], - "role": "user", - } diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py deleted file mode 100644 index 14de15f2d..000000000 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ /dev/null @@ -1,26 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.events.event_listener import EventListener -from tests.mocks.mock_image_query_driver import MockImageQueryDriver - - -class TestBaseImageQueryDriver: - @pytest.fixture() - def driver(self): - return MockImageQueryDriver(model="foo") - - def test_query_publishes_events(self, driver): - mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) - - driver.query("foo", []) - - call_args = mock_handler.call_args_list - - args, _kwargs = call_args[0] - assert args[0].type == "StartImageQueryEvent" - - args, _kwargs = call_args[1] - assert args[0].type == "FinishImageQueryEvent" diff --git a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py deleted file mode 100644 index 9da59c435..000000000 --- a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import DummyImageQueryDriver -from griptape.exceptions import DummyError - - -class TestDummyImageQueryDriver: - @pytest.fixture() - def image_query_driver(self): - return DummyImageQueryDriver() - - def test_init(self, image_query_driver): - assert image_query_driver - - def test_try_query(self, image_query_driver): - with pytest.raises(DummyError): - image_query_driver.try_query("Prompt", [ImageArtifact(value=b"", width=100, height=100, format="png")]) diff --git a/tests/unit/drivers/image_query/test_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_openai_image_query_driver.py deleted file mode 100644 index 9c4b011a6..000000000 --- a/tests/unit/drivers/image_query/test_openai_image_query_driver.py +++ /dev/null @@ -1,61 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import OpenAiImageQueryDriver - - -class TestOpenAiVisionImageQueryDriver: - @pytest.fixture() - def mock_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create - mock_choice = Mock(message=Mock(content="expected_output_text")) - mock_chat_create.return_value.choices = [mock_choice] - return mock_chat_create - - def test_init(self): - assert OpenAiImageQueryDriver(model="gpt-4-vision-preview") - - def test_try_query_defaults(self, mock_completion_create): - driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - text_artifact = driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=256) - - assert text_artifact.value == "expected_output_text" - - def test_try_query_max_tokens(self, mock_completion_create): - driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=1024) - - def test_try_query_multiple_choices(self, mock_completion_create): - mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) - driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") - - with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): - driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) - - def _expected_messages(self, expected_prompt_string, expected_binary_data): - return { - "content": [ - {"type": "text", "text": expected_prompt_string}, - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{expected_binary_data}", "detail": "auto"}, - }, - ], - "role": "user", - } diff --git a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py deleted file mode 100644 index c274f71dd..000000000 --- a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BedrockClaudeImageQueryModelDriver - - -class TestBedrockClaudeImageQueryModelDriver: - def test_init(self): - assert BedrockClaudeImageQueryModelDriver() - - def test_image_query_request_parameters(self): - model_driver = BedrockClaudeImageQueryModelDriver() - params = model_driver.image_query_request_parameters( - "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")], 256 - ) - - assert isinstance(params, dict) - assert "anthropic_version" in params - assert params["anthropic_version"] == "bedrock-2023-05-31" - assert "messages" in params - assert len(params["messages"]) == 1 - assert "max_tokens" in params - assert params["max_tokens"] == 256 - - def test_process_output(self): - model_driver = BedrockClaudeImageQueryModelDriver() - output = model_driver.process_output({"content": [{"text": "Content"}]}) - - assert isinstance(output, TextArtifact) - assert output.value == "Content" - - def test_process_output_no_content_key(self): - with pytest.raises(KeyError): - BedrockClaudeImageQueryModelDriver().process_output({"explicitly-not-content": ["ContentBlock"]}) - - def test_process_output_bad_length(self): - with pytest.raises(ValueError): - BedrockClaudeImageQueryModelDriver().process_output({"content": []}) - - def test_process_output_no_text_key(self): - with pytest.raises(KeyError): - BedrockClaudeImageQueryModelDriver().process_output({"content": [{"not-text": "Content"}]}) diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py deleted file mode 100644 index 447faa01c..000000000 --- a/tests/unit/tasks/test_image_query_task.py +++ /dev/null @@ -1,83 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.artifacts.list_artifact import ListArtifact -from griptape.engines import ImageQueryEngine -from griptape.structures import Agent -from griptape.tasks import BaseTask, ImageQueryTask -from tests.mocks.mock_image_query_driver import MockImageQueryDriver -from tests.mocks.mock_structure_config import MockStructureConfig - - -class TestImageQueryTask: - @pytest.fixture() - def image_query_engine(self) -> Mock: - mock = Mock() - mock.run.return_value = TextArtifact("image") - - return mock - - @pytest.fixture() - def text_artifact(self): - return TextArtifact(value="some text") - - @pytest.fixture() - def image_artifact(self): - return ImageArtifact(value=b"some image data", format="png", width=512, height=512) - - def test_text_inputs(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - task = ImageQueryTask((text_artifact.value, [image_artifact, image_artifact])) - - assert task.input.value[0].value == text_artifact.value - assert task.input.value[1] == image_artifact - assert task.input.value[2] == image_artifact - - def test_artifact_inputs(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - input_tuple = (text_artifact, [image_artifact, image_artifact]) - task = ImageQueryTask(input_tuple) - - assert task.input.value[0] == text_artifact - assert task.input.value[1] == image_artifact - assert task.input.value[2] == image_artifact - - def test_callable_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - artifacts = [text_artifact, image_artifact, image_artifact] - - def callable_input(task: BaseTask) -> ListArtifact: - return ListArtifact(value=artifacts) - - task = ImageQueryTask(callable_input) - - assert task.input.value == artifacts - - def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - artifacts = [text_artifact, image_artifact, image_artifact] - - task = ImageQueryTask(ListArtifact(value=artifacts)) - - assert task.input.value == artifacts - - def test_config_image_generation_engine(self, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - Agent(config=MockStructureConfig()).add_task(task) - - assert isinstance(task.image_query_engine, ImageQueryEngine) - assert isinstance(task.image_query_engine.image_query_driver, MockImageQueryDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - - with pytest.raises(ValueError, match="Image Query Engine"): - task.image_query_engine # noqa: B018 - - def test_run(self, image_query_engine, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) - task.run() - - assert task.output.value == "image" - - def test_bad_run(self, image_query_engine, text_artifact, image_artifact): - with pytest.raises(ValueError, match="All inputs"): - ImageQueryTask(("foo", [image_artifact, text_artifact]), image_query_engine=image_query_engine).run()