Skip to content

Commit

Permalink
Remove ImageQueryTask
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Nov 12, 2024
1 parent ae733f0 commit ab48bb6
Show file tree
Hide file tree
Showing 15 changed files with 87 additions and 241 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `TrafilaturaWebScraperDriver.no_ssl` parameter to disable SSL verification. Defaults to `False`.
- `CsvExtractionEngine.format_header` parameter to format the header row.
- `PromptStack.from_artifact` factory method for creating a Prompt Stack with a user message from an Artifact.

### Changed

Expand All @@ -20,6 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed `Structure.is_executing()` to `Structure.is_running()`.
- **BREAKING**: Removed ability to pass bytes to `BaseFileLoader.fetch`.
- **BREAKING**: Updated `CsvExtractionEngine.format_row` to format rows as comma-separated values instead of newline-separated key-value pairs.
- **BREAKING**: Removed all `ImageQueryDriver`s, use `PromptDriver`s instead.
- **BREAKING**: Removed `ImageQueryTask`, use `PromptTask` instead.
- **BREAKING**: Updated `ImageQueryTool.image_query_driver` to `ImageQueryTool.prompt_driver`.
- `BasePromptDriver.run` can now accept an Artifact in addition to a Prompt Stack.
- Improved `CsvExtractionEngine` prompts.
- Tweaked `PromptResponseRagModule` system prompt to yield answers more consistently.
- Removed `azure-core` and `azure-storage-blob` dependencies.
Expand Down
58 changes: 51 additions & 7 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ loader = TextLoader()
data = loader.parse(b"data")
```

### Removed `ImageQueryEngine`
### Removed `ImageQueryEngine`, `ImageQueryDriver`

`ImageQueryEngine` has been removed. Use `ImageQueryDriver` instead.
`ImageQueryEngine` has been removed. Use `PromptDriver` instead.

#### Before

Expand All @@ -45,15 +45,15 @@ engine.run("Describe the weather in the image", [image_artifact])`
#### After

```python
from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.artifacts import ListArtifact, TextArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.loaders import ImageLoader

driver = OpenAiImageQueryDriver(model="gpt-4o", max_tokens=256)
driver = OpenAiChatPromptDriver(model="gpt-4o", max_tokens=256)

image_artifact = ImageLoader().load("mountain.png")
image_artifact = ImageLoader().load("./assets/mountain.jpg")

driver.query("Describe the weather in the image", [image_artifact])`
driver.run(ListArtifact([TextArtifact("Describe the weather in the image"), image_artifact]))
```

### Removed `InpaintingImageGenerationEngine`
Expand Down Expand Up @@ -209,6 +209,50 @@ driver.run_text_to_image(
)
```

### Removed `ImageQueryTask`, use `PromptTask` instead

`ImageQueryTask` has been removed. Use `PromptTask` instead.

#### Before

```python
from griptape.loaders import ImageLoader
from griptape.structures import Pipeline
from griptape.tasks import ImageQueryTask

image_artifact = ImageLoader().load("mountain.png")

pipeline = Pipeline(
tasks=[
ImageQueryTask(
input=("Describe the weather in the image", [image_artifact]),
)
]
)

pipeline.run("Describe the weather in the image")
```

#### After

```python
from griptape.loaders import ImageLoader
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

image_artifact = ImageLoader().load("mountain.png")

pipeline = Pipeline(
tasks=[
PromptTask(
input=("Describe the weather in the image", image_artifact),
)
]
)

pipeline.run("Describe the weather in the image")
```

## 0.33.X to 0.34.X

### `AnthropicDriversConfig` Embedding Driver
Expand Down
27 changes: 0 additions & 27 deletions docs/griptape-framework/structures/src/tasks_15.py

This file was deleted.

