diff --git a/CHANGELOG.md b/CHANGELOG.md index a43c39aad..544f9d5cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TrafilaturaWebScraperDriver.no_ssl` parameter to disable SSL verification. Defaults to `False`. - `CsvExtractionEngine.format_header` parameter to format the header row. +- `PromptStack.from_artifact` factory method for creating a Prompt Stack with a user message from an Artifact. ### Changed @@ -20,6 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Renamed `Structure.is_executing()` to `Structure.is_running()`. - **BREAKING**: Removed ability to pass bytes to `BaseFileLoader.fetch`. - **BREAKING**: Updated `CsvExtractionEngine.format_row` to format rows as comma-separated values instead of newline-separated key-value pairs. +- **BREAKING**: Removed all `ImageQueryDriver`s, use `PromptDriver`s instead. +- **BREAKING**: Removed `ImageQueryTask`, use `PromptTask` instead. +- **BREAKING**: Updated `ImageQueryTool.image_query_driver` to `ImageQueryTool.prompt_driver`. +- `BasePromptDriver.run` can now accept an Artifact in addition to a Prompt Stack. - Improved `CsvExtractionEngine` prompts. - Tweaked `PromptResponseRagModule` system prompt to yield answers more consistently. - Removed `azure-core` and `azure-storage-blob` dependencies. diff --git a/MIGRATION.md b/MIGRATION.md index f3a21da80..7c795a349 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -22,9 +22,9 @@ loader = TextLoader() data = loader.parse(b"data") ``` -### Removed `ImageQueryEngine` +### Removed `ImageQueryEngine`, `ImageQueryDriver` -`ImageQueryEngine` has been removed. Use `ImageQueryDriver` instead. +`ImageQueryEngine` has been removed. Use `PromptDriver` instead. #### Before @@ -45,15 +45,15 @@ engine.run("Describe the weather in the image", [image_artifact])` #### After ```python -from griptape.drivers import OpenAiImageQueryDriver -from griptape.engines import ImageQueryEngine +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.drivers import OpenAiChatPromptDriver from griptape.loaders import ImageLoader -driver = OpenAiImageQueryDriver(model="gpt-4o", max_tokens=256) +driver = OpenAiChatPromptDriver(model="gpt-4o", max_tokens=256) -image_artifact = ImageLoader().load("mountain.png") +image_artifact = ImageLoader().load("./assets/mountain.jpg") -driver.query("Describe the weather in the image", [image_artifact])` +driver.run(ListArtifact([TextArtifact("Describe the weather in the image"), image_artifact])) ``` ### Removed `InpaintingImageGenerationEngine` @@ -209,6 +209,50 @@ driver.run_text_to_image( ) ``` +### Removed `ImageQueryTask`, use `PromptTask` instead + +`ImageQueryTask` has been removed. Use `PromptTask` instead. + +#### Before + +```python +from griptape.loaders import ImageLoader +from griptape.structures import Pipeline +from griptape.tasks import ImageQueryTask + +image_artifact = ImageLoader().load("mountain.png") + +pipeline = Pipeline( + tasks=[ + ImageQueryTask( + input=("Describe the weather in the image", [image_artifact]), + ) + ] +) + +pipeline.run("Describe the weather in the image") +``` + +#### After + +```python +from griptape.loaders import ImageLoader +from griptape.structures import Pipeline +from griptape.tasks import PromptTask + +image_artifact = ImageLoader().load("mountain.png") + +pipeline = Pipeline( + tasks=[ + PromptTask( + input=("Describe the weather in the image", image_artifact), + ) + ] +) + +pipeline.run("Describe the weather in the image") +``` + ## 0.33.X to 0.34.X ### `AnthropicDriversConfig` Embedding Driver diff --git a/README.md b/README.md index 4010ed405..49929c69e 100644 --- a/README.md +++ b/README.md @@ -36,11 +36,10 @@ Tools provide capabilities for LLMs to interact with data and services. Griptape Drivers facilitate interactions with external resources and services: -- 🗣️ **Prompt Drivers** manage textual interactions with LLMs. +- 🗣️ **Prompt Drivers** manage textual and image interactions with LLMs. - 🔢 **Embedding Drivers** generate vector embeddings from textual inputs. - 💾 **Vector Store Drivers** manage the storage and retrieval of embeddings. - 🎨 **Image Generation Drivers** create images from text descriptions. -- 🔎 **Image Query Drivers** query images from text queries. - 💼 **SQL Drivers** interact with SQL databases. - 🌐 **Web Scraper Drivers** extract information from web pages. - 🧠 **Conversation Memory Drivers** manage the storage and retrieval of conversational data. diff --git a/docs/griptape-framework/drivers/image-query-drivers.md b/docs/griptape-framework/drivers/image-query-drivers.md deleted file mode 100644 index 0f40d15fc..000000000 --- a/docs/griptape-framework/drivers/image-query-drivers.md +++ /dev/null @@ -1,64 +0,0 @@ ---- -search: - boost: 2 ---- - -## Overview - -Image Query Drivers execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular Image Query Driver. - -!!! info - All Image Query Drivers default to a `max_tokens` of 256. It is recommended that you set this value to correspond to the desired response length. - -## Image Query Drivers - -### Anthropic - -!!! info - To tune `max_tokens`, see [Anthropic's documentation on image tokens](https://docs.anthropic.com/claude/docs/vision#image-costs) for more information on how to relate token count to response length. - -The [AnthropicImageQueryDriver](../../reference/griptape/drivers/image_query/anthropic_image_query_driver.md) is used to query images using Anthropic's Claude 3 multi-modal model. Here is an example of how to use it: - -```python ---8<-- "docs/griptape-framework/drivers/src/image_query_drivers_1.py" -``` - -You can also specify multiple images with a single text prompt. This applies the same text prompt to all images specified, up to a max of 20. However, you will still receive one text response from the model currently. - -```python ---8<-- "docs/griptape-framework/drivers/src/image_query_drivers_2.py" -``` - -### OpenAI - -!!! info - While the `max_tokens` field is optional, it is recommended to set this to a value that corresponds to the desired response length. Without an explicit value, the model will default to very short responses. See [OpenAI's documentation](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) for more information on how to relate token count to response length. - -The [OpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/openai_image_query_driver.md) is used to query images using the OpenAI Vision API. Here is an example of how to use it: - -```python ---8<-- "docs/griptape-framework/drivers/src/image_query_drivers_3.py" -``` - -### Azure OpenAI - -!!! info - In order to use the `gpt-4-vision-preview` model on Azure OpenAI, the `gpt-4` model must be deployed with the version set to `vision-preview`. More information can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/gpt-with-vision). - -The [AzureOpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/azure_openai_image_query_driver.md) is used to query images using the Azure OpenAI Vision API. Here is an example of how to use it: - -```python ---8<-- "docs/griptape-framework/drivers/src/image_query_drivers_4.py" -``` - -### Amazon Bedrock - -The [Amazon Bedrock Image Query Driver](../../reference/griptape/drivers/image_query/amazon_bedrock_image_query_driver.md) provides multi-model access to image query models hosted by Amazon Bedrock. This Driver manages API calls to the Bedrock API, while the specific Model Drivers below format the API requests and parse the responses. - -#### Claude - -The [BedrockClaudeImageQueryModelDriver](../../reference/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.md) provides support for Claude models hosted by Bedrock. - -```python ---8<-- "docs/griptape-framework/drivers/src/image_query_drivers_5.py" -``` diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 131c596cf..9df4aae0a 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -19,6 +19,12 @@ Or use them independently: --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_2.py" ``` +You can pass images to the Driver if the model supports it: + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_driver_images.py" +``` + ## Prompt Drivers Griptape offers the following Prompt Drivers for interacting with LLMs. diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_1.py b/docs/griptape-framework/drivers/src/image_query_drivers_1.py deleted file mode 100644 index e2f537b08..000000000 --- a/docs/griptape-framework/drivers/src/image_query_drivers_1.py +++ /dev/null @@ -1,12 +0,0 @@ -from griptape.drivers import AnthropicImageQueryDriver -from griptape.loaders import ImageLoader - -driver = AnthropicImageQueryDriver( - model="claude-3-sonnet-20240229", - max_tokens=1024, -) - - -image_artifact = ImageLoader().load("tests/resources/mountain.png") - -driver.query("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_2.py b/docs/griptape-framework/drivers/src/image_query_drivers_2.py deleted file mode 100644 index d620c9675..000000000 --- a/docs/griptape-framework/drivers/src/image_query_drivers_2.py +++ /dev/null @@ -1,15 +0,0 @@ -from griptape.drivers import AnthropicImageQueryDriver -from griptape.loaders import ImageLoader - -driver = AnthropicImageQueryDriver( - model="claude-3-sonnet-20240229", - max_tokens=1024, -) - -image_artifact1 = ImageLoader().load("tests/resources/mountain.png") - -image_artifact2 = ImageLoader().load("tests/resources/cow.png") - -result = driver.query("Describe the weather in the image", [image_artifact1, image_artifact2]) - -print(result) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_3.py b/docs/griptape-framework/drivers/src/image_query_drivers_3.py deleted file mode 100644 index e6b1bb35b..000000000 --- a/docs/griptape-framework/drivers/src/image_query_drivers_3.py +++ /dev/null @@ -1,11 +0,0 @@ -from griptape.drivers import OpenAiImageQueryDriver -from griptape.loaders import ImageLoader - -driver = OpenAiImageQueryDriver( - model="gpt-4o", - max_tokens=256, -) - -image_artifact = ImageLoader().load("tests/resources/mountain.png") - -driver.query("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_4.py b/docs/griptape-framework/drivers/src/image_query_drivers_4.py deleted file mode 100644 index b5f480f63..000000000 --- a/docs/griptape-framework/drivers/src/image_query_drivers_4.py +++ /dev/null @@ -1,17 +0,0 @@ -import os - -from griptape.drivers import AzureOpenAiImageQueryDriver -from griptape.loaders import ImageLoader - -driver = AzureOpenAiImageQueryDriver( - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_2"], - api_key=os.environ["AZURE_OPENAI_API_KEY_2"], - model="gpt-4o", - azure_deployment="gpt-4o", - max_tokens=256, -) - - -image_artifact = ImageLoader().load("tests/resources/mountain.png") - -driver.query("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_5.py b/docs/griptape-framework/drivers/src/image_query_drivers_5.py deleted file mode 100644 index 332c84188..000000000 --- a/docs/griptape-framework/drivers/src/image_query_drivers_5.py +++ /dev/null @@ -1,19 +0,0 @@ -import boto3 - -from griptape.drivers import AmazonBedrockImageQueryDriver, BedrockClaudeImageQueryModelDriver -from griptape.loaders import ImageLoader - -session = boto3.Session(region_name="us-west-2") - -driver = AmazonBedrockImageQueryDriver( - image_query_model_driver=BedrockClaudeImageQueryModelDriver(), - model="anthropic.claude-3-sonnet-20240229-v1:0", - session=session, -) - -image_artifact = ImageLoader().load("tests/resources/mountain.png") - - -result = driver.query("Describe the weather in the image", [image_artifact]) - -print(result) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_images.py b/docs/griptape-framework/drivers/src/prompt_drivers_images.py new file mode 100644 index 000000000..244308ef5 --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_images.py @@ -0,0 +1,10 @@ +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.drivers import OpenAiChatPromptDriver +from griptape.loaders import ImageLoader + +driver = OpenAiChatPromptDriver(model="gpt-4o", max_tokens=256) + +image_artifact = ImageLoader().load("./tests/resources/mountain.jpg") +text_artifact = TextArtifact("Describe the weather in the image") + +driver.run(ListArtifact([text_artifact, image_artifact])) diff --git a/docs/griptape-framework/structures/src/tasks_15.py b/docs/griptape-framework/structures/src/tasks_15.py deleted file mode 100644 index 3f679b1ed..000000000 --- a/docs/griptape-framework/structures/src/tasks_15.py +++ /dev/null @@ -1,27 +0,0 @@ -from griptape.drivers import OpenAiImageQueryDriver -from griptape.loaders import ImageLoader -from griptape.structures import Pipeline -from griptape.tasks import ImageQueryTask - -# Create a driver configured to use OpenAI's GPT-4 Vision model. -driver = OpenAiImageQueryDriver( - model="gpt-4o", - max_tokens=100, -) - - -# Load the input image artifact. -image_artifact = ImageLoader().load("tests/resources/mountain.png") - -# Instantiate a pipeline. -pipeline = Pipeline() - -# Add an ImageQueryTask to the pipeline. -pipeline.add_task( - ImageQueryTask( - input=("{{ args[0] }}", [image_artifact]), - image_query_driver=driver, - ) -) - -pipeline.run("Describe the weather in the image") diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 07ff7cee3..691dba6e4 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -279,6 +279,7 @@ Today, these include: - [ExtractionTool](../../griptape-tools/official-tools/extraction-tool.md) - [RagClient](../../griptape-tools/official-tools/rag-tool.md) - [FileManagerTool](../../griptape-tools/official-tools/file-manager-tool.md) +- [ImageQueryTool](../../griptape-tools/official-tools/image-query-tool.md) ## Task Memory Considerations diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 2df1e4e1f..0101e405a 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -362,16 +362,6 @@ The [Outpainting Image Generation Task](../../reference/griptape/tasks/outpainti --8<-- "docs/griptape-framework/structures/src/tasks_14.py" ``` -## Image Query Task - -The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) performs a natural language query on one or more input images. This Task uses an [Image Query Driver](../drivers/image-query-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver. - -This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#text)) and a list of [Image Artifacts](../data/artifacts.md#image) or a Callable returning these two values. - -```python ---8<-- "docs/griptape-framework/structures/src/tasks_15.py" -``` - ## Structure Run Task The [Structure Run Task](../../reference/griptape/tasks/structure_run_task.md) runs another Structure with a given input. diff --git a/docs/griptape-tools/official-tools/src/image_query_tool_1.py b/docs/griptape-tools/official-tools/src/image_query_tool_1.py index f892e7d79..254ec6087 100644 --- a/docs/griptape-tools/official-tools/src/image_query_tool_1.py +++ b/docs/griptape-tools/official-tools/src/image_query_tool_1.py @@ -1,13 +1,12 @@ -from griptape.drivers import OpenAiImageQueryDriver +from griptape.drivers import OpenAiChatPromptDriver from griptape.structures import Agent from griptape.tools import ImageQueryTool -# Create an Image Query Driver. -driver = OpenAiImageQueryDriver(model="gpt-4o") +driver = OpenAiChatPromptDriver(model="gpt-4o") # Create an Image Query Tool configured to use the engine. tool = ImageQueryTool( - image_query_driver=driver, + prompt_driver=driver, ) # Create an agent and provide the tool to it. diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 77ce4ba9b..3b1b8ef74 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -60,6 +60,13 @@ def add_user_message(self, artifact: str | BaseArtifact) -> Message: def add_assistant_message(self, artifact: str | BaseArtifact) -> Message: return self.add_message(artifact, Message.ASSISTANT_ROLE) + @classmethod + def from_artifact(cls, artifact: BaseArtifact) -> PromptStack: + prompt_stack = cls() + prompt_stack.add_user_message(artifact) + + return prompt_stack + def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessageContent]: if isinstance(artifact, str): return [TextMessageContent(TextArtifact(artifact))] diff --git a/griptape/configs/drivers/amazon_bedrock_drivers_config.py b/griptape/configs/drivers/amazon_bedrock_drivers_config.py index 7a54ac522..c8b798aec 100644 --- a/griptape/configs/drivers/amazon_bedrock_drivers_config.py +++ b/griptape/configs/drivers/amazon_bedrock_drivers_config.py @@ -7,10 +7,8 @@ from griptape.configs.drivers import DriversConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, - AmazonBedrockImageQueryDriver, AmazonBedrockPromptDriver, AmazonBedrockTitanEmbeddingDriver, - BedrockClaudeImageQueryModelDriver, BedrockTitanImageGenerationModelDriver, LocalVectorStoreDriver, ) @@ -45,14 +43,6 @@ def image_generation_driver(self) -> AmazonBedrockImageGenerationDriver: image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), ) - @lazy_property() - def image_query_driver(self) -> AmazonBedrockImageQueryDriver: - return AmazonBedrockImageQueryDriver( - session=self.session, - model="anthropic.claude-3-5-sonnet-20240620-v1:0", - image_query_model_driver=BedrockClaudeImageQueryModelDriver(), - ) - @lazy_property() def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver( diff --git a/griptape/configs/drivers/anthropic_drivers_config.py b/griptape/configs/drivers/anthropic_drivers_config.py index 6a0fa52d4..3b3ee0f2d 100644 --- a/griptape/configs/drivers/anthropic_drivers_config.py +++ b/griptape/configs/drivers/anthropic_drivers_config.py @@ -2,7 +2,6 @@ from griptape.configs.drivers import DriversConfig from griptape.drivers import ( - AnthropicImageQueryDriver, AnthropicPromptDriver, ) from griptape.utils.decorators import lazy_property @@ -13,7 +12,3 @@ class AnthropicDriversConfig(DriversConfig): @lazy_property() def prompt_driver(self) -> AnthropicPromptDriver: return AnthropicPromptDriver(model="claude-3-5-sonnet-20240620") - - @lazy_property() - def image_query_driver(self) -> AnthropicImageQueryDriver: - return AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620") diff --git a/griptape/configs/drivers/azure_openai_drivers_config.py b/griptape/configs/drivers/azure_openai_drivers_config.py index 3ced8f9cd..b173830c0 100644 --- a/griptape/configs/drivers/azure_openai_drivers_config.py +++ b/griptape/configs/drivers/azure_openai_drivers_config.py @@ -9,7 +9,6 @@ AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, AzureOpenAiImageGenerationDriver, - AzureOpenAiImageQueryDriver, AzureOpenAiTextToSpeechDriver, LocalVectorStoreDriver, ) @@ -27,7 +26,6 @@ class AzureOpenAiDriversConfig(DriversConfig): api_key: An optional Azure API key. prompt_driver: An Azure OpenAI Chat Prompt Driver. image_generation_driver: An Azure OpenAI Image Generation Driver. - image_query_driver: An Azure OpenAI Vision Image Query Driver. embedding_driver: An Azure OpenAI Embedding Driver. vector_store_driver: A Local Vector Store Driver. """ @@ -72,16 +70,6 @@ def image_generation_driver(self) -> AzureOpenAiImageGenerationDriver: image_size="512x512", ) - @lazy_property() - def image_query_driver(self) -> AzureOpenAiImageQueryDriver: - return AzureOpenAiImageQueryDriver( - model="gpt-4o", - azure_endpoint=self.azure_endpoint, - api_key=self.api_key, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ) - @lazy_property() def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver( diff --git a/griptape/configs/drivers/base_drivers_config.py b/griptape/configs/drivers/base_drivers_config.py index 1c1ae149b..127804f94 100644 --- a/griptape/configs/drivers/base_drivers_config.py +++ b/griptape/configs/drivers/base_drivers_config.py @@ -14,7 +14,6 @@ BaseConversationMemoryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseRulesetDriver, BaseTextToSpeechDriver, @@ -30,9 +29,6 @@ class BaseDriversConfig(ABC, SerializableMixin): _image_generation_driver: BaseImageGenerationDriver = field( kw_only=True, default=None, metadata={"serializable": True}, alias="image_generation_driver" ) - _image_query_driver: BaseImageQueryDriver = field( - kw_only=True, default=None, metadata={"serializable": True}, alias="image_query_driver" - ) _embedding_driver: BaseEmbeddingDriver = field( kw_only=True, default=None, metadata={"serializable": True}, alias="embedding_driver" ) @@ -79,10 +75,6 @@ def prompt_driver(self) -> BasePromptDriver: ... @abstractmethod def image_generation_driver(self) -> BaseImageGenerationDriver: ... - @lazy_property() - @abstractmethod - def image_query_driver(self) -> BaseImageQueryDriver: ... - @lazy_property() @abstractmethod def embedding_driver(self) -> BaseEmbeddingDriver: ... diff --git a/griptape/configs/drivers/drivers_config.py b/griptape/configs/drivers/drivers_config.py index 5261640c8..b90afce34 100644 --- a/griptape/configs/drivers/drivers_config.py +++ b/griptape/configs/drivers/drivers_config.py @@ -9,7 +9,6 @@ DummyAudioTranscriptionDriver, DummyEmbeddingDriver, DummyImageGenerationDriver, - DummyImageQueryDriver, DummyPromptDriver, DummyTextToSpeechDriver, DummyVectorStoreDriver, @@ -24,7 +23,6 @@ BaseConversationMemoryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseRulesetDriver, BaseTextToSpeechDriver, @@ -42,10 +40,6 @@ def prompt_driver(self) -> BasePromptDriver: def image_generation_driver(self) -> BaseImageGenerationDriver: return DummyImageGenerationDriver() - @lazy_property() - def image_query_driver(self) -> BaseImageQueryDriver: - return DummyImageQueryDriver() - @lazy_property() def embedding_driver(self) -> BaseEmbeddingDriver: return DummyEmbeddingDriver() diff --git a/griptape/configs/drivers/openai_drivers_config.py b/griptape/configs/drivers/openai_drivers_config.py index ec1a4dc79..3448dd4a1 100644 --- a/griptape/configs/drivers/openai_drivers_config.py +++ b/griptape/configs/drivers/openai_drivers_config.py @@ -7,7 +7,6 @@ OpenAiChatPromptDriver, OpenAiEmbeddingDriver, OpenAiImageGenerationDriver, - OpenAiImageQueryDriver, OpenAiTextToSpeechDriver, ) from griptape.utils.decorators import lazy_property @@ -23,10 +22,6 @@ def prompt_driver(self) -> OpenAiChatPromptDriver: def image_generation_driver(self) -> OpenAiImageGenerationDriver: return OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512") - @lazy_property() - def image_query_driver(self) -> OpenAiImageQueryDriver: - return OpenAiImageQueryDriver(model="gpt-4o") - @lazy_property() def embedding_driver(self) -> OpenAiEmbeddingDriver: return OpenAiEmbeddingDriver(model="text-embedding-3-small") diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 4acbc9a19..7bc79ade4 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -80,17 +80,6 @@ HuggingFacePipelineImageGenerationDriver, ) -from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver -from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver - -from .image_query.base_image_query_driver import BaseImageQueryDriver -from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver -from .image_query.dummy_image_query_driver import DummyImageQueryDriver -from .image_query.openai_image_query_driver import OpenAiImageQueryDriver -from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver -from .image_query.azure_openai_image_query_driver import AzureOpenAiImageQueryDriver -from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver - from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver from .web_scraper.markdownify_web_scraper_driver import MarkdownifyWebScraperDriver @@ -204,15 +193,6 @@ "AzureOpenAiImageGenerationDriver", "DummyImageGenerationDriver", "HuggingFacePipelineImageGenerationDriver", - "BaseImageQueryModelDriver", - "BedrockClaudeImageQueryModelDriver", - "BaseImageQueryDriver", - "OpenAiImageQueryDriver", - "AzureOpenAiImageQueryDriver", - "DummyImageQueryDriver", - "AnthropicImageQueryDriver", - "BaseMultiModelImageQueryDriver", - "AmazonBedrockImageQueryDriver", "BaseWebScraperDriver", "TrafilaturaWebScraperDriver", "MarkdownifyWebScraperDriver", diff --git a/griptape/drivers/image_query/__init__.py b/griptape/drivers/image_query/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py deleted file mode 100644 index 9742cb9c7..000000000 --- a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import json -from typing import TYPE_CHECKING - -from attrs import Factory, define, field - -from griptape.drivers import BaseMultiModelImageQueryDriver -from griptape.utils import import_optional_dependency -from griptape.utils.decorators import lazy_property - -if TYPE_CHECKING: - import boto3 - from mypy_boto3_bedrock import BedrockClient - - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver): - session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - _client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - - @lazy_property() - def client(self) -> BedrockClient: - return self.session.client("bedrock-runtime") - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens) - - response = self.client.invoke_model( - modelId=self.model, - contentType="application/json", - accept="application/json", - body=json.dumps(payload), - ) - - response_body = json.loads(response.get("body").read()) - - if response_body is None: - raise ValueError("Model response is empty") - - try: - return self.image_query_model_driver.process_output(response_body) - except Exception as e: - raise ValueError(f"Output is unable to be processed as returned {e}") from e diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py deleted file mode 100644 index 191d95373..000000000 --- a/griptape/drivers/image_query/anthropic_image_query_driver.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BaseImageQueryDriver -from griptape.utils import import_optional_dependency -from griptape.utils.decorators import lazy_property - -if TYPE_CHECKING: - from anthropic import Anthropic - - -@define -class AnthropicImageQueryDriver(BaseImageQueryDriver): - """Anthropic Image Query Driver. - - Attributes: - api_key: Anthropic API key. - model: Anthropic model name. - client: Custom `Anthropic` client. - """ - - api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) - model: str = field(kw_only=True, metadata={"serializable": True}) - _client: Anthropic = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - - @lazy_property() - def client(self) -> Anthropic: - return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - if self.max_tokens is None: - raise TypeError("max_output_tokens can't be empty") - - response = self.client.messages.create(**self._base_params(query, images)) - content_blocks = response.content - - if len(content_blocks) < 1: - raise ValueError("Response content is empty") - - text_content = content_blocks[0].text - - return TextArtifact(text_content) - - def _base_params(self, text_query: str, images: list[ImageArtifact]) -> dict: - content = [self._construct_image_message(image) for image in images] - content.append(self._construct_text_message(text_query)) - messages = self._construct_messages(content) - return {"model": self.model, "messages": messages, "max_tokens": self.max_tokens} - - def _construct_image_message(self, image_data: ImageArtifact) -> dict: - data = image_data.base64 - media_type = image_data.mime_type - - return {"source": {"data": data, "media_type": media_type, "type": "base64"}, "type": "image"} - - def _construct_text_message(self, query: str) -> dict: - return {"text": query, "type": "text"} - - def _construct_messages(self, content: list) -> list: - return [{"content": content, "role": "user"}] diff --git a/griptape/drivers/image_query/azure_openai_image_query_driver.py b/griptape/drivers/image_query/azure_openai_image_query_driver.py deleted file mode 100644 index 637fa11cc..000000000 --- a/griptape/drivers/image_query/azure_openai_image_query_driver.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Optional - -import openai -from attrs import Factory, define, field - -from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver -from griptape.utils.decorators import lazy_property - - -@define -class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver): - """Driver for Azure-hosted OpenAI image query API. - - Attributes: - azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. - azure_endpoint: An Azure OpenAi endpoint. - azure_ad_token: An optional Azure Active Directory token. - azure_ad_token_provider: An optional Azure Active Directory token provider. - api_version: An Azure OpenAi API version. - client: An `openai.AzureOpenAI` client. - """ - - azure_deployment: str = field( - kw_only=True, - default=Factory(lambda self: self.model, takes_self=True), - metadata={"serializable": True}, - ) - azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) - azure_ad_token_provider: Optional[Callable[[], str]] = field( - kw_only=True, - default=None, - metadata={"serializable": False}, - ) - api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) - _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - - @lazy_property() - def client(self) -> openai.AzureOpenAI: - return openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ) diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py deleted file mode 100644 index ecfe0ca6e..000000000 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin -from griptape.mixins.serializable_mixin import SerializableMixin - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): - max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) - - def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( - StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), - ) - - def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) - - def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - for attempt in self.retrying(): - with attempt: - self.before_run(query, images) - - result = self.try_query(query, images) - - self.after_run(result.value) - - return result - else: - raise Exception("image query driver failed after all retry attempts") - - @abstractmethod - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: ... diff --git a/griptape/drivers/image_query/base_multi_model_image_query_driver.py b/griptape/drivers/image_query/base_multi_model_image_query_driver.py deleted file mode 100644 index 52af617e8..000000000 --- a/griptape/drivers/image_query/base_multi_model_image_query_driver.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from abc import ABC - -from attrs import define, field - -from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver - - -@define -class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC): - """Image Query Driver for platforms like Amazon Bedrock that host many LLM models. - - Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the - image generation request in the format required by the model and to process the output. - - Attributes: - model: Model name to use - image_query_model_driver: Image Model Driver to use. - """ - - model: str = field(kw_only=True, metadata={"serializable": True}) - image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/drivers/image_query/dummy_image_query_driver.py b/griptape/drivers/image_query/dummy_image_query_driver.py deleted file mode 100644 index 62820efd7..000000000 --- a/griptape/drivers/image_query/dummy_image_query_driver.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.drivers import BaseImageQueryDriver -from griptape.exceptions import DummyError - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class DummyImageQueryDriver(BaseImageQueryDriver): - model: None = field(init=False, default=None, kw_only=True) - max_tokens: None = field(init=False, default=None, kw_only=True) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - raise DummyError(__class__.__name__, "try_query") diff --git a/griptape/drivers/image_query/openai_image_query_driver.py b/griptape/drivers/image_query/openai_image_query_driver.py deleted file mode 100644 index f0ef9e148..000000000 --- a/griptape/drivers/image_query/openai_image_query_driver.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Optional - -import openai -from attrs import define, field -from openai.types.chat import ( - ChatCompletionContentPartImageParam, - ChatCompletionContentPartParam, - ChatCompletionContentPartTextParam, - ChatCompletionUserMessageParam, -) - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers.image_query.base_image_query_driver import BaseImageQueryDriver -from griptape.utils.decorators import lazy_property - - -@define -class OpenAiImageQueryDriver(BaseImageQueryDriver): - model: str = field(kw_only=True, metadata={"serializable": True}) - api_type: Optional[str] = field(default=openai.api_type, kw_only=True) - api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) - base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - api_key: Optional[str] = field(default=None, kw_only=True) - organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True}) - _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - - @lazy_property() - def client(self) -> openai.OpenAI: - return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - message_parts: list[ChatCompletionContentPartParam] = [ - ChatCompletionContentPartTextParam(type="text", text=query), - ] - - for image in images: - message_parts.append( - ChatCompletionContentPartImageParam( - type="image_url", - image_url={"url": f"data:{image.mime_type};base64,{image.base64}", "detail": self.image_quality}, - ), - ) - - messages = ChatCompletionUserMessageParam(content=message_parts, role="user") - params = {"model": self.model, "messages": [messages], "max_tokens": self.max_tokens} - - response = self.client.chat.completions.create(**params) - - if len(response.choices) != 1: - raise Exception("Image query responses with more than one choice are not supported yet.") - - return TextArtifact(response.choices[0].message.content) diff --git a/griptape/drivers/image_query_model/__init__.py b/griptape/drivers/image_query_model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/drivers/image_query_model/base_image_query_model_driver.py b/griptape/drivers/image_query_model/base_image_query_model_driver.py deleted file mode 100644 index ac97ee3c1..000000000 --- a/griptape/drivers/image_query_model/base_image_query_model_driver.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from attrs import define - -from griptape.mixins.serializable_mixin import SerializableMixin - -if TYPE_CHECKING: - from griptape.artifacts import ImageArtifact, TextArtifact - - -@define -class BaseImageQueryModelDriver(SerializableMixin, ABC): - @abstractmethod - def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: ... - - @abstractmethod - def process_output(self, output: dict) -> TextArtifact: ... diff --git a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py deleted file mode 100644 index 1785550a0..000000000 --- a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from attrs import define - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BaseImageQueryModelDriver - - -@define -class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver): - ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example - - def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: - content = [self._construct_image_message(image) for image in images] - content.append(self._construct_text_message(query)) - messages = self._construct_messages(content) - return {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens} - - def process_output(self, output: dict) -> TextArtifact: - content_blocks = output["content"] - if len(content_blocks) < 1: - raise ValueError("Response content is empty") - - text_content = content_blocks[0]["text"] - - return TextArtifact(text_content) - - def _construct_image_message(self, image_data: ImageArtifact) -> dict: - data = image_data.base64 - media_type = image_data.mime_type - - return {"source": {"data": data, "media_type": media_type, "type": "base64"}, "type": "image"} - - def _construct_text_message(self, query: str) -> dict: - return {"text": query, "type": "text"} - - def _construct_messages(self, content: list) -> list: - return [{"content": content, "role": "user"}] diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 1524d7ed9..707f67644 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field +from griptape.artifacts.base_artifact import BaseArtifact from griptape.common import ( ActionCallDeltaMessageContent, ActionCallMessageContent, @@ -71,7 +72,12 @@ def after_run(self, result: Message) -> None: ) @observable(tags=["PromptDriver.run()"]) - def run(self, prompt_stack: PromptStack) -> Message: + def run(self, prompt_input: PromptStack | BaseArtifact) -> Message: + if isinstance(prompt_input, BaseArtifact): + prompt_stack = PromptStack.from_artifact(prompt_input) + else: + prompt_stack = prompt_input + for attempt in self.retrying(): with attempt: self.before_run(prompt_stack) @@ -85,7 +91,7 @@ def run(self, prompt_stack: PromptStack) -> Message: raise Exception("prompt driver failed after all retry attempts") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: - """Converts a Prompt Stack to a string for token counting or model input. + """Converts a Prompt Stack to a string for token counting or model prompt_input. This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens. diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index a8aa7fa34..c7a0ea172 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -67,7 +67,7 @@ def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | if len(runs) > 0: summary = self.summarize_conversation_get_template.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( - prompt_stack=PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]), + PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]), ).to_text() else: return previous_summary diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 6a801e9db..4a049752b 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -165,7 +165,6 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseConversationMemoryDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, - BaseImageQueryDriver, BasePromptDriver, BaseRulesetDriver, BaseTextToSpeechDriver, @@ -186,7 +185,6 @@ def _resolve_types(cls, attrs_cls: type) -> None: localns={ "Any": Any, "BasePromptDriver": BasePromptDriver, - "BaseImageQueryDriver": BaseImageQueryDriver, "BaseEmbeddingDriver": BaseEmbeddingDriver, "BaseVectorStoreDriver": BaseVectorStoreDriver, "BaseTextToSpeechDriver": BaseTextToSpeechDriver, diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 4f65a4226..f28fe3ee9 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -13,7 +13,6 @@ from .inpainting_image_generation_task import InpaintingImageGenerationTask from .outpainting_image_generation_task import OutpaintingImageGenerationTask from .variation_image_generation_task import VariationImageGenerationTask -from .image_query_task import ImageQueryTask from .base_audio_generation_task import BaseAudioGenerationTask from .text_to_speech_task import TextToSpeechTask from .structure_run_task import StructureRunTask @@ -35,7 +34,6 @@ "VariationImageGenerationTask", "InpaintingImageGenerationTask", "OutpaintingImageGenerationTask", - "ImageQueryTask", "BaseAudioGenerationTask", "TextToSpeechTask", "StructureRunTask", diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py deleted file mode 100644 index c8fe9abcc..000000000 --- a/griptape/tasks/image_query_task.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, Union - -from attrs import Factory, define, field - -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -from griptape.configs.defaults_config import Defaults -from griptape.tasks import BaseTask -from griptape.utils import J2 - -if TYPE_CHECKING: - from griptape.drivers import BaseImageQueryDriver - - -@define -class ImageQueryTask(BaseTask): - """A task that executes a natural language query on one or more input images. - - Accepts a text prompt and a list of - images as input in one of the following formats: - - tuple of (template string, list[ImageArtifact]) - - tuple of (TextArtifact, list[ImageArtifact]) - - Callable that returns a tuple of (TextArtifact, list[ImageArtifact]). - - Attributes: - image_query_driver: The driver used to execute the query. - """ - - image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: Defaults.drivers_config.image_query_driver), kw_only=True - ) - _input: Union[ - tuple[str, list[ImageArtifact]], - tuple[TextArtifact, list[ImageArtifact]], - Callable[[BaseTask], ListArtifact], - ListArtifact, - ] = field(default=None, alias="input") - - @property - def input(self) -> ListArtifact: - if isinstance(self._input, ListArtifact): - return self._input - elif isinstance(self._input, tuple): - if isinstance(self._input[0], TextArtifact): - query_text = self._input[0] - else: - query_text = TextArtifact(J2().render_from_string(self._input[0], **self.full_context)) - - return ListArtifact([query_text, *self._input[1]]) - elif isinstance(self._input, Callable): - return self._input(self) - else: - raise ValueError( - "Input must be a tuple of a TextArtifact and a list of ImageArtifacts or a callable that " - "returns a tuple of a TextArtifact and a list of ImageArtifacts.", - ) - - @input.setter - def input( - self, - value: ( - Union[ - tuple[str, list[ImageArtifact]], - tuple[TextArtifact, list[ImageArtifact]], - Callable[[BaseTask], ListArtifact], - ] - ), - ) -> None: - self._input = value - - def try_run(self) -> TextArtifact: - query = self.input.value[0] - - if all(isinstance(artifact, ImageArtifact) for artifact in self.input.value[1:]): - image_artifacts = [ - image_artifact for image_artifact in self.input.value[1:] if isinstance(image_artifact, ImageArtifact) - ] - else: - raise ValueError("All inputs after the query must be ImageArtifacts.") - - self.output = self.image_query_driver.query(query.value, image_artifacts) - - return self.output diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 07b762167..d3345bc4e 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -61,7 +61,7 @@ def actions_schema(self) -> Schema: return self._actions_schema_for_tools([self.tool]) def try_run(self) -> BaseArtifact: - result = self.prompt_driver.run(prompt_stack=self.prompt_stack) + result = self.prompt_driver.run(self.prompt_stack) if self.prompt_driver.use_native_tools: subtask_input = result.to_artifact() diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 088ccd52d..f94899ab3 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -185,7 +185,7 @@ def try_run(self) -> BaseArtifact: subtask.run() subtask.after_run() - result = self.prompt_driver.run(prompt_stack=self.prompt_stack) + result = self.prompt_driver.run(self.prompt_stack) subtask = self.add_subtask(ActionsSubtask(result.to_artifact())) else: break diff --git a/griptape/tools/image_query/tool.py b/griptape/tools/image_query/tool.py index 193f810d8..de49a37a7 100644 --- a/griptape/tools/image_query/tool.py +++ b/griptape/tools/image_query/tool.py @@ -6,18 +6,20 @@ from schema import Literal, Schema from griptape.artifacts import BlobArtifact, ErrorArtifact, ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.common import PromptStack from griptape.loaders import ImageLoader from griptape.tools import BaseTool from griptape.utils import load_artifact_from_memory from griptape.utils.decorators import activity if TYPE_CHECKING: - from griptape.drivers import BaseImageQueryDriver + from griptape.drivers import BasePromptDriver @define class ImageQueryTool(BaseTool): - image_query_driver: BaseImageQueryDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(kw_only=True) image_loader: ImageLoader = field(default=Factory(lambda: ImageLoader()), kw_only=True) @activity( @@ -42,7 +44,14 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: for image_path in image_paths: image_artifacts.append(self.image_loader.load(image_path)) - return self.image_query_driver.query(query, image_artifacts) + return cast( + TextArtifact, + self.prompt_driver.run( + PromptStack.from_artifact( + ListArtifact([TextArtifact(query), *image_artifacts]), + ) + ).to_artifact(), + ) @activity( config={ @@ -94,4 +103,11 @@ def query_images_from_memory(self, params: dict[str, Any]) -> TextArtifact | Err except Exception as e: return ErrorArtifact(str(e)) - return self.image_query_driver.query(query, image_artifacts) + return cast( + TextArtifact, + self.prompt_driver.run( + PromptStack.from_artifact( + ListArtifact([TextArtifact(query), *image_artifacts]), + ) + ).to_artifact(), + ) diff --git a/mkdocs.yml b/mkdocs.yml index c54fc8b9c..6b35aa15c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -120,7 +120,6 @@ nav: - Vector Store Drivers: "griptape-framework/drivers/vector-store-drivers.md" - Image Generation Drivers: "griptape-framework/drivers/image-generation-drivers.md" - SQL Drivers: "griptape-framework/drivers/sql-drivers.md" - - Image Query Drivers: "griptape-framework/drivers/image-query-drivers.md" - Web Scraper Drivers: "griptape-framework/drivers/web-scraper-drivers.md" - Conversation Memory Drivers: "griptape-framework/drivers/conversation-memory-drivers.md" - Event Listener Drivers: "griptape-framework/drivers/event-listener-drivers.md" diff --git a/tests/mocks/mock_drivers_config.py b/tests/mocks/mock_drivers_config.py index aa9683dbd..edf263157 100644 --- a/tests/mocks/mock_drivers_config.py +++ b/tests/mocks/mock_drivers_config.py @@ -5,7 +5,6 @@ from griptape.utils.decorators import lazy_property from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_image_query_driver import MockImageQueryDriver from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -19,10 +18,6 @@ def prompt_driver(self) -> MockPromptDriver: def image_generation_driver(self) -> MockImageGenerationDriver: return MockImageGenerationDriver() - @lazy_property() - def image_query_driver(self) -> MockImageQueryDriver: - return MockImageQueryDriver() - @lazy_property() def embedding_driver(self) -> MockEmbeddingDriver: return MockEmbeddingDriver() diff --git a/tests/mocks/mock_image_query_driver.py b/tests/mocks/mock_image_query_driver.py deleted file mode 100644 index 8f8cc888c..000000000 --- a/tests/mocks/mock_image_query_driver.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from attrs import define - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BaseImageQueryDriver - - -@define -class MockImageQueryDriver(BaseImageQueryDriver): - model: Optional[str] = None - - def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: - return TextArtifact(value="mock text") diff --git a/tests/unit/common/test_prompt_stack.py b/tests/unit/common/test_prompt_stack.py index 983dccc4c..f24238097 100644 --- a/tests/unit/common/test_prompt_stack.py +++ b/tests/unit/common/test_prompt_stack.py @@ -108,3 +108,9 @@ def test_add_assistant_message(self, prompt_stack): assert prompt_stack.messages[0].role == "assistant" assert prompt_stack.messages[0].content[0].artifact.value == "foo" + + def test_from_artifact(self): + prompt_stack = PromptStack.from_artifact(TextArtifact("foo")) + + assert prompt_stack.messages[0].role == "user" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index bdde495de..52408922c 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -43,12 +43,6 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { - "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "max_tokens": 256, - "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, - }, "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -104,12 +98,6 @@ def test_to_dict_with_values(self, config_with_values): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { - "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "max_tokens": 256, - "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, - }, "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index bd232283f..8a6f25ef2 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -28,11 +28,6 @@ def test_to_dict(self, config): "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": { - "type": "AnthropicImageQueryDriver", - "model": "claude-3-5-sonnet-20240620", - "max_tokens": 256, - }, "embedding_driver": { "type": "DummyEmbeddingDriver", }, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index a4af1692f..b8d006778 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -63,17 +63,6 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "image_query_driver": { - "base_url": None, - "image_quality": "auto", - "max_tokens": 256, - "model": "gpt-4o", - "api_version": "2024-02-01", - "azure_deployment": "gpt-4o", - "azure_endpoint": "http://localhost:8080", - "organization": None, - "type": "AzureOpenAiImageQueryDriver", - }, "vector_store_driver": { "embedding_driver": { "base_url": None, diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 0032b6e7d..65295da52 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -12,7 +12,6 @@ def test_to_dict(self, config): assert config.to_dict() == { "type": "CohereDriversConfig", "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", "persist_file": None, diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index a1138769b..ca3cea60e 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -26,7 +26,6 @@ def test_to_dict(self, config): }, "embedding_driver": {"type": "DummyEmbeddingDriver"}, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, "vector_store_driver": { "embedding_driver": {"type": "DummyEmbeddingDriver"}, "type": "DummyVectorStoreDriver", @@ -64,7 +63,6 @@ def test_lazy_init(self): assert Defaults.drivers_config._prompt_driver is None assert Defaults.drivers_config._image_generation_driver is None - assert Defaults.drivers_config._image_query_driver is None assert Defaults.drivers_config._embedding_driver is None assert Defaults.drivers_config._vector_store_driver is None assert Defaults.drivers_config._conversation_memory_driver is None @@ -74,7 +72,6 @@ def test_lazy_init(self): assert Defaults.drivers_config.prompt_driver is not None assert Defaults.drivers_config.image_generation_driver is not None - assert Defaults.drivers_config.image_query_driver is not None assert Defaults.drivers_config.embedding_driver is not None assert Defaults.drivers_config.vector_store_driver is not None assert Defaults.drivers_config.conversation_memory_driver is not None @@ -84,7 +81,6 @@ def test_lazy_init(self): assert Defaults.drivers_config._prompt_driver is not None assert Defaults.drivers_config._image_generation_driver is not None - assert Defaults.drivers_config._image_query_driver is not None assert Defaults.drivers_config._embedding_driver is not None assert Defaults.drivers_config._vector_store_driver is not None assert Defaults.drivers_config._conversation_memory_driver is not None diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 8eacda7c6..c1459a400 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -28,7 +28,6 @@ def test_to_dict(self, config): "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, "embedding_driver": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 09ceccfdc..337896483 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -50,15 +50,6 @@ def test_to_dict(self, config): "style": None, "type": "OpenAiImageGenerationDriver", }, - "image_query_driver": { - "api_version": None, - "base_url": None, - "image_quality": "auto", - "max_tokens": 256, - "model": "gpt-4o", - "organization": None, - "type": "OpenAiImageQueryDriver", - }, "vector_store_driver": { "embedding_driver": { "base_url": None, diff --git a/tests/unit/drivers/image_query/__init__.py b/tests/unit/drivers/image_query/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py deleted file mode 100644 index 66b23d0c3..000000000 --- a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py +++ /dev/null @@ -1,52 +0,0 @@ -import io -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import AmazonBedrockImageQueryDriver - - -class TestAmazonBedrockImageQueryDriver: - @pytest.fixture() - def client(self, mocker): - return Mock() - - @pytest.fixture() - def session(self, client): - session = Mock() - session.client.return_value = client - - return session - - @pytest.fixture() - def model_driver(self): - model_driver = Mock() - model_driver.image_query_request_parameters.return_value = {} - model_driver.process_output.return_value = TextArtifact("content") - - return model_driver - - @pytest.fixture() - def image_query_driver(self, session, model_driver): - return AmazonBedrockImageQueryDriver(session=session, model="model", image_query_model_driver=model_driver) - - def test_init(self, image_query_driver): - assert image_query_driver - - def test_try_query(self, image_query_driver): - image_query_driver.client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} - - text_artifact = image_query_driver.try_query( - "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")] - ) - - assert text_artifact.value == "content" - - def test_try_query_no_body(self, image_query_driver): - image_query_driver.client.invoke_model.return_value = {"body": io.BytesIO(b"")} - - with pytest.raises(ValueError): - image_query_driver.try_query( - "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")] - ) diff --git a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py deleted file mode 100644 index db4b2407c..000000000 --- a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py +++ /dev/null @@ -1,97 +0,0 @@ -import base64 -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import AnthropicImageQueryDriver - - -class TestAnthropicImageQueryDriver: - @pytest.fixture() - def mock_client(self, mocker): - mock_client = mocker.patch("anthropic.Anthropic") - return_value = Mock(text="Content") - mock_client.return_value.messages.create.return_value.content = [return_value] - - return mock_client - - @pytest.mark.parametrize( - "model", [("claude-3-haiku-20240307"), ("claude-3-sonnet-20240229"), ("claude-3-opus-20240229")] - ) - def test_init(self, model): - assert AnthropicImageQueryDriver(model=model) - - def test_try_query(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model") - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - - text_artifact = driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="png")] - ) - - expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string) - - mock_client.return_value.messages.create.assert_called_once_with( - model=driver.model, max_tokens=256, messages=[expected_message] - ) - - assert text_artifact.value == "Content" - - def test_try_query_max_tokens_value(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model", max_tokens=1024) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - - text_artifact = driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="png")] - ) - - expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string) - - mock_client.return_value.messages.create.assert_called_once_with( - model=driver.model, max_tokens=1024, messages=[expected_message] - ) - - assert text_artifact.value == "Content" - - def test_try_query_max_tokens_none(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore[reportArgumentType] - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - with pytest.raises(TypeError): - driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="png")] - ) - - def test_try_query_wrong_media_type(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model") - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - - # we expect this to pass Griptape code as the model will error appropriately - text_artifact = driver.try_query( - test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100, format="exr")] - ) - - expected_message = self._expected_message(test_binary_data, "image/exr", test_prompt_string) - - mock_client.return_value.messages.create.assert_called_once_with( - model=driver.model, messages=[expected_message], max_tokens=256 - ) - - assert text_artifact.value == "Content" - - def _expected_message(self, expected_data, expected_media_type, expected_prompt_string): - encoded_data = base64.b64encode(expected_data).decode("utf-8") - return { - "content": [ - { - "source": {"data": encoded_data, "media_type": expected_media_type, "type": "base64"}, - "type": "image", - }, - {"text": expected_prompt_string, "type": "text"}, - ], - "role": "user", - } diff --git a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py deleted file mode 100644 index a1d428197..000000000 --- a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py +++ /dev/null @@ -1,70 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import AzureOpenAiImageQueryDriver - - -class TestAzureOpenAiVisionImageQueryDriver: - @pytest.fixture() - def mock_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create - mock_choice = Mock(message=Mock(content="expected_output_text")) - mock_chat_create.return_value.choices = [mock_choice] - return mock_chat_create - - def test_init(self): - assert AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" - ) - assert AzureOpenAiImageQueryDriver(azure_endpoint="test-endpoint", model="gpt-4").azure_deployment == "gpt-4" - - def test_try_query_defaults(self, mock_completion_create): - driver = AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" - ) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - text_artifact = driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=256) - - assert text_artifact.value == "expected_output_text" - - def test_try_query_max_tokens(self, mock_completion_create): - driver = AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4", max_tokens=1024 - ) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=1024) - - def test_try_query_multiple_choices(self, mock_completion_create): - mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) - driver = AzureOpenAiImageQueryDriver( - azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" - ) - - with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): - driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) - - def _expected_messages(self, expected_prompt_string, expected_binary_data): - return { - "content": [ - {"type": "text", "text": expected_prompt_string}, - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{expected_binary_data}", "detail": "auto"}, - }, - ], - "role": "user", - } diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py deleted file mode 100644 index 652ee11c5..000000000 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ /dev/null @@ -1,26 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.events import EventBus, EventListener -from tests.mocks.mock_image_query_driver import MockImageQueryDriver - - -class TestBaseImageQueryDriver: - @pytest.fixture() - def driver(self): - return MockImageQueryDriver(model="foo") - - def test_query_publishes_events(self, driver): - mock_handler = Mock() - EventBus.add_event_listener(EventListener(on_event=mock_handler)) - - driver.query("foo", []) - - call_args = mock_handler.call_args_list - - args, _kwargs = call_args[0] - assert args[0].type == "StartImageQueryEvent" - - args, _kwargs = call_args[1] - assert args[0].type == "FinishImageQueryEvent" diff --git a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py deleted file mode 100644 index 9da59c435..000000000 --- a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import DummyImageQueryDriver -from griptape.exceptions import DummyError - - -class TestDummyImageQueryDriver: - @pytest.fixture() - def image_query_driver(self): - return DummyImageQueryDriver() - - def test_init(self, image_query_driver): - assert image_query_driver - - def test_try_query(self, image_query_driver): - with pytest.raises(DummyError): - image_query_driver.try_query("Prompt", [ImageArtifact(value=b"", width=100, height=100, format="png")]) diff --git a/tests/unit/drivers/image_query/test_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_openai_image_query_driver.py deleted file mode 100644 index 9c4b011a6..000000000 --- a/tests/unit/drivers/image_query/test_openai_image_query_driver.py +++ /dev/null @@ -1,61 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact -from griptape.drivers import OpenAiImageQueryDriver - - -class TestOpenAiVisionImageQueryDriver: - @pytest.fixture() - def mock_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create - mock_choice = Mock(message=Mock(content="expected_output_text")) - mock_chat_create.return_value.choices = [mock_choice] - return mock_chat_create - - def test_init(self): - assert OpenAiImageQueryDriver(model="gpt-4-vision-preview") - - def test_try_query_defaults(self, mock_completion_create): - driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - text_artifact = driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=256) - - assert text_artifact.value == "expected_output_text" - - def test_try_query_max_tokens(self, mock_completion_create): - driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024) - test_prompt_string = "Prompt String" - test_binary_data = b"test-data" - test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") - driver.try_query(test_prompt_string, [test_image]) - - messages = self._expected_messages(test_prompt_string, test_image.base64) - - mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=1024) - - def test_try_query_multiple_choices(self, mock_completion_create): - mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) - driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") - - with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): - driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) - - def _expected_messages(self, expected_prompt_string, expected_binary_data): - return { - "content": [ - {"type": "text", "text": expected_prompt_string}, - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{expected_binary_data}", "detail": "auto"}, - }, - ], - "role": "user", - } diff --git a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py deleted file mode 100644 index c274f71dd..000000000 --- a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.drivers import BedrockClaudeImageQueryModelDriver - - -class TestBedrockClaudeImageQueryModelDriver: - def test_init(self): - assert BedrockClaudeImageQueryModelDriver() - - def test_image_query_request_parameters(self): - model_driver = BedrockClaudeImageQueryModelDriver() - params = model_driver.image_query_request_parameters( - "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")], 256 - ) - - assert isinstance(params, dict) - assert "anthropic_version" in params - assert params["anthropic_version"] == "bedrock-2023-05-31" - assert "messages" in params - assert len(params["messages"]) == 1 - assert "max_tokens" in params - assert params["max_tokens"] == 256 - - def test_process_output(self): - model_driver = BedrockClaudeImageQueryModelDriver() - output = model_driver.process_output({"content": [{"text": "Content"}]}) - - assert isinstance(output, TextArtifact) - assert output.value == "Content" - - def test_process_output_no_content_key(self): - with pytest.raises(KeyError): - BedrockClaudeImageQueryModelDriver().process_output({"explicitly-not-content": ["ContentBlock"]}) - - def test_process_output_bad_length(self): - with pytest.raises(ValueError): - BedrockClaudeImageQueryModelDriver().process_output({"content": []}) - - def test_process_output_no_text_key(self): - with pytest.raises(KeyError): - BedrockClaudeImageQueryModelDriver().process_output({"content": [{"not-text": "Content"}]}) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 3efe85c98..f9ad70573 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -39,6 +39,7 @@ def test_run_via_pipeline_publishes_events(self, mocker): def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) + assert isinstance(MockPromptDriver().run(TextArtifact("")), Message) def test_run_with_stream(self): result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py deleted file mode 100644 index ef196457b..000000000 --- a/tests/unit/tasks/test_image_query_task.py +++ /dev/null @@ -1,74 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.artifacts.list_artifact import ListArtifact -from griptape.structures import Agent -from griptape.tasks import BaseTask, ImageQueryTask -from tests.mocks.mock_image_query_driver import MockImageQueryDriver - - -class TestImageQueryTask: - @pytest.fixture() - def image_query_driver(self) -> Mock: - mock = Mock() - mock.query.return_value = TextArtifact("image") - - return mock - - @pytest.fixture() - def text_artifact(self): - return TextArtifact(value="some text") - - @pytest.fixture() - def image_artifact(self): - return ImageArtifact(value=b"some image data", format="png", width=512, height=512) - - def test_text_inputs(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - task = ImageQueryTask((text_artifact.value, [image_artifact, image_artifact])) - - assert task.input.value[0].value == text_artifact.value - assert task.input.value[1] == image_artifact - assert task.input.value[2] == image_artifact - - def test_artifact_inputs(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - input_tuple = (text_artifact, [image_artifact, image_artifact]) - task = ImageQueryTask(input_tuple) - - assert task.input.value[0] == text_artifact - assert task.input.value[1] == image_artifact - assert task.input.value[2] == image_artifact - - def test_callable_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - artifacts = [text_artifact, image_artifact, image_artifact] - - def callable_input(task: BaseTask) -> ListArtifact: - return ListArtifact(value=artifacts) - - task = ImageQueryTask(callable_input) - - assert task.input.value == artifacts - - def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - artifacts = [text_artifact, image_artifact, image_artifact] - - task = ImageQueryTask(ListArtifact(value=artifacts)) - - assert task.input.value == artifacts - - def test_config_image_generation_driver(self, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - Agent().add_task(task) - - assert isinstance(task.image_query_driver, MockImageQueryDriver) - - def test_run(self, image_query_driver, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_driver=image_query_driver) - task.run() - - assert task.output.value == "image" - - def test_bad_try_run(self, image_query_driver, text_artifact, image_artifact): - with pytest.raises(ValueError, match="All inputs"): - ImageQueryTask(("foo", [image_artifact, text_artifact]), image_query_driver=image_query_driver).try_run() diff --git a/tests/unit/tools/test_image_query_tool.py b/tests/unit/tools/test_image_query_tool.py index 630f1bc4d..c13b41f0d 100644 --- a/tests/unit/tools/test_image_query_tool.py +++ b/tests/unit/tools/test_image_query_tool.py @@ -2,7 +2,7 @@ from griptape.artifacts.image_artifact import ImageArtifact from griptape.tools import ImageQueryTool -from tests.mocks.mock_image_query_driver import MockImageQueryDriver +from tests.mocks.mock_drivers_config import MockPromptDriver from tests.utils import defaults @@ -11,7 +11,7 @@ class TestImageQueryTool: def tool(self): task_memory = defaults.text_task_memory("memory_name") task_memory.store_artifact("namespace", ImageArtifact(b"", format="png", width=1, height=1, name="test")) - return ImageQueryTool(input_memory=[task_memory], image_query_driver=MockImageQueryDriver()) + return ImageQueryTool(input_memory=[task_memory], prompt_driver=MockPromptDriver(mock_output="mock text")) def test_query_image_from_disk(self, tool): assert tool.query_image_from_disk({"values": {"query": "test", "image_paths": []}}).value == "mock text"