10 changes: 0 additions & 10 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,6 @@ The [Outpainting Image Generation Task](../../reference/griptape/tasks/outpainti
--8<-- "docs/griptape-framework/structures/src/tasks_14.py"
```

## Image Query Task

The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) performs a natural language query on one or more input images. This Task uses an [Prompt Driver](../drivers/prompt-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver.

This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#text)) and a list of [Image Artifacts](../data/artifacts.md#image) or a Callable returning these two values.

```python
--8<-- "docs/griptape-framework/structures/src/tasks_15.py"
```

## Structure Run Task

The [Structure Run Task](../../reference/griptape/tasks/structure_run_task.md) runs another Structure with a given input.
Expand Down
7 changes: 7 additions & 0 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def add_user_message(self, artifact: str | BaseArtifact) -> Message:
def add_assistant_message(self, artifact: str | BaseArtifact) -> Message:
return self.add_message(artifact, Message.ASSISTANT_ROLE)

@classmethod
def from_artifact(cls, artifact: BaseArtifact) -> PromptStack:
prompt_stack = cls()
prompt_stack.add_user_message(artifact)

return prompt_stack

def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessageContent]:
if isinstance(artifact, str):
return [TextMessageContent(TextArtifact(artifact))]
Expand Down
10 changes: 8 additions & 2 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from attrs import Factory, define, field

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
Expand Down Expand Up @@ -71,7 +72,12 @@ def after_run(self, result: Message) -> None:
)

@observable(tags=["PromptDriver.run()"])
def run(self, prompt_stack: PromptStack) -> Message:
def run(self, prompt_input: PromptStack | BaseArtifact) -> Message:
if isinstance(prompt_input, BaseArtifact):
prompt_stack = PromptStack.from_artifact(prompt_input)
else:
prompt_stack = prompt_input

for attempt in self.retrying():
with attempt:
self.before_run(prompt_stack)
Expand All @@ -85,7 +91,7 @@ def run(self, prompt_stack: PromptStack) -> Message:
raise Exception("prompt driver failed after all retry attempts")

def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
"""Converts a Prompt Stack to a string for token counting or model input.
"""Converts a Prompt Stack to a string for token counting or model prompt_input.
This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.
Expand Down
2 changes: 1 addition & 1 deletion griptape/memory/structure/summary_conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str |
if len(runs) > 0:
summary = self.summarize_conversation_get_template.render(summary=previous_summary, runs=runs)
return self.prompt_driver.run(
prompt_stack=PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]),
PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]),
).to_text()
else:
return previous_summary
Expand Down
2 changes: 0 additions & 2 deletions griptape/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .inpainting_image_generation_task import InpaintingImageGenerationTask
from .outpainting_image_generation_task import OutpaintingImageGenerationTask
from .variation_image_generation_task import VariationImageGenerationTask
from .image_query_task import ImageQueryTask
from .base_audio_generation_task import BaseAudioGenerationTask
from .text_to_speech_task import TextToSpeechTask
from .structure_run_task import StructureRunTask
Expand All @@ -35,7 +34,6 @@
"VariationImageGenerationTask",
"InpaintingImageGenerationTask",
"OutpaintingImageGenerationTask",
"ImageQueryTask",
"BaseAudioGenerationTask",
"TextToSpeechTask",
"StructureRunTask",
Expand Down
100 changes: 0 additions & 100 deletions griptape/tasks/image_query_task.py

This file was deleted.

2 changes: 1 addition & 1 deletion griptape/tasks/tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def actions_schema(self) -> Schema:
return self._actions_schema_for_tools([self.tool])

def try_run(self) -> BaseArtifact:
result = self.prompt_driver.run(prompt_stack=self.prompt_stack)
result = self.prompt_driver.run(self.prompt_stack)

if self.prompt_driver.use_native_tools:
subtask_input = result.to_artifact()
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def try_run(self) -> BaseArtifact:
subtask.run()
subtask.after_run()

result = self.prompt_driver.run(prompt_stack=self.prompt_stack)
result = self.prompt_driver.run(self.prompt_stack)
subtask = self.add_subtask(ActionsSubtask(result.to_artifact()))
else:
break
Expand Down
27 changes: 6 additions & 21 deletions griptape/tools/image_query/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from schema import Literal, Schema

from griptape.artifacts import BlobArtifact, ErrorArtifact, ImageArtifact, TextArtifact
from griptape.common import ImageMessageContent, Message, PromptStack, TextMessageContent
from griptape.artifacts.list_artifact import ListArtifact
from griptape.common import PromptStack
from griptape.loaders import ImageLoader
from griptape.tools import BaseTool
from griptape.utils import load_artifact_from_memory
Expand Down Expand Up @@ -46,16 +47,8 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact:
return cast(
TextArtifact,
self.prompt_driver.run(
PromptStack(
messages=[
Message(
role=Message.USER_ROLE,
content=[
TextMessageContent(TextArtifact(query)),
*[ImageMessageContent(image_artifact) for image_artifact in image_artifacts],
],
),
]
PromptStack.from_artifact(
ListArtifact([TextArtifact(query), *image_artifacts]),
)
).to_artifact(),
)
Expand Down Expand Up @@ -113,16 +106,8 @@ def query_images_from_memory(self, params: dict[str, Any]) -> TextArtifact | Err
return cast(
TextArtifact,
self.prompt_driver.run(
PromptStack(
messages=[
Message(
role=Message.USER_ROLE,
content=[
TextMessageContent(TextArtifact(query)),
*[ImageMessageContent(image_artifact) for image_artifact in image_artifacts],
],
),
]
PromptStack.from_artifact(
ListArtifact([TextArtifact(query), *image_artifacts]),
)
).to_artifact(),
)
6 changes: 6 additions & 0 deletions tests/unit/common/test_prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,9 @@ def test_add_assistant_message(self, prompt_stack):

assert prompt_stack.messages[0].role == "assistant"
assert prompt_stack.messages[0].content[0].artifact.value == "foo"

def test_from_artifact(self):
prompt_stack = PromptStack.from_artifact(TextArtifact("foo"))

assert prompt_stack.messages[0].role == "user"
assert prompt_stack.messages[0].content[0].artifact.value == "foo"
Loading

0 comments on commit ab48bb6

Please sign in to comment.