From 661bb407471e2ee153aef1c8e1d2fee8420e0361 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 27 Aug 2024 11:22:51 -0700 Subject: [PATCH 01/52] Refactor Artifacts --- .ignore | 0 CHANGELOG.md | 19 +++ MIGRATION.md | 117 ++++++++++++++++++ docs/griptape-framework/data/artifacts.md | 60 ++++----- docs/griptape-framework/data/loaders.md | 6 +- griptape/artifacts/__init__.py | 18 ++- griptape/artifacts/action_artifact.py | 13 +- griptape/artifacts/audio_artifact.py | 26 +++- griptape/artifacts/base_artifact.py | 35 +++--- griptape/artifacts/base_system_artifact.py | 10 ++ griptape/artifacts/blob_artifact.py | 29 +++-- griptape/artifacts/error_artifact.py | 14 ++- griptape/artifacts/generic_artifact.py | 10 +- griptape/artifacts/image_artifact.py | 35 ++++-- griptape/artifacts/info_artifact.py | 15 ++- griptape/artifacts/json_artifact.py | 20 ++- griptape/artifacts/list_artifact.py | 22 ++-- griptape/artifacts/media_artifact.py | 53 -------- griptape/artifacts/text_artifact.py | 28 +++-- .../amazon_bedrock_image_generation_driver.py | 12 +- ...ngface_pipeline_image_generation_driver.py | 4 +- .../leonardo_image_generation_driver.py | 12 +- .../openai_image_generation_driver.py | 3 +- .../extraction/csv_extraction_engine.py | 10 +- griptape/loaders/csv_loader.py | 10 +- griptape/loaders/dataframe_loader.py | 10 +- griptape/loaders/sql_loader.py | 10 +- ...mixin.py => artifact_file_output_mixin.py} | 8 +- griptape/tasks/base_audio_generation_task.py | 2 +- griptape/tasks/base_image_generation_task.py | 7 +- griptape/tools/base_image_generation_tool.py | 4 +- griptape/tools/query/tool.py | 4 +- griptape/tools/text_to_speech/tool.py | 2 +- tests/unit/artifacts/test_action_artifact.py | 4 - tests/unit/artifacts/test_audio_artifact.py | 14 ++- tests/unit/artifacts/test_base_artifact.py | 8 +- .../artifacts/test_base_media_artifact.py | 30 ----- tests/unit/artifacts/test_blob_artifact.py | 17 +-- tests/unit/artifacts/test_image_artifact.py | 13 +- tests/unit/artifacts/test_json_artifact.py | 14 +-- tests/unit/artifacts/test_list_artifact.py | 6 +- ...table_diffusion_image_generation_driver.py | 4 +- ...st_azure_openai_image_generation_driver.py | 9 +- .../test_leonardo_image_generation_driver.py | 4 +- .../test_openai_image_generation_driver.py | 10 +- .../test_base_local_vector_store_driver.py | 15 --- .../extraction/test_csv_extraction_engine.py | 6 +- tests/unit/loaders/test_audio_loader.py | 6 +- tests/unit/loaders/test_csv_loader.py | 18 ++- tests/unit/loaders/test_dataframe_loader.py | 52 -------- tests/unit/loaders/test_image_loader.py | 19 ++- tests/unit/loaders/test_sql_loader.py | 17 ++- tests/unit/memory/tool/test_task_memory.py | 6 +- .../test_image_artifact_file_output_mixin.py | 12 +- tests/unit/tasks/test_extraction_task.py | 2 +- tests/unit/tasks/test_prompt_task.py | 4 +- tests/unit/tools/test_extraction_tool.py | 4 +- tests/unit/tools/test_file_manager.py | 25 +--- .../test_inpainting_image_generation_tool.py | 8 +- .../test_outpainting_image_variation_tool.py | 8 +- .../test_prompt_image_generation_tool.py | 5 +- tests/unit/tools/test_sql_tool.py | 3 +- tests/unit/tools/test_text_to_speech_tool.py | 3 +- .../test_variation_image_generation_tool.py | 4 +- 64 files changed, 496 insertions(+), 482 deletions(-) create mode 100644 .ignore create mode 100644 griptape/artifacts/base_system_artifact.py delete mode 100644 griptape/artifacts/media_artifact.py rename griptape/mixins/{media_artifact_file_output_mixin.py => artifact_file_output_mixin.py} (87%) delete mode 100644 tests/unit/artifacts/test_base_media_artifact.py delete mode 100644 tests/unit/loaders/test_dataframe_loader.py diff --git a/.ignore b/.ignore new file mode 100644 index 000000000..e69de29bb diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fed1b200..b57bce948 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added +- `BaseArtifact.to_bytes()` method to convert an Artifact to bytes. + +### Changed +- **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. +- **BREAKING**: Removed `BooleanArtifact`, use `JsonArtifact` instead. +- **BREAKING**: Removed `CsvRowArtifact`. +- **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. +- **BREAKING**: Removed `ImageArtifact.media_type`. +- **BREAKING**: Removed `AudioArtifact.media_type`. +- **BREAKING**: Removed `BlobArtifact.dir_name`. +- **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. +- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. +- Updated `JsonArtifact` value converter to properly handle more types. +- `AudioArtifact` now subclasses `BaseArtifact` instead of `MediaArtifact`. +- `ImageArtifact` now subclasses `BaseArtifact` instead of `MediaArtifact`. +- Passing a dictionary as the value to `TextArtifact` will convert to a key-value formatted string. +- Removed `__add__` method from `BaseArtifact`, implemented it where necessary. + ## [0.31.0] - 2024-09-03 **Note**: This release includes breaking changes. Please refer to the [Migration Guide](./MIGRATION.md#030x-to-031x) for details. diff --git a/MIGRATION.md b/MIGRATION.md index af8835e5b..39932d4fa 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,6 +1,123 @@ # Migration Guide This document provides instructions for migrating your codebase to accommodate breaking changes introduced in new versions of Griptape. +## 0.31.X to 0.32.X + +### Removed `MediaArtifact` + +`MediaArtifact` has been removed. Use `ImageArtifact` or `AudioArtifact` instead. + +#### Before + +```python +image_media = MediaArtifact( + b"image_data", + media_type="image", + format="jpeg" +) + +audio_media = MediaArtifact( + b"audio_data", + media_type="audio", + format="wav" +) +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg" +) + +audio_artifact = AudioArtifact( + b"audio_data", + format="wav" +) +``` + +### Removed `BooleanArtifact` + +`BooleanArtifact` has been removed. Use `JsonArtifact` instead. + +#### Before + +```python +boolean_artifact = BooleanArtifact("true") + +print(boolean_artifact.value) # Value is True +``` + +#### After +```python +json_artifact = JsonArtifact("true") + +print(json_artifact.value) # Value is True +``` + +### Removed `CsvRowArtifact` + +`CsvRowArtifact` has been removed. Use `TextArtifact` instead. + +#### Before + +```python +CsvRowArtifact({"name": "John", "age": 30}) +``` + +#### After +```python +TextArtifact("name: John\nAge: 30") +``` + +### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types + +`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a tuple of `list[TextArtifact]` instead of `list[CsvRowArtifact]`. + +#### Before + +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # {"name": "John", "age": 30} +``` + +#### After +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # name: John\nAge: 30 +print(results[0].meta["row"]) # 0 +``` + +### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` + +`ImageArtifact.prompt` and `ImageArtifact.model` have been moved to `ImageArtifact.meta`. + +#### Before + +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg", + prompt="Generate an image of a cat", + model="DALL-E" +) + +print(image_artifact.prompt, image_artifact.model) # Generate an image of a cat, DALL-E +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg", + meta={"prompt": "Generate an image of a cat", "model": "DALL-E"} +) + +print(image_artifact.meta["prompt"], image_artifact.meta["model"]) # Generate an image of a cat, DALL-E +``` + ## 0.30.X to 0.31.X diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index 8c4da02b3..b90636731 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -5,60 +5,54 @@ search: ## Overview -**[Artifacts](../../reference/griptape/artifacts/base_artifact.md)** are used for passing different types of data between Griptape components. All tools return artifacts that are later consumed by tasks and task memory. -Artifacts make sure framework components enforce contracts when passing and consuming data. + +**[Artifacts](../../reference/griptape/artifacts/base_artifact.md)** are used to store data that can be provided as input to or received as output from a Language Learning Model (LLM). ## Text -A [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) for passing text data of arbitrary size around the framework. It can be used to count tokens with [token_count()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.token_count) with a tokenizer. -It can also be used to generate a text embedding with [generate_embedding()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.generate_embedding) -and access it with [embedding](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.embedding). +[TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s store textual data. They can be used to count tokens using the [token_count()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.token_count) method with a tokenizer, generate a text embedding through the [generate_embedding()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.generate_embedding) method, and access the embedding with the [embedding](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.embedding) property. -[TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s returned by tool activities and returns artifact IDs back to the LLM. +[TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores `TextArtifacts` returned by tool activities and provides their IDs back to the LLM. -## Csv Row +## Image -A [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md) for passing structured row data around the framework. It inherits from [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) and overrides the -[to_text()](../../reference/griptape/artifacts/csv_row_artifact.md#griptape.artifacts.csv_row_artifact.CsvRowArtifact.to_text) method, which always returns a valid CSV row. +[ImageArtifact](../../reference/griptape/artifacts/image_artifact.md)s store image data. They include binary image data and metadata such as MIME type, dimensions, and prompt and model information for images returned by [image generation drivers](../drivers/image-generation-drivers.md). They inherit functionality from [BlobArtifacts](#blob). -## Info +## Audio -An [InfoArtifact](../../reference/griptape/artifacts/info_artifact.md) for passing short notifications back to the LLM without task memory storing them. +[AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md)s store audio content, including binary audio data and metadata such as format, duration, and prompt and model information for audio returned by generative models. They inherit from [BlobArtifacts](#blob). -## Error +## Action -An [ErrorArtifact](../../reference/griptape/artifacts/error_artifact.md) is used for passing errors back to the LLM without task memory storing them. +[ActionArtifact](../../reference/griptape/artifacts/action_artifact.md)s represent actions taken by the LLM. Currently, the only supported action is [ToolAction](../../reference/griptape/common/actions/tool_action.md), which is used to execute a [Tool](../../griptape-framework/tools/index.md). -## Blob +## JSON -A [BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md) for passing binary large objects (blobs) back to the LLM. -Treat it as a way to return unstructured data, such as images, videos, audio, and other files back from tools. -Each blob has a [name](../../reference/griptape/artifacts/base_artifact.md#griptape.artifacts.base_artifact.BaseArtifact.name) and -[dir](../../reference/griptape/artifacts/blob_artifact.md#griptape.artifacts.blob_artifact.BlobArtifact.dir_name) to uniquely identify stored objects. +[JsonArtifact](../../reference/griptape/artifacts/json_artifact.md)s store JSON-serializable data. Any data assigned to the `value` property is converted using `json.dumps(json.loads(value))`. -[TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores [BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s returned by tool activities that can be reused by other tools. +## Generic -## Image +[GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md)s act as an escape hatch for passing any type of data that does not fit into any other artifact type. While generally not recommended, they are suitable for specific scenarios. For example, see [talking to a video](../../examples/talk-to-a-video.md), which demonstrates using a `GenericArtifact` to pass a Gemini-specific video file. -An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an Image Artifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blob). +## System Artifacts -## Audio +These Artifacts don't map to an LLM modality. They must be transformed in some way before they can be used as LLM input. -An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blob). +### Blob -## Boolean +[BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s store binary large objects (blobs) and are used to pass unstructured data back to the LLM via [InfoArtifact](#info). -A [BooleanArtifact](../../reference/griptape/artifacts/boolean_artifact.md) is used for passing boolean values around the framework. +`TaskMemory` automatically stores `BlobArtifacts` returned by tool activities, allowing them to be reused by other tools. -!!! info - Any object passed on init to `BooleanArtifact` will be coerced into a `bool` type. This might lead to unintended behavior: `BooleanArtifact("False").value is True`. Use [BooleanArtifact.parse_bool](../../reference/griptape/artifacts/boolean_artifact.md#griptape.artifacts.boolean_artifact.BooleanArtifact.parse_bool) to convert case-insensitive string literal values `"True"` and `"False"` into a `BooleanArtifact`: `BooleanArtifact.parse_bool("False").value is False`. +### Info -## Generic +[InfoArtifact](../../reference/griptape/artifacts/info_artifact.md)s store short notifications that are passed back to the LLM without being stored in Task Memory. + +### Error + +[ErrorArtifact](../../reference/griptape/artifacts/error_artifact.md)s store errors that are passed back to the LLM without being stored in Task Memory. -A [GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md) can be used as an escape hatch for passing any type of data around the framework. -It is generally not recommended to use this Artifact type, but it can be used in a handful of situations where no other Artifact type fits the data being passed. -See [talking to a video](../../examples/talk-to-a-video.md) for an example of using a `GenericArtifact` to pass a Gemini-specific video file. +### List -## Json +[ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s store lists of Artifacts that can be passed to the LLM. -A [JsonArtifact](../../reference/griptape/artifacts/json_artifact.md) is used for passing JSON-serliazable data around the framework. Anything passed to `value` will be converted using `json.dumps(json.loads(value))`. diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 914fdee2a..0c0fc3ead 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -22,7 +22,7 @@ Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) ## SQL -Can be used to load data from a SQL database into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: +Can be used to load data from a SQL database into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_2.py" @@ -30,7 +30,7 @@ Can be used to load data from a SQL database into [CsvRowArtifact](../../referen ## CSV -Can be used to load CSV files into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: +Can be used to load CSV files into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_3.py" @@ -42,7 +42,7 @@ Can be used to load CSV files into [CsvRowArtifact](../../reference/griptape/art !!! info This driver requires the `loaders-dataframe` [extra](../index.md#extras). -Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: +Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_4.py" diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index f39bfea8d..a647d0578 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,30 +1,28 @@ from .base_artifact import BaseArtifact -from .error_artifact import ErrorArtifact -from .info_artifact import InfoArtifact +from .base_system_artifact import BaseSystemArtifact + from .text_artifact import TextArtifact -from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact -from .boolean_artifact import BooleanArtifact -from .csv_row_artifact import CsvRowArtifact -from .list_artifact import ListArtifact -from .media_artifact import MediaArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact +from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact +from .error_artifact import ErrorArtifact +from .info_artifact import InfoArtifact +from .list_artifact import ListArtifact + __all__ = [ "BaseArtifact", + "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", "JsonArtifact", "BlobArtifact", - "BooleanArtifact", - "CsvRowArtifact", "ListArtifact", - "MediaArtifact", "ImageArtifact", "AudioArtifact", "ActionArtifact", diff --git a/griptape/artifacts/action_artifact.py b/griptape/artifacts/action_artifact.py index 9772bbbab..d882d0638 100644 --- a/griptape/artifacts/action_artifact.py +++ b/griptape/artifacts/action_artifact.py @@ -5,15 +5,20 @@ from attrs import define, field from griptape.artifacts import BaseArtifact -from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from griptape.common import ToolAction @define() -class ActionArtifact(BaseArtifact, SerializableMixin): +class ActionArtifact(BaseArtifact): + """Represents the LLM taking an action to use a Tool. + + Attributes: + value: The Action to take. Currently only supports ToolAction. + """ + value: ToolAction = field(metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> ActionArtifact: - raise NotImplementedError + def to_text(self) -> str: + return str(self.value) diff --git a/griptape/artifacts/audio_artifact.py b/griptape/artifacts/audio_artifact.py index 3dc67fa36..77febf332 100644 --- a/griptape/artifacts/audio_artifact.py +++ b/griptape/artifacts/audio_artifact.py @@ -1,12 +1,28 @@ from __future__ import annotations -from attrs import define +from attrs import define, field -from griptape.artifacts import MediaArtifact +from griptape.artifacts import BaseArtifact @define -class AudioArtifact(MediaArtifact): - """AudioArtifact is a type of MediaArtifact representing audio.""" +class AudioArtifact(BaseArtifact): + """Stores audio data. - media_type: str = "audio" + Attributes: + value: The audio data. + format: The audio format, e.g. "wav" or "mp3". + """ + + value: bytes = field(metadata={"serializable": True}) + format: str = field(kw_only=True, metadata={"serializable": True}) + + @property + def mime_type(self) -> str: + return f"audio/{self.format}" + + def to_bytes(self) -> bytes: + return self.value + + def to_text(self) -> str: + return f"Audio, format: {self.format}, size: {len(self.value)} bytes" diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index 82a0bbd23..8f94adbfe 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import uuid from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional @@ -15,6 +14,18 @@ @define class BaseArtifact(SerializableMixin, ABC): + """Serves as the base class for all Artifacts. + + Artifacts are used to store data that can be provided as input to, or received as output from, a language model (LLM). + + Attributes: + id: The unique identifier of the Artifact. Defaults to a random UUID. + reference: The optional Reference to the Artifact. + meta: The metadata associated with the Artifact. Defaults to an empty dictionary. + name: The name of the Artifact. Defaults to the id. + value: The value of the Artifact. + """ + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) reference: Optional[Reference] = field(default=None, kw_only=True, metadata={"serializable": True}) meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) @@ -24,22 +35,7 @@ class BaseArtifact(SerializableMixin, ABC): metadata={"serializable": True}, ) value: Any = field() - - @classmethod - def value_to_bytes(cls, value: Any) -> bytes: - if isinstance(value, bytes): - return value - else: - return str(value).encode() - - @classmethod - def value_to_dict(cls, value: Any) -> dict: - dict_value = value if isinstance(value, dict) else json.loads(value) - - return dict(dict_value.items()) - - def to_text(self) -> str: - return str(self.value) + encoding: str = field(default="utf-8", kw_only=True) def __str__(self) -> str: return self.to_text() @@ -50,5 +46,8 @@ def __bool__(self) -> bool: def __len__(self) -> int: return len(self.value) + def to_bytes(self) -> bytes: + return self.to_text().encode() + @abstractmethod - def __add__(self, other: BaseArtifact) -> BaseArtifact: ... + def to_text(self) -> str: ... diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py new file mode 100644 index 000000000..e5e7bfab6 --- /dev/null +++ b/griptape/artifacts/base_system_artifact.py @@ -0,0 +1,10 @@ +from abc import ABC + +from griptape.artifacts import BaseArtifact + + +class BaseSystemArtifact(BaseArtifact, ABC): + """Serves as the base class for all Artifacts specific to Griptape.""" + + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 0c0dcc122..072b6cb5b 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -1,26 +1,35 @@ from __future__ import annotations -import os.path -from typing import Optional +from typing import Any from attrs import define, field from griptape.artifacts import BaseArtifact +def value_to_bytes(value: Any) -> bytes: + if isinstance(value, bytes): + return value + else: + return str(value).encode() + + @define class BlobArtifact(BaseArtifact): - value: bytes = field(converter=BaseArtifact.value_to_bytes, metadata={"serializable": True}) - dir_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + """Stores arbitrary binary data. + + Attributes: + value: The binary data. + encoding: The encoding to use when converting the binary data to text. + encoding_error_handler: The error handler to use when converting the binary data to text. + """ + + value: bytes = field(converter=value_to_bytes, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) - def __add__(self, other: BaseArtifact) -> BlobArtifact: - return BlobArtifact(self.value + other.value, name=self.name) - - @property - def full_path(self) -> str: - return os.path.join(self.dir_name, self.name) if self.dir_name else self.name + def to_bytes(self) -> bytes: + return self.value def to_text(self) -> str: return self.value.decode(encoding=self.encoding, errors=self.encoding_error_handler) diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index d065d754b..086cdb893 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,13 +4,17 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define -class ErrorArtifact(BaseArtifact): +class ErrorArtifact(BaseSystemArtifact): + """Represents an error that may want to be conveyed to the LLM. + + Attributes: + value: The error message. + exception: The exception that caused the error. Defaults to None. + """ + value: str = field(converter=str, metadata={"serializable": True}) exception: Optional[Exception] = field(default=None, kw_only=True, metadata={"serializable": False}) - - def __add__(self, other: BaseArtifact) -> ErrorArtifact: - return ErrorArtifact(self.value + other.value) diff --git a/griptape/artifacts/generic_artifact.py b/griptape/artifacts/generic_artifact.py index 8e0b7e38c..e90f40ef0 100644 --- a/griptape/artifacts/generic_artifact.py +++ b/griptape/artifacts/generic_artifact.py @@ -9,7 +9,13 @@ @define class GenericArtifact(BaseArtifact): + """Serves as an escape hatch for artifacts that don't fit into any other category. + + Attributes: + value: The value of the Artifact. + """ + value: Any = field(metadata={"serializable": True}) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - raise NotImplementedError + def to_text(self) -> str: + return str(self.value) diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index e963b3881..28de3b1c9 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -1,23 +1,38 @@ from __future__ import annotations +import base64 + from attrs import define, field -from griptape.artifacts import MediaArtifact +from griptape.artifacts import BaseArtifact @define -class ImageArtifact(MediaArtifact): - """ImageArtifact is a type of MediaArtifact representing an image. +class ImageArtifact(BaseArtifact): + """Stores image data. Attributes: - value: Raw bytes representing media data. - media_type: The type of media, defaults to "image". - format: The format of the media, like png, jpeg, or gif. - name: Artifact name, generated using creation time and a random string. - model: Optionally specify the model used to generate the media. - prompt: Optionally specify the prompt used to generate the media. + value: The image data. + format: The format of the image data. Defaults to "png". + width: The width of the image. + height: The height of the image """ - media_type: str = "image" + value: bytes = field(metadata={"serializable": True}) + format: str = field(default="png", kw_only=True, metadata={"serializable": True}) width: int = field(kw_only=True, metadata={"serializable": True}) height: int = field(kw_only=True, metadata={"serializable": True}) + + @property + def base64(self) -> str: + return base64.b64encode(self.value).decode("utf-8") + + @property + def mime_type(self) -> str: + return f"image/{self.format}" + + def to_bytes(self) -> bytes: + return self.value + + def to_text(self) -> str: + return self.base64 diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 26fe6366b..66fe94fc9 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,12 +2,17 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define -class InfoArtifact(BaseArtifact): - value: str = field(converter=str, metadata={"serializable": True}) +class InfoArtifact(BaseSystemArtifact): + """Represents helpful info that can be conveyed to the LLM. + + For example, "No results found" or "Please try again.". - def __add__(self, other: BaseArtifact) -> InfoArtifact: - return InfoArtifact(self.value + other.value) + Attributes: + value: The info to convey. + """ + + value: str = field(converter=str, metadata={"serializable": True}) diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index b292879a9..b3d45e4d9 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Union +from typing import Any, Union from attrs import define, field @@ -10,12 +10,22 @@ Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] +def value_to_json(value: Any) -> Json: + if isinstance(value, str): + return json.loads(value) + else: + return json.loads(json.dumps(value)) + + @define class JsonArtifact(BaseArtifact): - value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True}) + """Stores JSON data. + + Attributes: + value: The JSON data. Values will automatically be converted to a JSON-compatible format. + """ + + value: Json = field(converter=value_to_json, metadata={"serializable": True}) def to_text(self) -> str: return json.dumps(self.value) - - def __add__(self, other: BaseArtifact) -> JsonArtifact: - raise NotImplementedError diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 298f29c6a..00df83789 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,18 +4,27 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseArtifact, BaseSystemArtifact if TYPE_CHECKING: from collections.abc import Sequence @define -class ListArtifact(BaseArtifact): +class ListArtifact(BaseSystemArtifact): value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + def __getitem__(self, key: int) -> BaseArtifact: + return self.value[key] + + def __bool__(self) -> bool: + return len(self) > 0 + + def __add__(self, other: BaseArtifact) -> BaseArtifact: + return ListArtifact(self.value + other.value) + @value.validator # pyright: ignore[reportAttributeAccessIssue] def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None: if self.validate_uniform_types and len(value) > 0: @@ -31,18 +40,9 @@ def child_type(self) -> Optional[type]: else: return None - def __getitem__(self, key: int) -> BaseArtifact: - return self.value[key] - - def __bool__(self) -> bool: - return len(self) > 0 - def to_text(self) -> str: return self.item_separator.join([v.to_text() for v in self.value]) - def __add__(self, other: BaseArtifact) -> BaseArtifact: - return ListArtifact(self.value + other.value) - def is_type(self, target_type: type) -> bool: if self.value: return isinstance(self.value[0], target_type) diff --git a/griptape/artifacts/media_artifact.py b/griptape/artifacts/media_artifact.py deleted file mode 100644 index a57217fc7..000000000 --- a/griptape/artifacts/media_artifact.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import base64 -import random -import string -import time -from typing import Optional - -from attrs import define, field - -from griptape.artifacts import BlobArtifact - - -@define -class MediaArtifact(BlobArtifact): - """MediaArtifact is a type of BlobArtifact that represents media (image, audio, video, etc.) and can be extended to support a specific media type. - - Attributes: - value: Raw bytes representing media data. - media_type: The type of media, like image, audio, or video. - format: The format of the media, like png, wav, or mp4. - name: Artifact name, generated using creation time and a random string. - model: Optionally specify the model used to generate the media. - prompt: Optionally specify the prompt used to generate the media. - """ - - media_type: str = field(default="media", kw_only=True, metadata={"serializable": True}) - format: str = field(kw_only=True, metadata={"serializable": True}) - model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - - def __attrs_post_init__(self) -> None: - # Generating the name string requires attributes set by child classes. - # This waits until all attributes are available before generating a name. - if self.name == self.id: - self.name = self.make_name() - - @property - def mime_type(self) -> str: - return f"{self.media_type}/{self.format}" - - @property - def base64(self) -> str: - return base64.b64encode(self.value).decode("utf-8") - - def to_text(self) -> str: - return f"Media, type: {self.mime_type}, size: {len(self.value)} bytes" - - def make_name(self) -> str: - entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) - fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime()) - - return f"{self.media_type}_artifact_{fmt_time}_{entropy}.{self.format}" diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 752f66615..3751d2c89 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import define, field @@ -11,23 +11,36 @@ from griptape.tokenizers import BaseTokenizer +def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return "\n".join(f"{key}: {val}" for key, val in value.items()) + else: + return str(value) + + @define class TextArtifact(BaseArtifact): - value: str = field(converter=str, metadata={"serializable": True}) + value: str = field(converter=value_to_str, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) _embedding: list[float] = field(factory=list, kw_only=True) - @property - def embedding(self) -> Optional[list[float]]: - return None if len(self._embedding) == 0 else self._embedding - def __add__(self, other: BaseArtifact) -> TextArtifact: return TextArtifact(self.value + other.value) def __bool__(self) -> bool: return bool(self.value.strip()) + @property + def embedding(self) -> Optional[list[float]]: + return None if len(self._embedding) == 0 else self._embedding + + def to_text(self) -> str: + return self.value + + def to_bytes(self) -> bytes: + return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) + def generate_embedding(self, driver: BaseEmbeddingDriver) -> Optional[list[float]]: self._embedding.clear() self._embedding.extend(driver.embed_string(str(self.value))) @@ -36,6 +49,3 @@ def generate_embedding(self, driver: BaseEmbeddingDriver) -> Optional[list[float def token_count(self, tokenizer: BaseTokenizer) -> int: return tokenizer.count_tokens(str(self.value)) - - def to_bytes(self) -> bytes: - return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index 7106c8192..4db302f6f 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -46,12 +46,11 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=self.image_width, height=self.image_height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_variation( @@ -70,12 +69,11 @@ def try_image_variation( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_inpainting( @@ -96,12 +94,11 @@ def try_image_inpainting( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_outpainting( @@ -122,12 +119,11 @@ def try_image_outpainting( image_bytes = self._make_request(request) return ImageArtifact( - prompt=", ".join(prompts), value=image_bytes, format="png", width=image.width, height=image.height, - model=self.model, + meta={"prompt": ", ".join(prompts), "model": self.model}, ) def _make_request(self, request: dict) -> bytes: diff --git a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py index 46dbcd331..b89df1c4b 100644 --- a/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py +++ b/griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py @@ -44,7 +44,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ format=self.output_format.lower(), height=output_image.height, width=output_image.width, - prompt=prompt, + meta={"prompt": prompt}, ) def try_image_variation( @@ -76,7 +76,7 @@ def try_image_variation( format=self.output_format.lower(), height=output_image.height, width=output_image.width, - prompt=prompt, + meta={"prompt": prompt}, ) def try_image_inpainting( diff --git a/griptape/drivers/image_generation/leonardo_image_generation_driver.py b/griptape/drivers/image_generation/leonardo_image_generation_driver.py index e32dbb4c7..db89244bf 100644 --- a/griptape/drivers/image_generation/leonardo_image_generation_driver.py +++ b/griptape/drivers/image_generation/leonardo_image_generation_driver.py @@ -60,8 +60,10 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[ format="png", width=self.image_width, height=self.image_height, - model=self.model, - prompt=", ".join(prompts), + meta={ + "model": self.model, + "prompt": ", ".join(prompts), + }, ) def try_image_variation( @@ -87,8 +89,10 @@ def try_image_variation( format="png", width=self.image_width, height=self.image_height, - model=self.model, - prompt=", ".join(prompts), + meta={ + "model": self.model, + "prompt": ", ".join(prompts), + }, ) def try_image_outpainting( diff --git a/griptape/drivers/image_generation/openai_image_generation_driver.py b/griptape/drivers/image_generation/openai_image_generation_driver.py index 0ee50a1e2..bf77ac300 100644 --- a/griptape/drivers/image_generation/openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/openai_image_generation_driver.py @@ -151,6 +151,5 @@ def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageA format="png", width=image_dimensions[0], height=image_dimensions[1], - model=self.model, - prompt=prompt, + meta={"model": self.model, "prompt": prompt}, ) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index b45bdf7f5..21d7820ac 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -36,22 +36,22 @@ def extract( item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: + def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row])))) + rows.append(TextArtifact(dict(zip(column_names, [x.strip() for x in row])))) return rows def _extract_rec( self, artifacts: list[TextArtifact], - rows: list[CsvRowArtifact], + rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[CsvRowArtifact]: + ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( column_names=self.column_names, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 14dfe3e4a..20ac237c5 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,7 +6,7 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -19,7 +19,7 @@ class CsvLoader(BaseLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: artifacts = [] if isinstance(source, bytes): @@ -28,7 +28,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [CsvRowArtifact(row) for row in reader] + chunks = [TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)] if self.embedding_driver: for chunk in chunks: @@ -44,8 +44,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[CsvRowArtifact]]: + ) -> dict[str, list[TextArtifact]]: return cast( - dict[str, list[CsvRowArtifact]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 0b1ae1448..5ecb35ecd 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -19,10 +19,10 @@ class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: artifacts = [] - chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] + chunks = [TextArtifact(row) for row in source.to_dict(orient="records")] if self.embedding_driver: for chunk in chunks: @@ -33,8 +33,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index e4522796f..14320911e 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts.text_artifact import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -16,11 +16,11 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] - chunks = [CsvRowArtifact(row.cells) for row in rows] if rows else [] + chunks = [TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else [] if self.embedding_driver: for chunk in chunks: @@ -31,5 +31,5 @@ def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/mixins/media_artifact_file_output_mixin.py b/griptape/mixins/artifact_file_output_mixin.py similarity index 87% rename from griptape/mixins/media_artifact_file_output_mixin.py rename to griptape/mixins/artifact_file_output_mixin.py index 9b9f34911..25ed8718d 100644 --- a/griptape/mixins/media_artifact_file_output_mixin.py +++ b/griptape/mixins/artifact_file_output_mixin.py @@ -7,11 +7,11 @@ from attrs import Attribute, define, field if TYPE_CHECKING: - from griptape.artifacts import BlobArtifact + from griptape.artifacts import BaseArtifact @define(slots=False) -class BlobArtifactFileOutputMixin: +class ArtifactFileOutputMixin: output_dir: Optional[str] = field(default=None, kw_only=True) output_file: Optional[str] = field(default=None, kw_only=True) @@ -31,7 +31,7 @@ def validate_output_file(self, _: Attribute, output_file: str) -> None: if self.output_dir: raise ValueError("Can't have both output_dir and output_file specified.") - def _write_to_file(self, artifact: BlobArtifact) -> None: + def _write_to_file(self, artifact: BaseArtifact) -> None: if self.output_file: outfile = self.output_file elif self.output_dir: @@ -42,4 +42,4 @@ def _write_to_file(self, artifact: BlobArtifact) -> None: if os.path.dirname(outfile): os.makedirs(os.path.dirname(outfile), exist_ok=True) - Path(outfile).write_bytes(artifact.value) + Path(outfile).write_bytes(artifact.to_bytes()) diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index fae217d54..866c9aa5c 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -14,7 +14,7 @@ @define -class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): +class BaseAudioGenerationTask(ArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): def before_run(self) -> None: super().before_run() diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index bd36d0080..57dc99506 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -16,14 +16,13 @@ from griptape.tasks import BaseTask if TYPE_CHECKING: - from griptape.artifacts import MediaArtifact - + from griptape.artifacts import ImageArtifact logger = logging.getLogger(Defaults.logging_config.logger_name) @define -class BaseImageGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): +class BaseImageGenerationTask(ArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): """Provides a base class for image generation-related tasks. Attributes: @@ -65,6 +64,6 @@ def all_negative_rulesets(self) -> list[Ruleset]: return task_rulesets - def _read_from_file(self, path: str) -> MediaArtifact: + def _read_from_file(self, path: str) -> ImageArtifact: logger.info("Reading image from %s", os.path.abspath(path)) return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/tools/base_image_generation_tool.py b/griptape/tools/base_image_generation_tool.py index ee1c37b6d..2df5d9747 100644 --- a/griptape/tools/base_image_generation_tool.py +++ b/griptape/tools/base_image_generation_tool.py @@ -1,11 +1,11 @@ from attrs import define -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.tools import BaseTool @define -class BaseImageGenerationTool(BlobArtifactFileOutputMixin, BaseTool): +class BaseImageGenerationTool(ArtifactFileOutputMixin, BaseTool): """A base class for tools that generate images from text prompts.""" PROMPT_DESCRIPTION = "Features and qualities to include in the generated image, descriptive and succinct." diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 0089970e9..0274e7940 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from schema import Literal, Or, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact from griptape.configs import Defaults from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -60,7 +60,7 @@ class QueryTool(BaseTool, RuleMixin): ), }, ) - def query(self, params: dict) -> BaseArtifact: + def query(self, params: dict) -> ListArtifact | ErrorArtifact: query = params["values"]["query"] content = params["values"]["content"] diff --git a/griptape/tools/text_to_speech/tool.py b/griptape/tools/text_to_speech/tool.py index ea4982029..23b82333a 100644 --- a/griptape/tools/text_to_speech/tool.py +++ b/griptape/tools/text_to_speech/tool.py @@ -15,7 +15,7 @@ @define -class TextToSpeechTool(BlobArtifactFileOutputMixin, BaseTool): +class TextToSpeechTool(ArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate speech from input text. Attributes: diff --git a/tests/unit/artifacts/test_action_artifact.py b/tests/unit/artifacts/test_action_artifact.py index 2530ed8c3..b7180b1c3 100644 --- a/tests/unit/artifacts/test_action_artifact.py +++ b/tests/unit/artifacts/test_action_artifact.py @@ -11,10 +11,6 @@ class TestActionArtifact: def action(self) -> ToolAction: return ToolAction(tag="TestTag", name="TestName", path="TestPath", input={"foo": "bar"}) - def test___add__(self, action): - with pytest.raises(NotImplementedError): - ActionArtifact(action) + ActionArtifact(action) - def test_to_text(self, action): assert ActionArtifact(action).to_text() == json.dumps(action.to_dict()) diff --git a/tests/unit/artifacts/test_audio_artifact.py b/tests/unit/artifacts/test_audio_artifact.py index 6d44c05b3..aab6af630 100644 --- a/tests/unit/artifacts/test_audio_artifact.py +++ b/tests/unit/artifacts/test_audio_artifact.py @@ -6,20 +6,22 @@ class TestAudioArtifact: @pytest.fixture() def audio_artifact(self): - return AudioArtifact(value=b"some binary audio data", format="pcm", model="provider/model", prompt="two words") + return AudioArtifact( + value=b"some binary audio data", format="pcm", meta={"model": "provider/model", "prompt": "two words"} + ) def test_mime_type(self, audio_artifact: AudioArtifact): assert audio_artifact.mime_type == "audio/pcm" def test_to_text(self, audio_artifact: AudioArtifact): - assert audio_artifact.to_text() == "Media, type: audio/pcm, size: 22 bytes" + assert audio_artifact.to_text() == "Audio, format: pcm, size: 22 bytes" def test_to_dict(self, audio_artifact: AudioArtifact): audio_dict = audio_artifact.to_dict() assert audio_dict["format"] == "pcm" - assert audio_dict["model"] == "provider/model" - assert audio_dict["prompt"] == "two words" + assert audio_dict["meta"]["model"] == "provider/model" + assert audio_dict["meta"]["prompt"] == "two words" assert audio_dict["value"] == "c29tZSBiaW5hcnkgYXVkaW8gZGF0YQ==" def test_deserialization(self, audio_artifact): @@ -31,5 +33,5 @@ def test_deserialization(self, audio_artifact): assert deserialized_artifact.value == b"some binary audio data" assert deserialized_artifact.mime_type == "audio/pcm" assert deserialized_artifact.format == "pcm" - assert deserialized_artifact.model == "provider/model" - assert deserialized_artifact.prompt == "two words" + assert deserialized_artifact.meta["model"] == "provider/model" + assert deserialized_artifact.meta["prompt"] == "two words" diff --git a/tests/unit/artifacts/test_base_artifact.py b/tests/unit/artifacts/test_base_artifact.py index 6cf8f4466..e2060879f 100644 --- a/tests/unit/artifacts/test_base_artifact.py +++ b/tests/unit/artifacts/test_base_artifact.py @@ -41,7 +41,7 @@ def test_list_artifact_from_dict(self): assert artifact.to_text() == "foobar" def test_blob_artifact_from_dict(self): - dict_value = {"type": "BlobArtifact", "value": b"Zm9vYmFy", "dir_name": "foo", "name": "bar"} + dict_value = {"type": "BlobArtifact", "value": b"Zm9vYmFy", "name": "bar"} artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, BlobArtifact) @@ -51,17 +51,15 @@ def test_image_artifact_from_dict(self): dict_value = { "type": "ImageArtifact", "value": b"aW1hZ2UgZGF0YQ==", - "dir_name": "foo", "format": "png", "width": 256, "height": 256, - "model": "test-model", - "prompt": "some prompt", + "meta": {"model": "test-model", "prompt": "some prompt"}, } artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, ImageArtifact) - assert artifact.to_text() == "Media, type: image/png, size: 10 bytes" + assert artifact.to_text() == "aW1hZ2UgZGF0YQ==" assert artifact.value == b"image data" def test_unsupported_from_dict(self): diff --git a/tests/unit/artifacts/test_base_media_artifact.py b/tests/unit/artifacts/test_base_media_artifact.py deleted file mode 100644 index c85d070fe..000000000 --- a/tests/unit/artifacts/test_base_media_artifact.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from attrs import define - -from griptape.artifacts import MediaArtifact - - -class TestMediaArtifact: - @define - class ImaginaryMediaArtifact(MediaArtifact): - media_type: str = "imagination" - - @pytest.fixture() - def media_artifact(self): - return self.ImaginaryMediaArtifact(value=b"some binary dream data", format="dream") - - def test_to_dict(self, media_artifact): - image_dict = media_artifact.to_dict() - - assert image_dict["format"] == "dream" - assert image_dict["value"] == "c29tZSBiaW5hcnkgZHJlYW0gZGF0YQ==" - - def test_name(self, media_artifact): - assert media_artifact.name.startswith("imagination_artifact") - assert media_artifact.name.endswith(".dream") - - def test_mime_type(self, media_artifact): - assert media_artifact.mime_type == "imagination/dream" - - def test_to_text(self, media_artifact): - assert media_artifact.to_text() == "Media, type: imagination/dream, size: 22 bytes" diff --git a/tests/unit/artifacts/test_blob_artifact.py b/tests/unit/artifacts/test_blob_artifact.py index 3d88d5793..d0f04b1e5 100644 --- a/tests/unit/artifacts/test_blob_artifact.py +++ b/tests/unit/artifacts/test_blob_artifact.py @@ -1,5 +1,4 @@ import base64 -import os import pytest @@ -30,32 +29,22 @@ def test_to_text_encoding_error_handler(self): ) def test_to_dict(self): - assert BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo").to_dict()["name"] == "foobar.txt" - - def test_full_path_with_path(self): - assert BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo").full_path == os.path.normpath( - "foo/foobar.txt" - ) - - def test_full_path_without_path(self): - assert BlobArtifact(b"foobar", name="foobar.txt").full_path == "foobar.txt" + assert BlobArtifact(b"foobar", name="foobar.txt").to_dict()["name"] == "foobar.txt" def test_serialization(self): - artifact = BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo") + artifact = BlobArtifact(b"foobar", name="foobar.txt") artifact_dict = artifact.to_dict() assert artifact_dict["name"] == "foobar.txt" - assert artifact_dict["dir_name"] == "foo" assert base64.b64decode(artifact_dict["value"]) == b"foobar" def test_deserialization(self): - artifact = BlobArtifact(b"foobar", name="foobar.txt", dir_name="foo") + artifact = BlobArtifact(b"foobar", name="foobar.txt") artifact_dict = artifact.to_dict() deserialized_artifact = BaseArtifact.from_dict(artifact_dict) assert isinstance(deserialized_artifact, BlobArtifact) assert deserialized_artifact.name == "foobar.txt" - assert deserialized_artifact.dir_name == "foo" assert deserialized_artifact.value == b"foobar" def test_name(self): diff --git a/tests/unit/artifacts/test_image_artifact.py b/tests/unit/artifacts/test_image_artifact.py index a722ebd91..b048b3372 100644 --- a/tests/unit/artifacts/test_image_artifact.py +++ b/tests/unit/artifacts/test_image_artifact.py @@ -11,12 +11,11 @@ def image_artifact(self): format="png", width=512, height=512, - model="openai/dalle2", - prompt="a cute cat", + meta={"model": "openai/dalle2", "prompt": "a cute cat"}, ) def test_to_text(self, image_artifact: ImageArtifact): - assert image_artifact.to_text() == "Media, type: image/png, size: 26 bytes" + assert image_artifact.to_text() == "c29tZSBiaW5hcnkgcG5nIGltYWdlIGRhdGE=" def test_to_dict(self, image_artifact: ImageArtifact): image_dict = image_artifact.to_dict() @@ -24,8 +23,8 @@ def test_to_dict(self, image_artifact: ImageArtifact): assert image_dict["format"] == "png" assert image_dict["width"] == 512 assert image_dict["height"] == 512 - assert image_dict["model"] == "openai/dalle2" - assert image_dict["prompt"] == "a cute cat" + assert image_dict["meta"]["model"] == "openai/dalle2" + assert image_dict["meta"]["prompt"] == "a cute cat" assert image_dict["value"] == "c29tZSBiaW5hcnkgcG5nIGltYWdlIGRhdGE=" def test_deserialization(self, image_artifact): @@ -39,5 +38,5 @@ def test_deserialization(self, image_artifact): assert deserialized_artifact.format == "png" assert deserialized_artifact.width == 512 assert deserialized_artifact.height == 512 - assert deserialized_artifact.model == "openai/dalle2" - assert deserialized_artifact.prompt == "a cute cat" + assert deserialized_artifact.meta["model"] == "openai/dalle2" + assert deserialized_artifact.meta["prompt"] == "a cute cat" diff --git a/tests/unit/artifacts/test_json_artifact.py b/tests/unit/artifacts/test_json_artifact.py index 06f5d6297..be61e3edf 100644 --- a/tests/unit/artifacts/test_json_artifact.py +++ b/tests/unit/artifacts/test_json_artifact.py @@ -1,8 +1,6 @@ import json -import pytest - -from griptape.artifacts import JsonArtifact, TextArtifact +from griptape.artifacts import JsonArtifact class TestJsonArtifact: @@ -14,11 +12,11 @@ def test_value_type_conversion(self): assert JsonArtifact({"foo": None}).value == json.loads(json.dumps({"foo": None})) assert JsonArtifact([{"foo": {"bar": "baz"}}]).value == json.loads(json.dumps([{"foo": {"bar": "baz"}}])) assert JsonArtifact(None).value == json.loads(json.dumps(None)) - assert JsonArtifact("foo").value == json.loads(json.dumps("foo")) - - def test___add__(self): - with pytest.raises(NotImplementedError): - JsonArtifact({"foo": "bar"}) + TextArtifact("invalid json") + assert JsonArtifact('"foo"').value == "foo" + assert JsonArtifact("true").value is True + assert JsonArtifact("false").value is False + assert JsonArtifact("123").value == 123 + assert JsonArtifact("123.4").value == 123.4 def test_to_text(self): assert JsonArtifact({"foo": "bar"}).to_text() == json.dumps({"foo": "bar"}) diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index 06d234645..bf7004ad0 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -1,6 +1,7 @@ import pytest -from griptape.artifacts import BlobArtifact, CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, ListArtifact, TextArtifact +from griptape.artifacts.image_artifact import ImageArtifact class TestListArtifact: @@ -32,8 +33,7 @@ def test_child_type(self): def test_is_type(self): assert ListArtifact([TextArtifact("foo")]).is_type(TextArtifact) - assert ListArtifact([CsvRowArtifact({"foo": "bar"})]).is_type(TextArtifact) - assert ListArtifact([CsvRowArtifact({"foo": "bar"})]).is_type(CsvRowArtifact) + assert ListArtifact([ImageArtifact(b"", width=1234, height=1234)]).is_type(ImageArtifact) def test_has_items(self): assert not ListArtifact().has_items() diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index 9aa4d3f4f..05e669b66 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -60,5 +60,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "stability.stable-diffusion-xl-v1" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "stability.stable-diffusion-xl-v1" + assert image_artifact.meta["prompt"] == "test prompt" diff --git a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py index 268708b2b..a72764211 100644 --- a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py @@ -28,7 +28,10 @@ def test_init(self, driver): def test_init_requires_endpoint(self): with pytest.raises(TypeError): AzureOpenAiImageGenerationDriver( - model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512" + model="dall-e-3", + client=Mock(), + azure_deployment="dalle-deployment", + image_size="512x512", ) # pyright: ignore[reportCallIssues] def test_try_text_to_image(self, driver): @@ -40,5 +43,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-3" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-3" + assert image_artifact.meta["prompt"] == "test prompt" diff --git a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py index 48805cde6..ec70e2dd2 100644 --- a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py @@ -76,5 +76,5 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "test_model_id" - assert image_artifact.prompt == "test_prompt" + assert image_artifact.meta["model"] == "test_model_id" + assert image_artifact.meta["prompt"] == "test_prompt" diff --git a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py index 16bcd2870..ff5528fb6 100644 --- a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py @@ -22,8 +22,8 @@ def test_try_text_to_image(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-2" + assert image_artifact.meta["prompt"] == "test prompt" def test_try_image_variation(self, driver): driver.client.images.create_variation.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")]) @@ -34,7 +34,7 @@ def test_try_image_variation(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" + assert image_artifact.meta["model"] == "dall-e-2" def test_try_image_variation_invalid_size(self, driver): driver.image_size = "1024x1792" @@ -59,8 +59,8 @@ def test_try_image_inpainting(self, driver): assert image_artifact.mime_type == "image/png" assert image_artifact.width == 512 assert image_artifact.height == 512 - assert image_artifact.model == "dall-e-2" - assert image_artifact.prompt == "test prompt" + assert image_artifact.meta["model"] == "dall-e-2" + assert image_artifact.meta["prompt"] == "test prompt" def test_try_image_inpainting_invalid_size(self, driver): driver.image_size = "1024x1792" diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py index ac4ff8043..20a3e2b50 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py @@ -4,7 +4,6 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.artifacts.csv_row_artifact import CsvRowArtifact class BaseLocalVectorStoreDriver(ABC): @@ -26,20 +25,6 @@ def test_upsert(self, driver): assert len(driver.entries) == 2 - def test_upsert_csv_row(self, driver): - namespace = driver.upsert_text_artifact(CsvRowArtifact(id="foo1", value={"col": "value"})) - - assert len(driver.entries) == 1 - assert list(driver.entries.keys())[0] == namespace - - driver.upsert_text_artifact(CsvRowArtifact(id="foo1", value={"col": "value"})) - - assert len(driver.entries) == 1 - - driver.upsert_text_artifact(CsvRowArtifact(id="foo2", value={"col": "value2"})) - - assert len(driver.entries) == 2 - def test_upsert_multiple(self, driver): driver.upsert_text_artifacts({"foo": [TextArtifact("foo")], "bar": [TextArtifact("bar")]}) diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index 893c21d60..056df2d5a 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -12,11 +12,11 @@ def test_extract(self, engine): result = engine.extract("foo") assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" def test_text_to_csv_rows(self, engine): result = engine.text_to_csv_rows("foo,bar\nbaz,maz", ["test1", "test2"]) assert len(result) == 2 - assert result[0].value == {"test1": "foo", "test2": "bar"} - assert result[1].value == {"test1": "baz", "test2": "maz"} + assert result[0].value == "test1: foo\ntest2: bar" + assert result[1].value == "test1: baz\ntest2: maz" diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index 473fd0d9e..b7ebdd912 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -13,14 +13,13 @@ def loader(self): def create_source(self, bytes_from_resource_path): return bytes_from_resource_path - @pytest.mark.parametrize(("resource_path", "suffix", "mime_type"), [("sentences.wav", ".wav", "audio/wav")]) - def test_load(self, resource_path, suffix, mime_type, loader, create_source): + @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) + def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) artifact = loader.load(source) assert isinstance(artifact, AudioArtifact) - assert artifact.name.endswith(suffix) assert artifact.mime_type == mime_type assert len(artifact.value) > 0 @@ -35,6 +34,5 @@ def test_load_collection(self, create_source, loader): for key in collection: artifact = collection[key] assert isinstance(artifact, AudioArtifact) - assert artifact.name.endswith(".wav") assert artifact.mime_type == "audio/wav" assert len(artifact.value) > 0 diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index a747afff7..d08b93919 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -28,8 +28,7 @@ def test_load(self, loader, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "foo1" - assert first_artifact.value["Bar"] == "bar1" + assert first_artifact.value == "Foo: foo1\nBar: bar1" assert first_artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): @@ -39,8 +38,7 @@ def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "bar1" - assert first_artifact.value["Bar"] == "foo1" + assert first_artifact.value == "Bar: foo1\nFoo: bar1" assert first_artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): @@ -52,10 +50,8 @@ def test_load_collection(self, loader, create_source): keys = {loader.to_key(source) for source in sources} assert collection.keys() == keys - for key in keys: - artifacts = collection[key] - assert len(artifacts) == 10 - first_artifact = artifacts[0] - assert first_artifact.value["Foo"] == "foo1" - assert first_artifact.value["Bar"] == "bar1" - assert first_artifact.embedding == [0, 1] + assert collection[loader.to_key(sources[0])][0].value == "Foo: foo1\nBar: bar1" + assert collection[loader.to_key(sources[0])][0].embedding == [0, 1] + + assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" + assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_dataframe_loader.py b/tests/unit/loaders/test_dataframe_loader.py deleted file mode 100644 index 5c2a57ed6..000000000 --- a/tests/unit/loaders/test_dataframe_loader.py +++ /dev/null @@ -1,52 +0,0 @@ -import os - -import pandas as pd -import pytest - -from griptape.loaders.dataframe_loader import DataFrameLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - - -class TestDataFrameLoader: - @pytest.fixture() - def loader(self): - return DataFrameLoader(embedding_driver=MockEmbeddingDriver()) - - def test_load_with_path(self, loader): - # test loading a file delimited by comma - path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - - artifacts = loader.load(pd.read_csv(path)) - - assert len(artifacts) == 10 - first_artifact = artifacts[0].value - assert first_artifact["Foo"] == "foo1" - assert first_artifact["Bar"] == "bar1" - - assert artifacts[0].embedding == [0, 1] - - def test_load_collection_with_path(self, loader): - path1 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - path2 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-2.csv") - df1 = pd.read_csv(path1) - df2 = pd.read_csv(path2) - collection = loader.load_collection([df1, df2]) - - key1 = loader.to_key(df1) - key2 = loader.to_key(df2) - - assert list(collection.keys()) == [key1, key2] - - artifacts = collection[key1] - assert len(artifacts) == 10 - first_artifact = artifacts[0].value - assert first_artifact["Foo"] == "foo1" - assert first_artifact["Bar"] == "bar1" - - artifacts = collection[key2] - assert len(artifacts) == 10 - first_artifact = artifacts[0].value - assert first_artifact["Bar"] == "bar1" - assert first_artifact["Foo"] == "foo1" - - assert artifacts[0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index eca4cbccc..7093894b0 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -18,23 +18,22 @@ def create_source(self, bytes_from_resource_path): return bytes_from_resource_path @pytest.mark.parametrize( - ("resource_path", "suffix", "mime_type"), + ("resource_path", "mime_type"), [ - ("small.png", ".png", "image/png"), - ("small.jpg", ".jpeg", "image/jpeg"), - ("small.webp", ".webp", "image/webp"), - ("small.bmp", ".bmp", "image/bmp"), - ("small.gif", ".gif", "image/gif"), - ("small.tiff", ".tiff", "image/tiff"), + ("small.png", "image/png"), + ("small.jpg", "image/jpeg"), + ("small.webp", "image/webp"), + ("small.bmp", "image/bmp"), + ("small.gif", "image/gif"), + ("small.tiff", "image/tiff"), ], ) - def test_load(self, resource_path, suffix, mime_type, loader, create_source): + def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) artifact = loader.load(source) assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(suffix) assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == mime_type @@ -49,7 +48,6 @@ def test_load_normalize(self, resource_path, png_loader, create_source): artifact = png_loader.load(source) assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(".png") assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == "image/png" @@ -68,7 +66,6 @@ def test_load_collection(self, create_source, png_loader): for key in keys: artifact = collection[key] assert isinstance(artifact, ImageArtifact) - assert artifact.name.endswith(".png") assert artifact.height == 32 assert artifact.width == 32 assert artifact.mime_type == "image/png" diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index fbfa6d4fa..2ff6c7faf 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -38,24 +38,21 @@ def test_load(self, loader): artifacts = loader.load("SELECT * FROM test_table;") assert len(artifacts) == 3 - assert artifacts[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} - assert artifacts[1].value == {"id": 2, "name": "Bob", "age": 30, "city": "Los Angeles"} - assert artifacts[2].value == {"id": 3, "name": "Charlie", "age": 22, "city": "Chicago"} + assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" + assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" assert artifacts[0].embedding == [0, 1] def test_load_collection(self, loader): - artifacts = loader.load_collection(["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"]) + sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"] + artifacts = loader.load_collection(sources) assert list(artifacts.keys()) == [ loader.to_key("SELECT * FROM test_table LIMIT 1;"), loader.to_key("SELECT * FROM test_table LIMIT 2;"), ] - assert [a.value for artifact_list in artifacts.values() for a in artifact_list] == [ - {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, - {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, - {"age": 30, "city": "Los Angeles", "id": 2, "name": "Bob"}, - ] - + assert artifacts[loader.to_key(sources[0])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifacts[loader.to_key(sources[1])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" assert list(artifacts.values())[0][0].embedding == [0, 1] diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index 2f6ffe1c9..d2575959a 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -1,6 +1,6 @@ import pytest -from griptape.artifacts import BlobArtifact, CsvRowArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.memory import TaskMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.structures import Agent @@ -10,10 +10,6 @@ class TestTaskMemory: - @pytest.fixture(autouse=True) - def _mock_griptape(self, mocker): - mocker.patch("griptape.engines.CsvExtractionEngine.extract", return_value=[CsvRowArtifact({"foo": "bar"})]) - @pytest.fixture() def memory(self): return defaults.text_task_memory("MyMemory") diff --git a/tests/unit/mixins/test_image_artifact_file_output_mixin.py b/tests/unit/mixins/test_image_artifact_file_output_mixin.py index cf124da39..7e2926e09 100644 --- a/tests/unit/mixins/test_image_artifact_file_output_mixin.py +++ b/tests/unit/mixins/test_image_artifact_file_output_mixin.py @@ -4,12 +4,12 @@ import pytest from griptape.artifacts import ImageArtifact -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin -class TestMediaArtifactFileOutputMixin: +class TestArtifactFileOutputMixin: def test_no_output(self): - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): pass assert Test().output_file is None @@ -18,7 +18,7 @@ class Test(BlobArtifactFileOutputMixin): def test_output_file(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): def run(self) -> None: self._write_to_file(artifact) @@ -33,7 +33,7 @@ def run(self) -> None: def test_output_dir(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): def run(self) -> None: self._write_to_file(artifact) @@ -46,7 +46,7 @@ def run(self) -> None: assert os.path.exists(os.path.join(outdir, artifact.name)) def test_output_file_and_dir(self): - class Test(BlobArtifactFileOutputMixin): + class Test(ArtifactFileOutputMixin): pass outfile = "test.txt" diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index 2d7ab442c..06d444f9b 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -18,4 +18,4 @@ def test_run(self, task): result = task.run() assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index cfe853226..2d4456ce2 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -114,9 +114,9 @@ def test_input(self): assert task.input.value[1].width == 100 # default case - task = PromptTask({"default": "test"}) + task = PromptTask({"default": "test"}) # pyright: ignore[reportArgumentType] - assert task.input.value == str({"default": "test"}) + assert task.input.value == "default: test" def test_prompt_stack(self): task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")]) diff --git a/tests/unit/tools/test_extraction_tool.py b/tests/unit/tools/test_extraction_tool.py index 1219da373..3f783e8a4 100644 --- a/tests/unit/tools/test_extraction_tool.py +++ b/tests/unit/tools/test_extraction_tool.py @@ -58,10 +58,10 @@ def test_csv_extract_artifacts(self, csv_tool): ) assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" def test_csv_extract_content(self, csv_tool): result = csv_tool.extract({"values": {"data": "foo"}}) assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} + assert result.value[0].value == "test1: mock output" diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 569c0a280..469918a02 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -5,7 +5,7 @@ import pytest -from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader from griptape.tools import FileManagerTool @@ -106,29 +106,6 @@ def test_save_memory_artifacts_to_disk_for_multiple_artifacts(self, temp_dir): assert Path(os.path.join(temp_dir, "test", f"{artifacts[1].name}-{file_name}")).read_text() == "baz" assert result.value == "Successfully saved memory artifacts to disk" - def test_save_memory_artifacts_to_disk_for_non_string_artifact(self, temp_dir): - memory = defaults.text_task_memory("Memory1") - artifact = CsvRowArtifact({"foo": "bar"}) - - memory.store_artifact("foobar", artifact) - - file_manager = FileManagerTool( - input_memory=[memory], file_manager_driver=LocalFileManagerDriver(workdir=temp_dir) - ) - result = file_manager.save_memory_artifacts_to_disk( - { - "values": { - "dir_name": "test", - "file_name": "foobar.txt", - "memory_name": memory.name, - "artifact_namespace": "foobar", - } - } - ) - - assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foo\nbar" - assert result.value == "Successfully saved memory artifacts to disk" - def test_save_content_to_file(self, temp_dir): file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(workdir=temp_dir)) result = file_manager.save_content_to_file( diff --git a/tests/unit/tools/test_inpainting_image_generation_tool.py b/tests/unit/tools/test_inpainting_image_generation_tool.py index 45afcbc63..a558921a9 100644 --- a/tests/unit/tools/test_inpainting_image_generation_tool.py +++ b/tests/unit/tools/test_inpainting_image_generation_tool.py @@ -59,8 +59,8 @@ def test_image_inpainting_with_outfile( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_inpainting_from_file( @@ -83,8 +83,8 @@ def test_image_inpainting_from_memory(self, image_generation_engine, image_artif memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_inpainting_from_memory( diff --git a/tests/unit/tools/test_outpainting_image_variation_tool.py b/tests/unit/tools/test_outpainting_image_variation_tool.py index 4fbcbe8d4..e3f0de847 100644 --- a/tests/unit/tools/test_outpainting_image_variation_tool.py +++ b/tests/unit/tools/test_outpainting_image_variation_tool.py @@ -34,8 +34,8 @@ def test_validate_output_configs(self, image_generation_engine) -> None: OutpaintingImageGenerationTool(engine=image_generation_engine, output_dir="test", output_file="test") def test_image_outpainting(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = Mock( - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_outpainting_from_file( @@ -59,8 +59,8 @@ def test_image_outpainting_with_outfile( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_outpainting_from_file( diff --git a/tests/unit/tools/test_prompt_image_generation_tool.py b/tests/unit/tools/test_prompt_image_generation_tool.py index a0c5c7037..4252d887e 100644 --- a/tests/unit/tools/test_prompt_image_generation_tool.py +++ b/tests/unit/tools/test_prompt_image_generation_tool.py @@ -5,6 +5,7 @@ import pytest +from griptape.artifacts.image_artifact import ImageArtifact from griptape.tools import PromptImageGenerationTool @@ -36,8 +37,8 @@ def test_generate_image_with_outfile(self, image_generation_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = PromptImageGenerationTool(engine=image_generation_engine, output_file=outfile) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.generate_image( diff --git a/tests/unit/tools/test_sql_tool.py b/tests/unit/tools/test_sql_tool.py index 2ef50ff54..ae81965b0 100644 --- a/tests/unit/tools/test_sql_tool.py +++ b/tests/unit/tools/test_sql_tool.py @@ -26,7 +26,8 @@ def test_execute_query(self, driver): result = client.execute_query({"values": {"sql_query": "SELECT * from test_table;"}}) assert len(result.value) == 1 - assert result.value[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} + assert result.value[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert result.value[0].meta["row"] == 0 def test_execute_query_description(self, driver): client = SqlTool( diff --git a/tests/unit/tools/test_text_to_speech_tool.py b/tests/unit/tools/test_text_to_speech_tool.py index 8821d48fc..6f2c43bd3 100644 --- a/tests/unit/tools/test_text_to_speech_tool.py +++ b/tests/unit/tools/test_text_to_speech_tool.py @@ -5,6 +5,7 @@ import pytest +from griptape.artifacts.audio_artifact import AudioArtifact from griptape.tools.text_to_speech.tool import TextToSpeechTool @@ -32,7 +33,7 @@ def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile) - text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] + text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) diff --git a/tests/unit/tools/test_variation_image_generation_tool.py b/tests/unit/tools/test_variation_image_generation_tool.py index c4528a044..5fd3513c1 100644 --- a/tests/unit/tools/test_variation_image_generation_tool.py +++ b/tests/unit/tools/test_variation_image_generation_tool.py @@ -58,8 +58,8 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] - value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" + image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"image data", format="png", width=512, height=512 ) image_artifact = image_generator.image_variation_from_file( From 8ad5b368457d274bf48d329139c1dc12ed173cb7 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 13:26:13 -0700 Subject: [PATCH 02/52] Artifacts are just data --- CHANGELOG.md | 5 +- MIGRATION.md | 19 ------- docs/griptape-framework/data/artifacts.md | 50 +++++++++---------- griptape/artifacts/__init__.py | 14 +++--- griptape/artifacts/base_system_artifact.py | 10 ---- griptape/artifacts/blob_artifact.py | 4 ++ griptape/artifacts/boolean_artifact.py | 17 +++++-- griptape/artifacts/csv_row_artifact.py | 34 ------------- griptape/artifacts/error_artifact.py | 7 ++- griptape/artifacts/image_artifact.py | 4 +- griptape/artifacts/info_artifact.py | 7 ++- griptape/artifacts/list_artifact.py | 4 +- griptape/artifacts/text_artifact.py | 15 +++--- .../contents/text_message_content.py | 4 +- griptape/common/prompt_stack/prompt_stack.py | 6 +-- tests/unit/artifacts/test_csv_row_artifact.py | 30 ----------- 16 files changed, 73 insertions(+), 157 deletions(-) delete mode 100644 griptape/artifacts/base_system_artifact.py delete mode 100644 griptape/artifacts/csv_row_artifact.py delete mode 100644 tests/unit/artifacts/test_csv_row_artifact.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b57bce948..77f052e5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. -- **BREAKING**: Removed `BooleanArtifact`, use `JsonArtifact` instead. - **BREAKING**: Removed `CsvRowArtifact`. - **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. @@ -20,8 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. - Updated `JsonArtifact` value converter to properly handle more types. -- `AudioArtifact` now subclasses `BaseArtifact` instead of `MediaArtifact`. -- `ImageArtifact` now subclasses `BaseArtifact` instead of `MediaArtifact`. +- `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. +- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - Passing a dictionary as the value to `TextArtifact` will convert to a key-value formatted string. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. diff --git a/MIGRATION.md b/MIGRATION.md index 39932d4fa..ddf9b9c00 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -36,25 +36,6 @@ audio_artifact = AudioArtifact( ) ``` -### Removed `BooleanArtifact` - -`BooleanArtifact` has been removed. Use `JsonArtifact` instead. - -#### Before - -```python -boolean_artifact = BooleanArtifact("true") - -print(boolean_artifact.value) # Value is True -``` - -#### After -```python -json_artifact = JsonArtifact("true") - -print(json_artifact.value) # Value is True -``` - ### Removed `CsvRowArtifact` `CsvRowArtifact` has been removed. Use `TextArtifact` instead. diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index b90636731..2edd1ebec 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -5,54 +5,50 @@ search: ## Overview - -**[Artifacts](../../reference/griptape/artifacts/base_artifact.md)** are used to store data that can be provided as input to or received as output from a Language Learning Model (LLM). +**[Artifacts](../../reference/griptape/artifacts/base_artifact.md)** are the core data structure in Griptape. They are used to encapsulate data and enhance it with metadata. ## Text -[TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s store textual data. They can be used to count tokens using the [token_count()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.token_count) method with a tokenizer, generate a text embedding through the [generate_embedding()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.generate_embedding) method, and access the embedding with the [embedding](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.embedding) property. +[TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s store textual data. They offer methods such as [token_count()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.token_count) for counting tokens with a tokenizer, and [generate_embedding()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.generate_embedding) for creating text embeddings. You can also access the embedding via the [embedding](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.embedding) property. -[TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores `TextArtifacts` returned by tool activities and provides their IDs back to the LLM. +When `TextArtifact`s are returned from Tools, they will be stored in [Task Memory](../../griptape-framework/structures/task-memory.md) if the Tool has set `off_prompt=True`. -## Image +## Blob -[ImageArtifact](../../reference/griptape/artifacts/image_artifact.md)s store image data. They include binary image data and metadata such as MIME type, dimensions, and prompt and model information for images returned by [image generation drivers](../drivers/image-generation-drivers.md). They inherit functionality from [BlobArtifacts](#blob). +[BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s store binary large objects (blobs). -## Audio +When `BlobArtifact`s are returned from Tools, they will be stored in [Task Memory](../../griptape-framework/structures/task-memory.md) if the Tool has set `off_prompt=True`. -[AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md)s store audio content, including binary audio data and metadata such as format, duration, and prompt and model information for audio returned by generative models. They inherit from [BlobArtifacts](#blob). +### Image -## Action +[ImageArtifact](../../reference/griptape/artifacts/image_artifact.md)s store image data. This includes binary image data along with metadata such as MIME type and dimensions. They are a subclass of [BlobArtifacts](#blob). -[ActionArtifact](../../reference/griptape/artifacts/action_artifact.md)s represent actions taken by the LLM. Currently, the only supported action is [ToolAction](../../reference/griptape/common/actions/tool_action.md), which is used to execute a [Tool](../../griptape-framework/tools/index.md). +### Audio -## JSON +[AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md)s store audio content. This includes binary audio data and metadata such as format, and duration. They are a subclass of [BlobArtifacts](#blob). -[JsonArtifact](../../reference/griptape/artifacts/json_artifact.md)s store JSON-serializable data. Any data assigned to the `value` property is converted using `json.dumps(json.loads(value))`. - -## Generic +## List -[GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md)s act as an escape hatch for passing any type of data that does not fit into any other artifact type. While generally not recommended, they are suitable for specific scenarios. For example, see [talking to a video](../../examples/talk-to-a-video.md), which demonstrates using a `GenericArtifact` to pass a Gemini-specific video file. +[ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s store lists of Artifacts. -## System Artifacts +When `ListArtifact`s are returned from Tools, their elements will be stored in [Task Memory](../../griptape-framework/structures/task-memory.md) if the element is either a `TextArtifact` or a `BlobArtifact` and the Tool has set `off_prompt=True`. -These Artifacts don't map to an LLM modality. They must be transformed in some way before they can be used as LLM input. +## Info -### Blob +[InfoArtifact](../../reference/griptape/artifacts/info_artifact.md)s store small pieces of textual information. These are useful for conveying messages about the execution or results of an operation, such as "No results found" or "Operation completed successfully." -[BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s store binary large objects (blobs) and are used to pass unstructured data back to the LLM via [InfoArtifact](#info). - -`TaskMemory` automatically stores `BlobArtifacts` returned by tool activities, allowing them to be reused by other tools. +## JSON -### Info +[JsonArtifact](../../reference/griptape/artifacts/json_artifact.md)s store JSON-serializable data. Any data assigned to the `value` property is processed using `json.dumps(json.loads(value))`. -[InfoArtifact](../../reference/griptape/artifacts/info_artifact.md)s store short notifications that are passed back to the LLM without being stored in Task Memory. +## Error -### Error +[ErrorArtifact](../../reference/griptape/artifacts/error_artifact.md)s store exception information, providing a structured way to convey errors. -[ErrorArtifact](../../reference/griptape/artifacts/error_artifact.md)s store errors that are passed back to the LLM without being stored in Task Memory. +## Action -### List +[ActionArtifact](../../reference/griptape/artifacts/action_artifact.md)s represent actions taken by an LLM. Currently, the only supported action type is [ToolAction](../../reference/griptape/common/actions/tool_action.md), which is used to execute a [Tool](../../griptape-framework/tools/index.md). -[ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s store lists of Artifacts that can be passed to the LLM. +## Generic +[GenericArtifact](../../reference/griptape/artifacts/generic_artifact.md)s provide a flexible way to pass data that does not fit into any other artifact category. While not generally recommended, they can be useful for specific use cases. For instance, see [talking to a video](../../examples/talk-to-a-video.md), which demonstrates using a `GenericArtifact` to pass a Gemini-specific video file. diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index a647d0578..0e58a8a76 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,27 +1,25 @@ from .base_artifact import BaseArtifact -from .base_system_artifact import BaseSystemArtifact - +from .error_artifact import ErrorArtifact +from .info_artifact import InfoArtifact from .text_artifact import TextArtifact +from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact +from .boolean_artifact import BooleanArtifact +from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact -from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact -from .error_artifact import ErrorArtifact -from .info_artifact import InfoArtifact -from .list_artifact import ListArtifact - __all__ = [ "BaseArtifact", - "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", "JsonArtifact", "BlobArtifact", + "BooleanArtifact", "ListArtifact", "ImageArtifact", "AudioArtifact", diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py deleted file mode 100644 index e5e7bfab6..000000000 --- a/griptape/artifacts/base_system_artifact.py +++ /dev/null @@ -1,10 +0,0 @@ -from abc import ABC - -from griptape.artifacts import BaseArtifact - - -class BaseSystemArtifact(BaseArtifact, ABC): - """Serves as the base class for all Artifacts specific to Griptape.""" - - def to_text(self) -> str: - return self.value diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 072b6cb5b..9fe729a63 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -28,6 +28,10 @@ class BlobArtifact(BaseArtifact): encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) + @property + def mime_type(self) -> str: + return "application/octet-stream" + def to_bytes(self) -> bytes: return self.value diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index 5bcdfac9b..dd8bb1504 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -9,17 +9,23 @@ @define class BooleanArtifact(BaseArtifact): + """Stores a boolean value. + + Attributes: + value: The boolean value. + """ + value: bool = field(converter=bool, metadata={"serializable": True}) @classmethod - def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001 - """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" + def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: + """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false".""" if value is not None: if isinstance(value, str): if value.lower() == "true": - return BooleanArtifact(True) # noqa: FBT003 + return BooleanArtifact(value=True) elif value.lower() == "false": - return BooleanArtifact(False) # noqa: FBT003 + return BooleanArtifact(value=False) elif isinstance(value, bool): return BooleanArtifact(value) raise ValueError(f"Cannot convert '{value}' to BooleanArtifact") @@ -29,3 +35,6 @@ def __add__(self, other: BaseArtifact) -> BooleanArtifact: def __eq__(self, value: object) -> bool: return self.value is value + + def to_text(self) -> str: + return str(self.value).lower() diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py deleted file mode 100644 index 00f1047fc..000000000 --- a/griptape/artifacts/csv_row_artifact.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -import csv -import io - -from attrs import define, field - -from griptape.artifacts import BaseArtifact, TextArtifact - - -@define -class CsvRowArtifact(TextArtifact): - value: dict[str, str] = field(converter=BaseArtifact.value_to_dict, metadata={"serializable": True}) - delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) - - def __add__(self, other: BaseArtifact) -> CsvRowArtifact: - return CsvRowArtifact(self.value | other.value) - - def __bool__(self) -> bool: - return len(self) > 0 - - def to_text(self) -> str: - with io.StringIO() as csvfile: - writer = csv.DictWriter( - csvfile, - fieldnames=self.value.keys(), - quoting=csv.QUOTE_MINIMAL, - delimiter=self.delimiter, - ) - - writer.writeheader() - writer.writerow(self.value) - - return csvfile.getvalue().strip() diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index 086cdb893..27e6a37ab 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,11 +4,11 @@ from attrs import define, field -from griptape.artifacts import BaseSystemArtifact +from griptape.artifacts import BaseArtifact @define -class ErrorArtifact(BaseSystemArtifact): +class ErrorArtifact(BaseArtifact): """Represents an error that may want to be conveyed to the LLM. Attributes: @@ -18,3 +18,6 @@ class ErrorArtifact(BaseSystemArtifact): value: str = field(converter=str, metadata={"serializable": True}) exception: Optional[Exception] = field(default=None, kw_only=True, metadata={"serializable": False}) + + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index 28de3b1c9..5a90ae403 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -4,11 +4,11 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BlobArtifact @define -class ImageArtifact(BaseArtifact): +class ImageArtifact(BlobArtifact): """Stores image data. Attributes: diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 66fe94fc9..3391554e9 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,11 +2,11 @@ from attrs import define, field -from griptape.artifacts import BaseSystemArtifact +from griptape.artifacts import BaseArtifact @define -class InfoArtifact(BaseSystemArtifact): +class InfoArtifact(BaseArtifact): """Represents helpful info that can be conveyed to the LLM. For example, "No results found" or "Please try again.". @@ -16,3 +16,6 @@ class InfoArtifact(BaseSystemArtifact): """ value: str = field(converter=str, metadata={"serializable": True}) + + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 00df83789..ac11d95cc 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,14 +4,14 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact, BaseSystemArtifact +from griptape.artifacts import BaseArtifact if TYPE_CHECKING: from collections.abc import Sequence @define -class ListArtifact(BaseSystemArtifact): +class ListArtifact(BaseArtifact): value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 3751d2c89..31d20e532 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -23,7 +23,7 @@ class TextArtifact(BaseArtifact): value: str = field(converter=value_to_str, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) - _embedding: list[float] = field(factory=list, kw_only=True) + embedding: Optional[list[float]] = field(default=None, kw_only=True) def __add__(self, other: BaseArtifact) -> TextArtifact: return TextArtifact(self.value + other.value) @@ -31,19 +31,18 @@ def __add__(self, other: BaseArtifact) -> TextArtifact: def __bool__(self) -> bool: return bool(self.value.strip()) - @property - def embedding(self) -> Optional[list[float]]: - return None if len(self._embedding) == 0 else self._embedding - def to_text(self) -> str: return self.value def to_bytes(self) -> bytes: return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) - def generate_embedding(self, driver: BaseEmbeddingDriver) -> Optional[list[float]]: - self._embedding.clear() - self._embedding.extend(driver.embed_string(str(self.value))) + def generate_embedding(self, driver: BaseEmbeddingDriver) -> list[float]: + if self.embedding is None: + self.embedding = [] + + self.embedding.clear() + self.embedding.extend(driver.embed_string(str(self.value))) return self.embedding diff --git a/griptape/common/prompt_stack/contents/text_message_content.py b/griptape/common/prompt_stack/contents/text_message_content.py index c862564f3..39e678f28 100644 --- a/griptape/common/prompt_stack/contents/text_message_content.py +++ b/griptape/common/prompt_stack/contents/text_message_content.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.common import BaseDeltaMessageContent, BaseMessageContent, TextDeltaMessageContent if TYPE_CHECKING: @@ -13,7 +13,7 @@ @define class TextMessageContent(BaseMessageContent): - artifact: TextArtifact = field(metadata={"serializable": True}) + artifact: BaseArtifact = field(metadata={"serializable": True}) @classmethod def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> TextMessageContent: diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 6d8dfde75..77ce4ba9b 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -7,7 +7,6 @@ from griptape.artifacts import ( ActionArtifact, BaseArtifact, - ErrorArtifact, GenericArtifact, ImageArtifact, ListArtifact, @@ -70,8 +69,6 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage return [ImageMessageContent(artifact)] elif isinstance(artifact, GenericArtifact): return [GenericMessageContent(artifact)] - elif isinstance(artifact, ErrorArtifact): - return [TextMessageContent(TextArtifact(artifact.to_text()))] elif isinstance(artifact, ActionArtifact): action = artifact.value output = action.output @@ -81,6 +78,7 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage return [ActionResultMessageContent(output, action=action)] elif isinstance(artifact, ListArtifact): processed_contents = [self.__to_message_content(artifact) for artifact in artifact.value] + return [sub_content for processed_content in processed_contents for sub_content in processed_content] else: - raise ValueError(f"Unsupported artifact type: {type(artifact)}") + return [TextMessageContent(TextArtifact(artifact.to_text()))] diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py deleted file mode 100644 index fe0b8cd64..000000000 --- a/tests/unit/artifacts/test_csv_row_artifact.py +++ /dev/null @@ -1,30 +0,0 @@ -from griptape.artifacts import CsvRowArtifact - - -class TestCsvRowArtifact: - def test_value_type_conversion(self): - assert CsvRowArtifact({"foo": "bar"}).value == {"foo": "bar"} - assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == {"foo": {"bar": "baz"}} - assert CsvRowArtifact('{"foo": "bar"}').value == {"foo": "bar"} - - def test___add__(self): - assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == { - "test1": "foo", - "test2": "bar", - } - - def test_to_text(self): - assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}, delimiter="|").to_text() == 'test1|test2\r\n"foo|bar"|1' - - def test_to_dict(self): - assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == {"test1": "foo"} - - def test_name(self): - artifact = CsvRowArtifact({}) - - assert artifact.name == artifact.id - assert CsvRowArtifact({}, name="bar").name == "bar" - - def test___bool__(self): - assert not bool(CsvRowArtifact({})) - assert bool(CsvRowArtifact({"foo": "bar"})) From 58a8233c51f803e3a17f241c754af3eeb7b069e5 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 21 Aug 2024 09:23:51 -0700 Subject: [PATCH 03/52] Improve ListArtifact Make type covariant --- CHANGELOG.md | 3 +++ griptape/artifacts/list_artifact.py | 19 ++++++++++++------- griptape/schemas/base_schema.py | 6 +++++- griptape/tasks/tool_task.py | 6 +++++- tests/unit/artifacts/test_list_artifact.py | 6 ++++++ 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77f052e5a..13d6d8f14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - Passing a dictionary as the value to `TextArtifact` will convert to a key-value formatted string. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. +- Generic type support to `ListArtifact`. +- Iteration support to `ListArtifact`. ## [0.31.0] - 2024-09-03 @@ -45,6 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed the `__all__` declaration from the `griptape.mixins` module. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. +- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAI compatible services. diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index ac11d95cc..0e6f81ca5 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -1,32 +1,37 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Generic, Optional, TypeVar from attrs import Attribute, define, field from griptape.artifacts import BaseArtifact if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterator, Sequence + +T = TypeVar("T", bound=BaseArtifact, covariant=True) @define -class ListArtifact(BaseArtifact): - value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) +class ListArtifact(BaseArtifact, Generic[T]): + value: Sequence[T] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - def __getitem__(self, key: int) -> BaseArtifact: + def __getitem__(self, key: int) -> T: return self.value[key] def __bool__(self) -> bool: return len(self) > 0 - def __add__(self, other: BaseArtifact) -> BaseArtifact: + def __add__(self, other: BaseArtifact) -> ListArtifact[T]: return ListArtifact(self.value + other.value) + def __iter__(self) -> Iterator[T]: + return iter(self.value) + @value.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None: + def validate_value(self, _: Attribute, value: list[T]) -> None: if self.validate_uniform_types and len(value) > 0: first_type = type(value[0]) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 9290c6098..dde3ae49a 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -2,7 +2,7 @@ from abc import ABC from collections.abc import Sequence -from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin +from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin import attrs from marshmallow import INCLUDE, Schema, fields @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested: field_class, args, optional = cls._get_field_type_info(field_type) + # Resolve TypeVars to their bound type + if isinstance(field_class, TypeVar): + field_class = field_class.__bound__ + if attrs.has(field_class): if ABC in field_class.__bases__: return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 2dcb796d8..6ffc24871 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -84,7 +84,11 @@ def run(self) -> BaseArtifact: subtask.after_run() if isinstance(subtask.output, ListArtifact): - self.output = subtask.output[0] + first_artifact = subtask.output[0] + if isinstance(first_artifact, BaseArtifact): + self.output = first_artifact + else: + self.output = ErrorArtifact(f"Output is not an Artifact: {type(subtask.output[0])}") else: self.output = InfoArtifact("No tool output") except Exception as e: diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index bf7004ad0..c56682d2c 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -24,6 +24,12 @@ def test___add__(self): assert artifact.value[0].value == "foo" assert artifact.value[1].value == "bar" + def test___iter__(self): + assert [a.value for a in ListArtifact([TextArtifact("foo"), TextArtifact("bar")])] == ["foo", "bar"] + + def test_type_var(self): + assert ListArtifact[TextArtifact]([TextArtifact("foo")]).value[0].value == "foo" + def test_validate_value(self): with pytest.raises(ValueError): ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True) From ccb37f2d1d4b2a8ea62998cbc48aafd629b167d1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:02:52 -0700 Subject: [PATCH 04/52] Bring back CsvRowArtifact --- .ignore | 0 CHANGELOG.md | 4 +-- MIGRATION.md | 32 +++++-------------- griptape/artifacts/__init__.py | 2 ++ griptape/artifacts/csv_row_artifact.py | 28 ++++++++++++++++ griptape/artifacts/text_artifact.py | 11 ++----- .../extraction/csv_extraction_engine.py | 11 ++++--- griptape/loaders/csv_loader.py | 10 +++--- griptape/loaders/dataframe_loader.py | 10 +++--- griptape/loaders/sql_loader.py | 12 ++++--- tests/unit/artifacts/test_csv_row_artifact.py | 27 ++++++++++++++++ tests/unit/tasks/test_prompt_task.py | 4 +-- tests/unit/tools/test_sql_tool.py | 2 +- 13 files changed, 95 insertions(+), 58 deletions(-) delete mode 100644 .ignore create mode 100644 griptape/artifacts/csv_row_artifact.py create mode 100644 tests/unit/artifacts/test_csv_row_artifact.py diff --git a/.ignore b/.ignore deleted file mode 100644 index e69de29bb..000000000 diff --git a/CHANGELOG.md b/CHANGELOG.md index 13d6d8f14..ccf8084c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,14 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseArtifact.to_bytes()` method to convert an Artifact to bytes. ### Changed +- **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. -- **BREAKING**: Removed `CsvRowArtifact`. - **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. - **BREAKING**: Removed `AudioArtifact.media_type`. - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. +- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. @@ -47,7 +48,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed the `__all__` declaration from the `griptape.mixins` module. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. -- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAI compatible services. diff --git a/MIGRATION.md b/MIGRATION.md index ddf9b9c00..65584297b 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -36,39 +36,23 @@ audio_artifact = AudioArtifact( ) ``` -### Removed `CsvRowArtifact` +### Changed `CsvRowArtifact.value` from `dict` to `str`. -`CsvRowArtifact` has been removed. Use `TextArtifact` instead. +`CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead. #### Before ```python -CsvRowArtifact({"name": "John", "age": 30}) +artifact = CsvRowArtifact({"name": "John", "age": 30}) +print(artifact.value) # {"name": "John", "age": 30} +print(type(artifact.value)) # ``` #### After ```python -TextArtifact("name: John\nAge: 30") -``` - -### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types - -`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a tuple of `list[TextArtifact]` instead of `list[CsvRowArtifact]`. - -#### Before - -```python -results = CsvLoader().load(Path("people.csv").read_text()) - -print(results[0].value) # {"name": "John", "age": 30} -``` - -#### After -```python -results = CsvLoader().load(Path("people.csv").read_text()) - -print(results[0].value) # name: John\nAge: 30 -print(results[0].meta["row"]) # 0 +artifact = CsvRowArtifact({"name": "John", "age": 30}) +print(artifact.value) # name: John\nAge: 30 +print(type(artifact.value)) # ``` ### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 0e58a8a76..c0e5f767e 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -5,6 +5,7 @@ from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact +from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact @@ -20,6 +21,7 @@ "JsonArtifact", "BlobArtifact", "BooleanArtifact", + "CsvRowArtifact", "ListArtifact", "ImageArtifact", "AudioArtifact", diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py new file mode 100644 index 000000000..b1e2e7dfe --- /dev/null +++ b/griptape/artifacts/csv_row_artifact.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any + +from attrs import define, field + +from griptape.artifacts import BaseArtifact, TextArtifact + + +def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return "\n".join(f"{key}: {val}" for key, val in value.items()) + else: + return str(value) + + +@define +class CsvRowArtifact(TextArtifact): + """Stores a row of a CSV file. + + Attributes: + value: The row of the CSV file. If a dictionary is passed, the keys and values converted to a string. + """ + + value: str = field(converter=value_to_str, metadata={"serializable": True}) + + def __add__(self, other: BaseArtifact) -> TextArtifact: + return TextArtifact(self.value + "\n" + other.value) diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 31d20e532..7a3b62e3b 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -11,16 +11,9 @@ from griptape.tokenizers import BaseTokenizer -def value_to_str(value: Any) -> str: - if isinstance(value, dict): - return "\n".join(f"{key}: {val}" for key, val in value.items()) - else: - return str(value) - - @define class TextArtifact(BaseArtifact): - value: str = field(converter=value_to_str, metadata={"serializable": True}) + value: str = field(converter=str, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) embedding: Optional[list[float]] = field(default=None, kw_only=True) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 21d7820ac..9f6085da6 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import ListArtifact, TextArtifact +from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -32,26 +32,27 @@ def extract( self._extract_rec( cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], [], + rulesets=rulesets, ), item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: + def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(TextArtifact(dict(zip(column_names, [x.strip() for x in row])))) + rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row])))) return rows def _extract_rec( self, artifacts: list[TextArtifact], - rows: list[TextArtifact], + rows: list[CsvRowArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[TextArtifact]: + ) -> list[CsvRowArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( column_names=self.column_names, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 20ac237c5..6f6fae5f3 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,7 +6,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -19,7 +19,7 @@ class CsvLoader(BaseLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] if isinstance(source, bytes): @@ -28,7 +28,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)] + chunks = [CsvRowArtifact(row, meta={"row_num": row_num}) for row_num, row in enumerate(reader)] if self.embedding_driver: for chunk in chunks: @@ -44,8 +44,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[TextArtifact]]: + ) -> dict[str, list[CsvRowArtifact]]: return cast( - dict[str, list[TextArtifact]], + dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 5ecb35ecd..0b1ae1448 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -19,10 +19,10 @@ class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] - chunks = [TextArtifact(row) for row in source.to_dict(orient="records")] + chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] if self.embedding_driver: for chunk in chunks: @@ -33,8 +33,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: + return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 14320911e..418ff0baf 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts.text_artifact import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -16,11 +16,13 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] - chunks = [TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else [] + chunks = ( + [CsvRowArtifact(row.cells, meta={"row_num": row_num}) for row_num, row in enumerate(rows)] if rows else [] + ) if self.embedding_driver: for chunk in chunks: @@ -31,5 +33,5 @@ def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: + return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py new file mode 100644 index 000000000..459128525 --- /dev/null +++ b/tests/unit/artifacts/test_csv_row_artifact.py @@ -0,0 +1,27 @@ +from griptape.artifacts import CsvRowArtifact + + +class TestCsvRowArtifact: + def test_value_type_conversion(self): + assert CsvRowArtifact({"foo": "bar"}).value == "foo: bar" + assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == "foo: {'bar': 'baz'}" + assert CsvRowArtifact('{"foo": "bar"}').value == '{"foo": "bar"}' + + def test___add__(self): + assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == "test1: foo\ntest2: bar" + + def test_to_text(self): + assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}).to_text() == "test1: foo|bar\ntest2: 1" + + def test_to_dict(self): + assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == "test1: foo" + + def test_name(self): + artifact = CsvRowArtifact({}) + + assert artifact.name == artifact.id + assert CsvRowArtifact({}, name="bar").name == "bar" + + def test___bool__(self): + assert not bool(CsvRowArtifact({})) + assert bool(CsvRowArtifact({"foo": "bar"})) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 2d4456ce2..cfe853226 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -114,9 +114,9 @@ def test_input(self): assert task.input.value[1].width == 100 # default case - task = PromptTask({"default": "test"}) # pyright: ignore[reportArgumentType] + task = PromptTask({"default": "test"}) - assert task.input.value == "default: test" + assert task.input.value == str({"default": "test"}) def test_prompt_stack(self): task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")]) diff --git a/tests/unit/tools/test_sql_tool.py b/tests/unit/tools/test_sql_tool.py index ae81965b0..f60401b66 100644 --- a/tests/unit/tools/test_sql_tool.py +++ b/tests/unit/tools/test_sql_tool.py @@ -27,7 +27,7 @@ def test_execute_query(self, driver): assert len(result.value) == 1 assert result.value[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert result.value[0].meta["row"] == 0 + assert result.value[0].meta["row_num"] == 0 def test_execute_query_description(self, driver): client = SqlTool( From 1da14284c58ef7f2ed0b4550a6a7df5f4c598ee1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:35:45 -0700 Subject: [PATCH 05/52] Fix changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccf8084c1..fc7de9260 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. -- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. From a5bc8f8fac6f462897f3b72ecc641b097e75884e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 9 Sep 2024 12:17:12 -0700 Subject: [PATCH 06/52] Move converters into class --- griptape/artifacts/base_artifact.py | 3 ++- griptape/artifacts/blob_artifact.py | 16 ++++++++-------- griptape/artifacts/csv_row_artifact.py | 16 ++++++++-------- griptape/artifacts/json_artifact.py | 16 ++++++++-------- 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index 8f94adbfe..a6b93593c 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -24,6 +24,7 @@ class BaseArtifact(SerializableMixin, ABC): meta: The metadata associated with the Artifact. Defaults to an empty dictionary. name: The name of the Artifact. Defaults to the id. value: The value of the Artifact. + encoding: The encoding of the Artifact when converting to bytes. """ id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) @@ -47,7 +48,7 @@ def __len__(self) -> int: return len(self.value) def to_bytes(self) -> bytes: - return self.to_text().encode() + return self.to_text().encode(self.encoding) @abstractmethod def to_text(self) -> str: ... diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 9fe729a63..33ae45fa1 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -7,13 +7,6 @@ from griptape.artifacts import BaseArtifact -def value_to_bytes(value: Any) -> bytes: - if isinstance(value, bytes): - return value - else: - return str(value).encode() - - @define class BlobArtifact(BaseArtifact): """Stores arbitrary binary data. @@ -24,7 +17,7 @@ class BlobArtifact(BaseArtifact): encoding_error_handler: The error handler to use when converting the binary data to text. """ - value: bytes = field(converter=value_to_bytes, metadata={"serializable": True}) + value: bytes = field(converter=lambda value: BlobArtifact.value_to_bytes(value), metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) @@ -32,6 +25,13 @@ class BlobArtifact(BaseArtifact): def mime_type(self) -> str: return "application/octet-stream" + @classmethod + def value_to_bytes(cls, value: Any) -> bytes: + if isinstance(value, bytes): + return value + else: + return str(value).encode() + def to_bytes(self) -> bytes: return self.value diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py index b1e2e7dfe..eea01fc14 100644 --- a/griptape/artifacts/csv_row_artifact.py +++ b/griptape/artifacts/csv_row_artifact.py @@ -7,13 +7,6 @@ from griptape.artifacts import BaseArtifact, TextArtifact -def value_to_str(value: Any) -> str: - if isinstance(value, dict): - return "\n".join(f"{key}: {val}" for key, val in value.items()) - else: - return str(value) - - @define class CsvRowArtifact(TextArtifact): """Stores a row of a CSV file. @@ -22,7 +15,14 @@ class CsvRowArtifact(TextArtifact): value: The row of the CSV file. If a dictionary is passed, the keys and values converted to a string. """ - value: str = field(converter=value_to_str, metadata={"serializable": True}) + value: str = field(converter=lambda value: CsvRowArtifact.value_to_str(value), metadata={"serializable": True}) def __add__(self, other: BaseArtifact) -> TextArtifact: return TextArtifact(self.value + "\n" + other.value) + + @classmethod + def value_to_str(cls, value: Any) -> str: + if isinstance(value, dict): + return "\n".join(f"{key}: {val}" for key, val in value.items()) + else: + return str(value) diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index b3d45e4d9..57700afd5 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -10,13 +10,6 @@ Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] -def value_to_json(value: Any) -> Json: - if isinstance(value, str): - return json.loads(value) - else: - return json.loads(json.dumps(value)) - - @define class JsonArtifact(BaseArtifact): """Stores JSON data. @@ -25,7 +18,14 @@ class JsonArtifact(BaseArtifact): value: The JSON data. Values will automatically be converted to a JSON-compatible format. """ - value: Json = field(converter=value_to_json, metadata={"serializable": True}) + value: Json = field(converter=lambda value: JsonArtifact.value_to_json(value), metadata={"serializable": True}) + + @classmethod + def value_to_json(cls, value: Any) -> Json: + if isinstance(value, str): + return json.loads(value) + else: + return json.loads(json.dumps(value)) def to_text(self) -> str: return json.dumps(self.value) From 98e3aa7ccc1b2c33666cf141211b3d95428b7563 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 9 Sep 2024 12:20:53 -0700 Subject: [PATCH 07/52] Make image format required --- CHANGELOG.md | 1 + MIGRATION.md | 20 ++++++++++++++++++++ griptape/artifacts/image_artifact.py | 4 ++-- tests/unit/artifacts/test_list_artifact.py | 2 +- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc7de9260..61e0f0ae6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. +- **BREAKING**: `ImageArtifact.format` is now required. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. diff --git a/MIGRATION.md b/MIGRATION.md index 65584297b..08b4d9a7f 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -36,6 +36,26 @@ audio_artifact = AudioArtifact( ) ``` +### `ImageArtifact.format` is now required + +`ImageArtifact.format` is now a required parameter. Update any code that does not provide a `format` parameter. + +#### Before + +```python +image_artifact = ImageArtifact( + b"image_data" +) +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg" +) +``` + ### Changed `CsvRowArtifact.value` from `dict` to `str`. `CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead. diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index 5a90ae403..b51f6eaa5 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -13,13 +13,13 @@ class ImageArtifact(BlobArtifact): Attributes: value: The image data. - format: The format of the image data. Defaults to "png". + format: The format of the image data. width: The width of the image. height: The height of the image """ value: bytes = field(metadata={"serializable": True}) - format: str = field(default="png", kw_only=True, metadata={"serializable": True}) + format: str = field(kw_only=True, metadata={"serializable": True}) width: int = field(kw_only=True, metadata={"serializable": True}) height: int = field(kw_only=True, metadata={"serializable": True}) diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index c56682d2c..0d6faaa7b 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -39,7 +39,7 @@ def test_child_type(self): def test_is_type(self): assert ListArtifact([TextArtifact("foo")]).is_type(TextArtifact) - assert ListArtifact([ImageArtifact(b"", width=1234, height=1234)]).is_type(ImageArtifact) + assert ListArtifact([ImageArtifact(b"", width=1234, height=1234, format="png")]).is_type(ImageArtifact) def test_has_items(self): assert not ListArtifact().has_items() From 0eb93c385e1e15d370de6629dc66b262adc89311 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 9 Sep 2024 12:24:04 -0700 Subject: [PATCH 08/52] Raise error instead --- griptape/tasks/tool_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 6ffc24871..7ae63b902 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -88,7 +88,7 @@ def run(self) -> BaseArtifact: if isinstance(first_artifact, BaseArtifact): self.output = first_artifact else: - self.output = ErrorArtifact(f"Output is not an Artifact: {type(subtask.output[0])}") + raise ValueError(f"Output is not an Artifact: {type(first_artifact)}") else: self.output = InfoArtifact("No tool output") except Exception as e: From 69fac0b61d03d83e3d3ee965ddcd85787992e3f3 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 9 Sep 2024 12:24:50 -0700 Subject: [PATCH 09/52] Generate embeddings before clearing --- griptape/artifacts/text_artifact.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 7a3b62e3b..da8a7cd9a 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -31,11 +31,12 @@ def to_bytes(self) -> bytes: return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) def generate_embedding(self, driver: BaseEmbeddingDriver) -> list[float]: + embedding = driver.embed_string(str(self.value)) + if self.embedding is None: self.embedding = [] - self.embedding.clear() - self.embedding.extend(driver.embed_string(str(self.value))) + self.embedding.extend(embedding) return self.embedding From dca6c788b2a4918a821f3d5366839ef17952b070 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 9 Sep 2024 12:44:40 -0700 Subject: [PATCH 10/52] Refactor converters --- griptape/artifacts/audio_artifact.py | 6 ++---- griptape/artifacts/base_artifact.py | 6 ++++-- griptape/artifacts/blob_artifact.py | 4 ---- griptape/artifacts/image_artifact.py | 2 -- griptape/artifacts/text_artifact.py | 2 -- 5 files changed, 6 insertions(+), 14 deletions(-) diff --git a/griptape/artifacts/audio_artifact.py b/griptape/artifacts/audio_artifact.py index 77febf332..e9e38858a 100644 --- a/griptape/artifacts/audio_artifact.py +++ b/griptape/artifacts/audio_artifact.py @@ -2,19 +2,17 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BlobArtifact @define -class AudioArtifact(BaseArtifact): +class AudioArtifact(BlobArtifact): """Stores audio data. Attributes: - value: The audio data. format: The audio format, e.g. "wav" or "mp3". """ - value: bytes = field(metadata={"serializable": True}) format: str = field(kw_only=True, metadata={"serializable": True}) @property diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index a6b93593c..4321ae00b 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -24,7 +24,8 @@ class BaseArtifact(SerializableMixin, ABC): meta: The metadata associated with the Artifact. Defaults to an empty dictionary. name: The name of the Artifact. Defaults to the id. value: The value of the Artifact. - encoding: The encoding of the Artifact when converting to bytes. + encoding: The encoding to use when encoding/decoding the value. + encoding_error_handler: The error handler to use when encoding/decoding the value. """ id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) @@ -36,6 +37,7 @@ class BaseArtifact(SerializableMixin, ABC): metadata={"serializable": True}, ) value: Any = field() + encoding_error_handler: str = field(default="strict", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) def __str__(self) -> str: @@ -48,7 +50,7 @@ def __len__(self) -> int: return len(self.value) def to_bytes(self) -> bytes: - return self.to_text().encode(self.encoding) + return self.to_text().encode(encoding=self.encoding, errors=self.encoding_error_handler) @abstractmethod def to_text(self) -> str: ... diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 33ae45fa1..b881ee14e 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -13,13 +13,9 @@ class BlobArtifact(BaseArtifact): Attributes: value: The binary data. - encoding: The encoding to use when converting the binary data to text. - encoding_error_handler: The error handler to use when converting the binary data to text. """ value: bytes = field(converter=lambda value: BlobArtifact.value_to_bytes(value), metadata={"serializable": True}) - encoding: str = field(default="utf-8", kw_only=True) - encoding_error_handler: str = field(default="strict", kw_only=True) @property def mime_type(self) -> str: diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index b51f6eaa5..82c47c044 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -12,13 +12,11 @@ class ImageArtifact(BlobArtifact): """Stores image data. Attributes: - value: The image data. format: The format of the image data. width: The width of the image. height: The height of the image """ - value: bytes = field(metadata={"serializable": True}) format: str = field(kw_only=True, metadata={"serializable": True}) width: int = field(kw_only=True, metadata={"serializable": True}) height: int = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index da8a7cd9a..5f9b045ba 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -14,8 +14,6 @@ @define class TextArtifact(BaseArtifact): value: str = field(converter=str, metadata={"serializable": True}) - encoding: str = field(default="utf-8", kw_only=True) - encoding_error_handler: str = field(default="strict", kw_only=True) embedding: Optional[list[float]] = field(default=None, kw_only=True) def __add__(self, other: BaseArtifact) -> TextArtifact: From 50dfb460e56241f281d33ac8d6db575d191cc9d0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 10 Sep 2024 10:47:01 -0700 Subject: [PATCH 11/52] Move converter to lambda --- griptape/artifacts/blob_artifact.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index b881ee14e..fb464af62 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any - from attrs import define, field from griptape.artifacts import BaseArtifact @@ -15,19 +13,15 @@ class BlobArtifact(BaseArtifact): value: The binary data. """ - value: bytes = field(converter=lambda value: BlobArtifact.value_to_bytes(value), metadata={"serializable": True}) + value: bytes = field( + converter=lambda value: value if isinstance(value, bytes) else str(value).encode(), + metadata={"serializable": True}, + ) @property def mime_type(self) -> str: return "application/octet-stream" - @classmethod - def value_to_bytes(cls, value: Any) -> bytes: - if isinstance(value, bytes): - return value - else: - return str(value).encode() - def to_bytes(self) -> bytes: return self.value From 48dc409d420ca091cf9d360e92c70cac2fa22e40 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 10 Sep 2024 14:05:35 -0700 Subject: [PATCH 12/52] Use artifact encoding for images --- griptape/artifacts/image_artifact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index 82c47c044..52ac2fcf4 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -23,7 +23,7 @@ class ImageArtifact(BlobArtifact): @property def base64(self) -> str: - return base64.b64encode(self.value).decode("utf-8") + return base64.b64encode(self.value).decode(self.encoding) @property def mime_type(self) -> str: From ad801015724961a185dc6c9b1cc63844947677f2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 10 Sep 2024 14:16:18 -0700 Subject: [PATCH 13/52] Appease code cov --- griptape/artifacts/boolean_artifact.py | 2 +- griptape/artifacts/text_artifact.py | 3 --- tests/unit/artifacts/test_blob_artifact.py | 10 ++++++++++ tests/unit/artifacts/test_boolean_artifact.py | 11 +++++++++++ tests/unit/artifacts/test_text_artifact.py | 5 +++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index dd8bb1504..eb135824d 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -34,7 +34,7 @@ def __add__(self, other: BaseArtifact) -> BooleanArtifact: raise ValueError("Cannot add BooleanArtifact with other artifacts") def __eq__(self, value: object) -> bool: - return self.value is value + return self.value == value def to_text(self) -> str: return str(self.value).lower() diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 5f9b045ba..9623c8096 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -25,9 +25,6 @@ def __bool__(self) -> bool: def to_text(self) -> str: return self.value - def to_bytes(self) -> bytes: - return str(self.value).encode(encoding=self.encoding, errors=self.encoding_error_handler) - def generate_embedding(self, driver: BaseEmbeddingDriver) -> list[float]: embedding = driver.embed_string(str(self.value)) diff --git a/tests/unit/artifacts/test_blob_artifact.py b/tests/unit/artifacts/test_blob_artifact.py index d0f04b1e5..9db5e21f5 100644 --- a/tests/unit/artifacts/test_blob_artifact.py +++ b/tests/unit/artifacts/test_blob_artifact.py @@ -12,6 +12,9 @@ def test_value_type_conversion(self): def test_to_text(self): assert BlobArtifact(b"foobar", name="foobar.txt").to_text() == "foobar" + def test_to_bytes(self): + assert BlobArtifact(b"foo").to_bytes() == b"foo" + def test_to_text_encoding(self): assert ( BlobArtifact("ß".encode("ascii", errors="backslashreplace"), name="foobar.txt", encoding="ascii").to_text() @@ -50,6 +53,13 @@ def test_deserialization(self): def test_name(self): assert BlobArtifact(b"foo", name="bar").name == "bar" + def test_mime_type(self): + assert BlobArtifact(b"foo").mime_type == "application/octet-stream" + + def test___add__(self): + with pytest.raises(TypeError): + BlobArtifact(b"foo") + BlobArtifact(b"bar") + def test___bool__(self): assert not bool(BlobArtifact(b"")) assert bool(BlobArtifact(b"foo")) diff --git a/tests/unit/artifacts/test_boolean_artifact.py b/tests/unit/artifacts/test_boolean_artifact.py index 57bbf1662..6ed21608d 100644 --- a/tests/unit/artifacts/test_boolean_artifact.py +++ b/tests/unit/artifacts/test_boolean_artifact.py @@ -10,6 +10,7 @@ def test_parse_bool(self): assert BooleanArtifact.parse_bool("false").value is False assert BooleanArtifact.parse_bool("True").value is True assert BooleanArtifact.parse_bool("False").value is False + assert BooleanArtifact.parse_bool(True).value is True with pytest.raises(ValueError): BooleanArtifact.parse_bool("foo") @@ -35,3 +36,13 @@ def test_value_type_conversion(self): assert BooleanArtifact([]).value is False assert BooleanArtifact(False).value is False assert BooleanArtifact(True).value is True + + def test_to_text(self): + assert BooleanArtifact(True).to_text() == "true" + assert BooleanArtifact(False).to_text() == "false" + + def test__eq__(self): + assert BooleanArtifact(True) == BooleanArtifact(True) + assert BooleanArtifact(False) == BooleanArtifact(False) + assert BooleanArtifact(True) != BooleanArtifact(False) + assert BooleanArtifact(False) != BooleanArtifact(True) diff --git a/tests/unit/artifacts/test_text_artifact.py b/tests/unit/artifacts/test_text_artifact.py index 9e00e2d2e..dda256d27 100644 --- a/tests/unit/artifacts/test_text_artifact.py +++ b/tests/unit/artifacts/test_text_artifact.py @@ -18,6 +18,11 @@ def test___add__(self): def test_to_dict(self): assert TextArtifact("foobar").to_dict()["value"] == "foobar" + def test_to_bytes(self): + artifact = TextArtifact("foobar") + + assert artifact.to_bytes() == b"foobar" + def test_from_dict(self): assert BaseArtifact.from_dict(TextArtifact("foobar").to_dict()).value == "foobar" From beff800ce3d6d70a60bfb8c5f7348422a4c66cfc Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 09:38:46 -0700 Subject: [PATCH 14/52] Fix bad merge --- griptape/tasks/base_audio_generation_task.py | 2 +- griptape/tasks/base_image_generation_task.py | 2 +- griptape/tools/text_to_speech/tool.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index 866c9aa5c..91f7b7501 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -6,7 +6,7 @@ from attrs import define from griptape.configs import Defaults -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 57dc99506..326b2a551 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -10,7 +10,7 @@ from griptape.configs import Defaults from griptape.loaders import ImageLoader -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Rule, Ruleset from griptape.tasks import BaseTask diff --git a/griptape/tools/text_to_speech/tool.py b/griptape/tools/text_to_speech/tool.py index 23b82333a..aca259698 100644 --- a/griptape/tools/text_to_speech/tool.py +++ b/griptape/tools/text_to_speech/tool.py @@ -5,7 +5,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.mixins.media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity From 9bb093078974aba3217cb28503f83701bbdf539b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 10:17:55 -0700 Subject: [PATCH 15/52] Address not --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61e0f0ae6..ca273800e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- `BaseArtifact.to_bytes()` method to convert an Artifact to bytes. +- `BaseArtifact.to_bytes()` method to convert an Artifact value to bytes. ### Changed - **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`. From c76e5c82009266114c33f4a52a9e0570f8cb0c13 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 10:44:38 -0700 Subject: [PATCH 16/52] Revert image to_text as base64, move base64 to parent --- CHANGELOG.md | 2 +- griptape/artifacts/blob_artifact.py | 6 ++++++ griptape/artifacts/image_artifact.py | 8 +------- tests/unit/artifacts/test_base_artifact.py | 2 +- tests/unit/artifacts/test_image_artifact.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca273800e..bcd86cee5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `BaseArtifact.to_bytes()` method to convert an Artifact value to bytes. +- `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 strings. ### Changed - **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`. @@ -17,7 +18,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `AudioArtifact.media_type`. - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. -- **BREAKING**: `ImageArtifact.to_text()` now returns the base64 encoded image. - **BREAKING**: `ImageArtifact.format` is now required. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index fb464af62..7c814a052 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 + from attrs import define, field from griptape.artifacts import BaseArtifact @@ -18,6 +20,10 @@ class BlobArtifact(BaseArtifact): metadata={"serializable": True}, ) + @property + def base64(self) -> str: + return base64.b64encode(self.value).decode(self.encoding) + @property def mime_type(self) -> str: return "application/octet-stream" diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index 52ac2fcf4..616dc335a 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -1,7 +1,5 @@ from __future__ import annotations -import base64 - from attrs import define, field from griptape.artifacts import BlobArtifact @@ -21,10 +19,6 @@ class ImageArtifact(BlobArtifact): width: int = field(kw_only=True, metadata={"serializable": True}) height: int = field(kw_only=True, metadata={"serializable": True}) - @property - def base64(self) -> str: - return base64.b64encode(self.value).decode(self.encoding) - @property def mime_type(self) -> str: return f"image/{self.format}" @@ -33,4 +27,4 @@ def to_bytes(self) -> bytes: return self.value def to_text(self) -> str: - return self.base64 + return f"Image, format: {self.format}, size: {len(self.value)} bytes" diff --git a/tests/unit/artifacts/test_base_artifact.py b/tests/unit/artifacts/test_base_artifact.py index e2060879f..28e2761a8 100644 --- a/tests/unit/artifacts/test_base_artifact.py +++ b/tests/unit/artifacts/test_base_artifact.py @@ -59,7 +59,7 @@ def test_image_artifact_from_dict(self): artifact = BaseArtifact.from_dict(dict_value) assert isinstance(artifact, ImageArtifact) - assert artifact.to_text() == "aW1hZ2UgZGF0YQ==" + assert artifact.to_text() == "Image, format: png, size: 10 bytes" assert artifact.value == b"image data" def test_unsupported_from_dict(self): diff --git a/tests/unit/artifacts/test_image_artifact.py b/tests/unit/artifacts/test_image_artifact.py index b048b3372..a632953ae 100644 --- a/tests/unit/artifacts/test_image_artifact.py +++ b/tests/unit/artifacts/test_image_artifact.py @@ -15,7 +15,7 @@ def image_artifact(self): ) def test_to_text(self, image_artifact: ImageArtifact): - assert image_artifact.to_text() == "c29tZSBiaW5hcnkgcG5nIGltYWdlIGRhdGE=" + assert image_artifact.to_text() == "Image, format: png, size: 26 bytes" def test_to_dict(self, image_artifact: ImageArtifact): image_dict = image_artifact.to_dict() From 27609b3f36901045a72546a19a1fe3be2390c416 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 11:54:54 -0700 Subject: [PATCH 17/52] Fix changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcd86cee5..db8d3f6cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. -- Passing a dictionary as the value to `TextArtifact` will convert to a key-value formatted string. +- Passing a dictionary as the value to `CsvRowArtifact` will convert to a key-value formatted string. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. - Generic type support to `ListArtifact`. - Iteration support to `ListArtifact`. From f00a35d9ac58e46e40b30d224e84f59878f0081f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 16:51:36 -0700 Subject: [PATCH 18/52] Remove CsvRowArtifact for final time --- CHANGELOG.md | 6 +-- MIGRATION.md | 42 +++++++++++++++++-- griptape/artifacts/__init__.py | 2 - .../extraction/csv_extraction_engine.py | 15 ++++--- griptape/loaders/csv_loader.py | 15 ++++--- griptape/loaders/dataframe_loader.py | 15 ++++--- griptape/loaders/sql_loader.py | 17 ++++---- tests/unit/artifacts/test_csv_row_artifact.py | 27 ------------ tests/unit/loaders/test_csv_loader.py | 11 +++++ tests/unit/tools/test_sql_tool.py | 1 - 10 files changed, 88 insertions(+), 63 deletions(-) delete mode 100644 tests/unit/artifacts/test_csv_row_artifact.py diff --git a/CHANGELOG.md b/CHANGELOG.md index db8d3f6cd..23e5366aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,11 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- `BaseArtifact.to_bytes()` method to convert an Artifact value to bytes. +- `BaseArtifact.to_bytes()` method to convert an Artifact's value to bytes. - `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 strings. +- `CsvLoader`/`SqlLoader`/`DataframeLoader` `formatter_fn` field for customizing how SQL results are formatted into `TextArtifact`s. ### Changed -- **BREAKING**: Changed `CsvRowArtifact.value` from `dict` to `str`. +- **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. - **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. @@ -22,7 +23,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. -- Passing a dictionary as the value to `CsvRowArtifact` will convert to a key-value formatted string. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. - Generic type support to `ListArtifact`. - Iteration support to `ListArtifact`. diff --git a/MIGRATION.md b/MIGRATION.md index 08b4d9a7f..016a93f03 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -56,9 +56,9 @@ image_artifact = ImageArtifact( ) ``` -### Changed `CsvRowArtifact.value` from `dict` to `str`. +### Removed `CsvRowArtifact` -`CsvRowArtifact`'s `value` is now a `str` instead of a `dict`. Update any logic that expects `dict` to handle `str` instead. +`CsvRowArtifact` has been removed. Use `TextArtifact` instead. #### Before @@ -70,11 +70,45 @@ print(type(artifact.value)) # #### After ```python -artifact = CsvRowArtifact({"name": "John", "age": 30}) -print(artifact.value) # name: John\nAge: 30 +artifact = TextArtifact("name: John\nage: 30") +print(artifact.value) # name: John\nage: 30 print(type(artifact.value)) # ``` +If you require storing a dictionary as an Artifact, you can use `GenericArtifact` instead. + +### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types + +`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`. + +If you require a dictionary, set a custom `formatter_fn` and then parse the text to a dictionary. + +#### Before + +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # +``` + +#### After +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # name: John\nAge: 30 +print(type(results[0].value)) # + +# Customize formatter_fn +results = CsvLoader(formatter_fn=lambda x: json.dumps(x)).load(Path("people.csv").read_text()) +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # + +dict_results = [json.loads(result.value) for result in results] +print(dict_results[0]) # {"name": "John", "age": 30} +print(type(dict_results[0])) # +``` + ### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` `ImageArtifact.prompt` and `ImageArtifact.model` have been moved to `ImageArtifact.meta`. diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index c0e5f767e..0e58a8a76 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -5,7 +5,6 @@ from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact -from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact @@ -21,7 +20,6 @@ "JsonArtifact", "BlobArtifact", "BooleanArtifact", - "CsvRowArtifact", "ListArtifact", "ImageArtifact", "AudioArtifact", diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 9f6085da6..cf8f164cc 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,11 +2,11 @@ import csv import io -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, cast from attrs import Factory, define, field -from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -20,6 +20,9 @@ class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(default=Factory(list), kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) + formatter_fn: Callable[[Any], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) def extract( self, @@ -37,22 +40,22 @@ def extract( item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: + def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(CsvRowArtifact(dict(zip(column_names, [x.strip() for x in row])))) + rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row]))))) return rows def _extract_rec( self, artifacts: list[TextArtifact], - rows: list[CsvRowArtifact], + rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[CsvRowArtifact]: + ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( column_names=self.column_names, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 6f6fae5f3..bcf7029d4 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,11 +2,11 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -18,8 +18,11 @@ class CsvLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: artifacts = [] if isinstance(source, bytes): @@ -28,7 +31,7 @@ def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [CsvRowArtifact(row, meta={"row_num": row_num}) for row_num, row in enumerate(reader)] + chunks = [TextArtifact(self.formatter_fn(row)) for row in reader] if self.embedding_driver: for chunk in chunks: @@ -44,8 +47,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[CsvRowArtifact]]: + ) -> dict[str, list[TextArtifact]]: return cast( - dict[str, list[CsvRowArtifact]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 0b1ae1448..30d705676 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -18,11 +18,14 @@ @define class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: artifacts = [] - chunks = [CsvRowArtifact(row) for row in source.to_dict(orient="records")] + chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")] if self.embedding_driver: for chunk in chunks: @@ -33,8 +36,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 418ff0baf..105f585cb 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact +from griptape.artifacts import TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -15,14 +15,15 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) - def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: + def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] - chunks = ( - [CsvRowArtifact(row.cells, meta={"row_num": row_num}) for row_num, row in enumerate(rows)] if rows else [] - ) + chunks = [TextArtifact(self.formatter_fn(row.cells)) for row in rows] if rows else [] if self.embedding_driver: for chunk in chunks: @@ -33,5 +34,5 @@ def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: + return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/tests/unit/artifacts/test_csv_row_artifact.py b/tests/unit/artifacts/test_csv_row_artifact.py deleted file mode 100644 index 459128525..000000000 --- a/tests/unit/artifacts/test_csv_row_artifact.py +++ /dev/null @@ -1,27 +0,0 @@ -from griptape.artifacts import CsvRowArtifact - - -class TestCsvRowArtifact: - def test_value_type_conversion(self): - assert CsvRowArtifact({"foo": "bar"}).value == "foo: bar" - assert CsvRowArtifact({"foo": {"bar": "baz"}}).value == "foo: {'bar': 'baz'}" - assert CsvRowArtifact('{"foo": "bar"}').value == '{"foo": "bar"}' - - def test___add__(self): - assert (CsvRowArtifact({"test1": "foo"}) + CsvRowArtifact({"test2": "bar"})).value == "test1: foo\ntest2: bar" - - def test_to_text(self): - assert CsvRowArtifact({"test1": "foo|bar", "test2": 1}).to_text() == "test1: foo|bar\ntest2: 1" - - def test_to_dict(self): - assert CsvRowArtifact({"test1": "foo"}).to_dict()["value"] == "test1: foo" - - def test_name(self): - artifact = CsvRowArtifact({}) - - assert artifact.name == artifact.id - assert CsvRowArtifact({}, name="bar").name == "bar" - - def test___bool__(self): - assert not bool(CsvRowArtifact({})) - assert bool(CsvRowArtifact({"foo": "bar"})) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index d08b93919..7af409152 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,3 +1,5 @@ +import json + import pytest from griptape.loaders.csv_loader import CsvLoader @@ -55,3 +57,12 @@ def test_load_collection(self, loader, create_source): assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] + + def test_formatter_fn(self, loader, create_source): + loader.formatter_fn = lambda value: json.dumps(value) + source = create_source("test-1.csv") + + artifacts = loader.load(source) + + assert len(artifacts) == 10 + assert artifacts[0].value == '{"Foo": "foo1", "Bar": "bar1"}' diff --git a/tests/unit/tools/test_sql_tool.py b/tests/unit/tools/test_sql_tool.py index f60401b66..061b31f4d 100644 --- a/tests/unit/tools/test_sql_tool.py +++ b/tests/unit/tools/test_sql_tool.py @@ -27,7 +27,6 @@ def test_execute_query(self, driver): assert len(result.value) == 1 assert result.value[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert result.value[0].meta["row_num"] == 0 def test_execute_query_description(self, driver): client = SqlTool( From 83744cfe60a107f4c3c510395871e03a87581cf3 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Sep 2024 08:48:52 -0700 Subject: [PATCH 19/52] Fix feedback --- CHANGELOG.md | 2 +- griptape/engines/extraction/csv_extraction_engine.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23e5366aa..4516c59b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `BaseArtifact.to_bytes()` method to convert an Artifact's value to bytes. -- `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 strings. +- `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 string. - `CsvLoader`/`SqlLoader`/`DataframeLoader` `formatter_fn` field for customizing how SQL results are formatted into `TextArtifact`s. ### Changed diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index cf8f164cc..e7be73c73 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,7 +2,7 @@ import csv import io -from typing import TYPE_CHECKING, Any, Callable, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from attrs import Factory, define, field @@ -20,7 +20,7 @@ class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(default=Factory(list), kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) - formatter_fn: Callable[[Any], str] = field( + formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) From de6e1f33e4ec8eb7d6108582beebe7ab6b650b97 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 27 Aug 2024 11:22:51 -0700 Subject: [PATCH 20/52] Refactor Artifacts --- .ignore | 0 griptape/artifacts/__init__.py | 11 ++++++++--- griptape/artifacts/base_system_artifact.py | 10 ++++++++++ griptape/artifacts/blob_artifact.py | 7 +++++++ griptape/artifacts/error_artifact.py | 2 +- griptape/artifacts/info_artifact.py | 2 +- griptape/artifacts/json_artifact.py | 7 +++++++ griptape/artifacts/list_artifact.py | 2 +- griptape/artifacts/text_artifact.py | 9 ++++++++- tests/unit/tasks/test_prompt_task.py | 4 ++-- 10 files changed, 45 insertions(+), 9 deletions(-) create mode 100644 .ignore create mode 100644 griptape/artifacts/base_system_artifact.py diff --git a/.ignore b/.ignore new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 0e58a8a76..8acf2642b 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,19 +1,24 @@ from .base_artifact import BaseArtifact -from .error_artifact import ErrorArtifact -from .info_artifact import InfoArtifact +from .base_system_artifact import BaseSystemArtifact + from .text_artifact import TextArtifact -from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact +from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact +from .error_artifact import ErrorArtifact +from .info_artifact import InfoArtifact +from .list_artifact import ListArtifact + __all__ = [ "BaseArtifact", + "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py new file mode 100644 index 000000000..e5e7bfab6 --- /dev/null +++ b/griptape/artifacts/base_system_artifact.py @@ -0,0 +1,10 @@ +from abc import ABC + +from griptape.artifacts import BaseArtifact + + +class BaseSystemArtifact(BaseArtifact, ABC): + """Serves as the base class for all Artifacts specific to Griptape.""" + + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 7c814a052..1d20acadb 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -7,6 +7,13 @@ from griptape.artifacts import BaseArtifact +def value_to_bytes(value: Any) -> bytes: + if isinstance(value, bytes): + return value + else: + return str(value).encode() + + @define class BlobArtifact(BaseArtifact): """Stores arbitrary binary data. diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index 27e6a37ab..a3a8e3bb7 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 3391554e9..44d42658a 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,7 +2,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index 57700afd5..0f788c94f 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -10,6 +10,13 @@ Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] +def value_to_json(value: Any) -> Json: + if isinstance(value, str): + return json.loads(value) + else: + return json.loads(json.dumps(value)) + + @define class JsonArtifact(BaseArtifact): """Stores JSON data. diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 0e6f81ca5..18f4e184f 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,7 +4,7 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseArtifact, BaseSystemArtifact if TYPE_CHECKING: from collections.abc import Iterator, Sequence diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 9623c8096..11172e73d 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import define, field @@ -11,6 +11,13 @@ from griptape.tokenizers import BaseTokenizer +def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return "\n".join(f"{key}: {val}" for key, val in value.items()) + else: + return str(value) + + @define class TextArtifact(BaseArtifact): value: str = field(converter=str, metadata={"serializable": True}) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index cfe853226..2d4456ce2 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -114,9 +114,9 @@ def test_input(self): assert task.input.value[1].width == 100 # default case - task = PromptTask({"default": "test"}) + task = PromptTask({"default": "test"}) # pyright: ignore[reportArgumentType] - assert task.input.value == str({"default": "test"}) + assert task.input.value == "default: test" def test_prompt_stack(self): task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")]) From b562da09b498bf7dd38dcd362e2b365c19de0481 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 13:26:13 -0700 Subject: [PATCH 21/52] Artifacts are just data --- griptape/artifacts/__init__.py | 11 +++-------- griptape/artifacts/base_system_artifact.py | 10 ---------- griptape/artifacts/blob_artifact.py | 4 ++++ griptape/artifacts/error_artifact.py | 2 +- griptape/artifacts/info_artifact.py | 2 +- griptape/artifacts/list_artifact.py | 2 +- 6 files changed, 10 insertions(+), 21 deletions(-) delete mode 100644 griptape/artifacts/base_system_artifact.py diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 8acf2642b..0e58a8a76 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,24 +1,19 @@ from .base_artifact import BaseArtifact -from .base_system_artifact import BaseSystemArtifact - +from .error_artifact import ErrorArtifact +from .info_artifact import InfoArtifact from .text_artifact import TextArtifact +from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact -from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact -from .error_artifact import ErrorArtifact -from .info_artifact import InfoArtifact -from .list_artifact import ListArtifact - __all__ = [ "BaseArtifact", - "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py deleted file mode 100644 index e5e7bfab6..000000000 --- a/griptape/artifacts/base_system_artifact.py +++ /dev/null @@ -1,10 +0,0 @@ -from abc import ABC - -from griptape.artifacts import BaseArtifact - - -class BaseSystemArtifact(BaseArtifact, ABC): - """Serves as the base class for all Artifacts specific to Griptape.""" - - def to_text(self) -> str: - return self.value diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 1d20acadb..7f7e4915a 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -35,6 +35,10 @@ def base64(self) -> str: def mime_type(self) -> str: return "application/octet-stream" + @property + def mime_type(self) -> str: + return "application/octet-stream" + def to_bytes(self) -> bytes: return self.value diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index a3a8e3bb7..27e6a37ab 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseSystemArtifact +from griptape.artifacts import BaseArtifact @define diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 44d42658a..3391554e9 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,7 +2,7 @@ from attrs import define, field -from griptape.artifacts import BaseSystemArtifact +from griptape.artifacts import BaseArtifact @define diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 18f4e184f..0e6f81ca5 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,7 +4,7 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact, BaseSystemArtifact +from griptape.artifacts import BaseArtifact if TYPE_CHECKING: from collections.abc import Iterator, Sequence From 2cba4eda3286c357766c98b9ffec5a23ce5be797 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 21 Aug 2024 09:23:51 -0700 Subject: [PATCH 22/52] Improve ListArtifact Make type covariant --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4516c59b6..07c73a21a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed the `__all__` declaration from the `griptape.mixins` module. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. +- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAI compatible services. From 72b61330ee614cc2539d0d878e7f63de04606f4a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:02:52 -0700 Subject: [PATCH 23/52] Bring back CsvRowArtifact --- .ignore | 0 CHANGELOG.md | 1 - griptape/artifacts/__init__.py | 2 ++ griptape/artifacts/text_artifact.py | 9 +-------- griptape/engines/extraction/csv_extraction_engine.py | 8 ++++---- griptape/loaders/csv_loader.py | 8 ++++---- griptape/loaders/dataframe_loader.py | 8 ++++---- griptape/loaders/sql_loader.py | 6 +++--- tests/unit/tasks/test_prompt_task.py | 4 ++-- 9 files changed, 20 insertions(+), 26 deletions(-) delete mode 100644 .ignore diff --git a/.ignore b/.ignore deleted file mode 100644 index e69de29bb..000000000 diff --git a/CHANGELOG.md b/CHANGELOG.md index 07c73a21a..4516c59b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,7 +48,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed the `__all__` declaration from the `griptape.mixins` module. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. -- `BaseConversationMemory.prompt_driver` for use with autopruning. ### Fixed - Parsing streaming response with some OpenAI compatible services. diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 0e58a8a76..c0e5f767e 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -5,6 +5,7 @@ from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact +from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact @@ -20,6 +21,7 @@ "JsonArtifact", "BlobArtifact", "BooleanArtifact", + "CsvRowArtifact", "ListArtifact", "ImageArtifact", "AudioArtifact", diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 11172e73d..9623c8096 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -11,13 +11,6 @@ from griptape.tokenizers import BaseTokenizer -def value_to_str(value: Any) -> str: - if isinstance(value, dict): - return "\n".join(f"{key}: {val}" for key, val in value.items()) - else: - return str(value) - - @define class TextArtifact(BaseArtifact): value: str = field(converter=str, metadata={"serializable": True}) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index e7be73c73..973171c5f 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import ListArtifact, TextArtifact +from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -40,7 +40,7 @@ def extract( item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: + def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: rows = [] with io.StringIO(text) as f: @@ -52,10 +52,10 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif def _extract_rec( self, artifacts: list[TextArtifact], - rows: list[TextArtifact], + rows: list[CsvRowArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[TextArtifact]: + ) -> list[CsvRowArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( column_names=self.column_names, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index bcf7029d4..cd0302586 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,7 +6,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -22,7 +22,7 @@ class CsvLoader(BaseLoader): default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] if isinstance(source, bytes): @@ -47,8 +47,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[TextArtifact]]: + ) -> dict[str, list[CsvRowArtifact]]: return cast( - dict[str, list[TextArtifact]], + dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 30d705676..9131a7682 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -22,7 +22,7 @@ class DataFrameLoader(BaseLoader): default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")] @@ -36,8 +36,8 @@ def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: return artifacts - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: + return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 105f585cb..9cff7782b 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -19,7 +19,7 @@ class SqlLoader(BaseLoader): default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: rows = self.sql_driver.execute_query(source) artifacts = [] @@ -34,5 +34,5 @@ def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: return artifacts - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: + return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 2d4456ce2..cfe853226 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -114,9 +114,9 @@ def test_input(self): assert task.input.value[1].width == 100 # default case - task = PromptTask({"default": "test"}) # pyright: ignore[reportArgumentType] + task = PromptTask({"default": "test"}) - assert task.input.value == "default: test" + assert task.input.value == str({"default": "test"}) def test_prompt_stack(self): task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")]) From 8a415a23550e72d7f7fced874127fb4f87b17790 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 13:26:13 -0700 Subject: [PATCH 24/52] Artifacts are just data --- griptape/artifacts/text_artifact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 9623c8096..735256313 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -33,7 +33,7 @@ def generate_embedding(self, driver: BaseEmbeddingDriver) -> list[float]: self.embedding.clear() self.embedding.extend(embedding) - return self.embedding + return self._embedding def token_count(self, tokenizer: BaseTokenizer) -> int: return tokenizer.count_tokens(str(self.value)) From ad19c5aab0c089672742ec165cb041507a12eb63 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 27 Aug 2024 11:22:51 -0700 Subject: [PATCH 25/52] Refactor Artifacts --- griptape/artifacts/__init__.py | 11 ++++++++--- griptape/artifacts/base_system_artifact.py | 10 ++++++++++ griptape/artifacts/error_artifact.py | 2 +- griptape/artifacts/info_artifact.py | 2 +- griptape/artifacts/list_artifact.py | 2 +- 5 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 griptape/artifacts/base_system_artifact.py diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index c0e5f767e..138ce43fc 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,20 +1,25 @@ from .base_artifact import BaseArtifact -from .error_artifact import ErrorArtifact -from .info_artifact import InfoArtifact +from .base_system_artifact import BaseSystemArtifact + from .text_artifact import TextArtifact -from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact +from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact +from .error_artifact import ErrorArtifact +from .info_artifact import InfoArtifact +from .list_artifact import ListArtifact + __all__ = [ "BaseArtifact", + "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py new file mode 100644 index 000000000..e5e7bfab6 --- /dev/null +++ b/griptape/artifacts/base_system_artifact.py @@ -0,0 +1,10 @@ +from abc import ABC + +from griptape.artifacts import BaseArtifact + + +class BaseSystemArtifact(BaseArtifact, ABC): + """Serves as the base class for all Artifacts specific to Griptape.""" + + def to_text(self) -> str: + return self.value diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index 27e6a37ab..a3a8e3bb7 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 3391554e9..44d42658a 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,7 +2,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseSystemArtifact @define diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 0e6f81ca5..18f4e184f 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,7 +4,7 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseArtifact, BaseSystemArtifact if TYPE_CHECKING: from collections.abc import Iterator, Sequence From 06d40a72a7799ce12d27c22c665e21c164baa443 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 27 Aug 2024 11:22:51 -0700 Subject: [PATCH 26/52] Standardize loaders interface --- .../src/load_query_and_chat_marqo_1.py | 4 +- docs/examples/src/query_webpage_1.py | 6 +- docs/examples/src/query_webpage_astra_db_1.py | 7 +- docs/examples/src/talk_to_a_pdf_1.py | 4 +- docs/examples/src/talk_to_a_webpage_1.py | 4 +- docs/griptape-framework/data/loaders.md | 13 ---- docs/griptape-framework/data/src/loaders_1.py | 11 ++-- .../griptape-framework/data/src/loaders_10.py | 7 +- docs/griptape-framework/data/src/loaders_3.py | 13 ++-- docs/griptape-framework/data/src/loaders_4.py | 13 ---- docs/griptape-framework/data/src/loaders_7.py | 8 +-- docs/griptape-framework/data/src/loaders_8.py | 13 ++-- .../drivers/src/vector_store_drivers_1.py | 6 +- .../drivers/src/vector_store_drivers_10.py | 6 +- .../drivers/src/vector_store_drivers_11.py | 6 +- .../drivers/src/vector_store_drivers_3.py | 6 +- .../drivers/src/vector_store_drivers_4.py | 6 +- .../drivers/src/vector_store_drivers_5.py | 10 ++- .../drivers/src/vector_store_drivers_6.py | 6 +- .../drivers/src/vector_store_drivers_7.py | 6 +- .../drivers/src/vector_store_drivers_8.py | 6 +- .../drivers/src/vector_store_drivers_9.py | 6 +- .../engines/src/audio_engines_2.py | 3 +- .../engines/src/rag_engines_1.py | 7 +- .../engines/src/summary_engines_1.py | 6 +- .../structures/src/tasks_18.py | 3 +- .../official-tools/src/vector_store_tool_1.py | 4 +- griptape/artifacts/__init__.py | 3 + griptape/artifacts/list_artifact.py | 2 + griptape/artifacts/table_artifact.py | 41 ++++++++++++ .../file_manager/base_file_manager_driver.py | 51 ++++----------- .../file_manager/local_file_manager_driver.py | 8 +-- .../web_scraper/base_web_scraper_driver.py | 10 ++- .../markdownify_web_scraper_driver.py | 65 ++++++++++--------- .../web_scraper/proxy_web_scraper_driver.py | 8 ++- .../trafilatura_web_scraper_driver.py | 11 +++- .../text_loader_retrieval_rag_module.py | 9 ++- griptape/loaders/__init__.py | 12 ++-- griptape/loaders/audio_loader.py | 16 ++--- griptape/loaders/base_file_loader.py | 24 +++++++ griptape/loaders/base_loader.py | 21 ++++-- griptape/loaders/base_text_loader.py | 65 ------------------- griptape/loaders/blob_loader.py | 10 +-- griptape/loaders/csv_loader.py | 42 ++---------- griptape/loaders/dataframe_loader.py | 45 ------------- griptape/loaders/email_loader.py | 35 ++++++---- griptape/loaders/image_loader.py | 31 ++------- griptape/loaders/pdf_loader.py | 30 ++++----- griptape/loaders/sql_loader.py | 31 +++------ griptape/loaders/text_loader.py | 55 ++++------------ griptape/loaders/web_loader.py | 21 +++--- griptape/tools/file_manager/tool.py | 23 ++++++- griptape/tools/sql/tool.py | 2 +- griptape/tools/web_scraper/tool.py | 6 +- griptape/utils/__init__.py | 5 +- griptape/utils/file_utils.py | 46 ++++--------- poetry.lock | 25 +++---- pyproject.toml | 5 +- .../test_amazon_s3_file_manager_driver.py | 29 +++------ .../test_local_file_manager_driver.py | 35 +++------- ...er.py => test_base_vector_store_driver.py} | 5 +- .../vector/test_local_vector_store_driver.py | 4 +- ...st_persistent_local_vector_store_driver.py | 4 +- .../test_text_loader_retrieval_rag_module.py | 2 +- tests/unit/loaders/conftest.py | 9 +-- tests/unit/loaders/test_audio_loader.py | 7 +- tests/unit/loaders/test_blob_loader.py | 2 +- tests/unit/loaders/test_csv_loader.py | 22 ++----- tests/unit/loaders/test_email_loader.py | 24 +++---- tests/unit/loaders/test_image_loader.py | 6 +- tests/unit/loaders/test_pdf_loader.py | 23 +++---- tests/unit/loaders/test_sql_loader.py | 15 ++--- tests/unit/loaders/test_text_loader.py | 22 ++----- tests/unit/loaders/test_web_loader.py | 14 ++-- tests/unit/tools/test_file_manager.py | 20 ++---- tests/unit/utils/test_file_utils.py | 51 --------------- 76 files changed, 484 insertions(+), 758 deletions(-) delete mode 100644 docs/griptape-framework/data/src/loaders_4.py create mode 100644 griptape/artifacts/table_artifact.py create mode 100644 griptape/loaders/base_file_loader.py delete mode 100644 griptape/loaders/base_text_loader.py delete mode 100644 griptape/loaders/dataframe_loader.py rename tests/unit/drivers/vector/{test_base_local_vector_store_driver.py => test_base_vector_store_driver.py} (95%) delete mode 100644 tests/unit/utils/test_file_utils.py diff --git a/docs/examples/src/load_query_and_chat_marqo_1.py b/docs/examples/src/load_query_and_chat_marqo_1.py index cdcb376bb..f318abb20 100644 --- a/docs/examples/src/load_query_and_chat_marqo_1.py +++ b/docs/examples/src/load_query_and_chat_marqo_1.py @@ -1,6 +1,7 @@ import os from griptape import utils +from griptape.chunkers import TextChunker from griptape.drivers import MarqoVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -25,11 +26,12 @@ # Load artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) # Upsert the artifacts into the vector store vector_store.upsert_text_artifacts( { - namespace: artifacts, + namespace: chunks, } ) diff --git a/docs/examples/src/query_webpage_1.py b/docs/examples/src/query_webpage_1.py index b9e3286d6..9cc033566 100644 --- a/docs/examples/src/query_webpage_1.py +++ b/docs/examples/src/query_webpage_1.py @@ -1,13 +1,15 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])) -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) -for a in artifacts: +for a in chunks: vector_store.upsert_text_artifact(a, namespace="griptape") results = vector_store.query("creativity", count=3, namespace="griptape") diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 4590a6b59..aaf5e2fcc 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import ( AstraDbVectorStoreDriver, OpenAiChatPromptDriver, @@ -43,9 +44,9 @@ ), ) -artifacts = WebLoader(max_tokens=256).load(input_blogpost) - -vector_store_driver.upsert_text_artifacts({namespace: artifacts}) +artifacts = WebLoader().load(input_blogpost) +chunks = TextChunker().chunk(artifacts) +vector_store_driver.upsert_text_artifacts({namespace: chunks}) rag_tool = RagTool( description="A DataStax blog post", diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index 3c29f4c74..aa2b1f153 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -1,5 +1,6 @@ import requests +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -31,8 +32,9 @@ ) artifacts = PdfLoader().load(response.content) +chunks = TextChunker().chunk(artifacts) -vector_store.upsert_text_artifacts({namespace: artifacts}) +vector_store.upsert_text_artifacts({namespace: chunks}) agent = Agent(tools=[rag_tool]) diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 3e973da2d..5414c8769 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -26,8 +27,9 @@ ) artifacts = WebLoader().load("https://en.wikipedia.org/wiki/Physics") +chunks = TextChunker().chunk(artifacts) -vector_store_driver.upsert_text_artifacts({namespace: artifacts}) +vector_store_driver.upsert_text_artifacts({namespace: chunks}) rag_tool = RagTool( description="Contains information about physics. " "Use it to answer any physics-related questions.", diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 0c0fc3ead..f1826e67a 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -36,19 +36,6 @@ Can be used to load CSV files into [TextArtifact](../../reference/griptape/artif --8<-- "docs/griptape-framework/data/src/loaders_3.py" ``` - -## DataFrame - -!!! info - This driver requires the `loaders-dataframe` [extra](../index.md#extras). - -Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: - -```python ---8<-- "docs/griptape-framework/data/src/loaders_4.py" -``` - - ## Text Used to load arbitrary text and text files: diff --git a/docs/griptape-framework/data/src/loaders_1.py b/docs/griptape-framework/data/src/loaders_1.py index 2b7f31613..3732d8ac6 100644 --- a/docs/griptape-framework/data/src/loaders_1.py +++ b/docs/griptape-framework/data/src/loaders_1.py @@ -2,18 +2,15 @@ from pathlib import Path from griptape.loaders import PdfLoader -from griptape.utils import load_file, load_files urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "attention.pdf") # Load a single PDF file -PdfLoader().load(Path("attention.pdf").read_bytes()) -# You can also use the load_file utility function -PdfLoader().load(load_file("attention.pdf")) +PdfLoader().load("attention.pdf") +# You can also pass a Path object +PdfLoader().load(Path("attention.pdf")) urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "CoT.pdf") # Load multiple PDF files -PdfLoader().load_collection([Path("attention.pdf").read_bytes(), Path("CoT.pdf").read_bytes()]) -# You can also use the load_files utility function -PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) +PdfLoader().load_collection([Path("attention.pdf"), Path("CoT.pdf")]) diff --git a/docs/griptape-framework/data/src/loaders_10.py b/docs/griptape-framework/data/src/loaders_10.py index 42a64c40e..723fbcdf1 100644 --- a/docs/griptape-framework/data/src/loaders_10.py +++ b/docs/griptape-framework/data/src/loaders_10.py @@ -1,10 +1,9 @@ from pathlib import Path from griptape.loaders import AudioLoader -from griptape.utils import load_file # Load an image from disk -audio_artifact = AudioLoader().load(Path("tests/resources/sentences.wav").read_bytes()) +AudioLoader().load("tests/resources/sentences.wav") -# You can also use the load_file utility function -AudioLoader().load(load_file("tests/resources/sentences.wav")) +# You can also pass a Path object +AudioLoader().load(Path("tests/resources/sentences.wav")) diff --git a/docs/griptape-framework/data/src/loaders_3.py b/docs/griptape-framework/data/src/loaders_3.py index 35af0fdfc..3bc3ceb81 100644 --- a/docs/griptape-framework/data/src/loaders_3.py +++ b/docs/griptape-framework/data/src/loaders_3.py @@ -1,16 +1,11 @@ from pathlib import Path from griptape.loaders import CsvLoader -from griptape.utils import load_file, load_files # Load a single CSV file -CsvLoader().load(Path("tests/resources/cities.csv").read_text()) -# You can also use the load_file utility function -CsvLoader().load(load_file("tests/resources/cities.csv")) +CsvLoader().load("tests/resources/cities.csv") +# You can also pass a Path object +CsvLoader().load(Path("tests/resources/cities.csv")) # Load multiple CSV files -CsvLoader().load_collection( - [Path("tests/resources/cities.csv").read_text(), Path("tests/resources/addresses.csv").read_text()] -) -# You can also use the load_files utility function -CsvLoader().load_collection(list(load_files(["tests/resources/cities.csv", "tests/resources/addresses.csv"]).values())) +CsvLoader().load_collection([Path("tests/resources/cities.csv"), "tests/resources/addresses.csv"]) diff --git a/docs/griptape-framework/data/src/loaders_4.py b/docs/griptape-framework/data/src/loaders_4.py deleted file mode 100644 index 8d5883adf..000000000 --- a/docs/griptape-framework/data/src/loaders_4.py +++ /dev/null @@ -1,13 +0,0 @@ -import urllib.request - -import pandas as pd - -from griptape.loaders import DataFrameLoader - -urllib.request.urlretrieve("https://people.sc.fsu.edu/~jburkardt/data/csv/cities.csv", "cities.csv") - -DataFrameLoader().load(pd.read_csv("cities.csv")) - -urllib.request.urlretrieve("https://people.sc.fsu.edu/~jburkardt/data/csv/addresses.csv", "addresses.csv") - -DataFrameLoader().load_collection([pd.read_csv("cities.csv"), pd.read_csv("addresses.csv")]) diff --git a/docs/griptape-framework/data/src/loaders_7.py b/docs/griptape-framework/data/src/loaders_7.py index 6857886e8..177471fd2 100644 --- a/docs/griptape-framework/data/src/loaders_7.py +++ b/docs/griptape-framework/data/src/loaders_7.py @@ -1,9 +1,9 @@ from pathlib import Path from griptape.loaders import ImageLoader -from griptape.utils import load_file # Load an image from disk -disk_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) -# You can also use the load_file utility function -ImageLoader().load(load_file("tests/resources/mountain.png")) +ImageLoader().load("tests/resources/mountain.png") + +# You can also pass a Path object +ImageLoader().load(Path("tests/resources/mountain.png")) diff --git a/docs/griptape-framework/data/src/loaders_8.py b/docs/griptape-framework/data/src/loaders_8.py index e85992d45..d54c31246 100644 --- a/docs/griptape-framework/data/src/loaders_8.py +++ b/docs/griptape-framework/data/src/loaders_8.py @@ -1,16 +1,11 @@ from pathlib import Path from griptape.loaders import ImageLoader -from griptape.utils import load_file, load_files # Load a single image in BMP format -image_artifact_jpeg = ImageLoader(format="bmp").load(Path("tests/resources/mountain.png").read_bytes()) -# You can also use the load_file utility function -ImageLoader(format="bmp").load(load_file("tests/resources/mountain.png")) +ImageLoader(format="bmp").load("tests/resources/mountain.png") +# You can also pass a Path object +ImageLoader(format="bmp").load(Path("tests/resources/mountain.png")) # Load multiple images in BMP format -ImageLoader().load_collection( - [Path("tests/resources/mountain.png").read_bytes(), Path("tests/resources/cow.png").read_bytes()] -) -# You can also use the load_files utility function -ImageLoader().load_collection(list(load_files(["tests/resources/mountain.png", "tests/resources/cow.png"]).values())) +ImageLoader().load_collection([Path("tests/resources/mountain.png"), "tests/resources/cow.png"]) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py index 7f7e98e13..33f611a71 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -9,10 +10,11 @@ vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index 39a21121d..f82294913 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, QdrantVectorStoreDriver from griptape.loaders import WebLoader @@ -19,7 +20,8 @@ ) # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Recreate Qdrant collection vector_store_driver.client.recreate_collection( @@ -28,7 +30,7 @@ ) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py index a8d9ceed1..739a52c63 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import AstraDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -20,10 +21,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py index 559eaec5a..c8ff0e0b5 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, PineconeVectorStoreDriver from griptape.loaders import WebLoader @@ -14,10 +15,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py index f2f0091a0..ad15f0744 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import MarqoVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -19,12 +20,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py index 7649579c7..33d541f2f 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import MongoDbAtlasVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -25,14 +26,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -vector_store_driver.upsert_text_artifacts( - { - "griptape": artifacts, - } -) +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py index 78a7cc3e6..6a8a8bb04 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import AzureMongoDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -25,12 +26,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py index d34ff8649..7d14504bb 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, RedisVectorStoreDriver from griptape.loaders import WebLoader @@ -15,12 +16,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py index 18e50a397..0ecb49723 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py @@ -2,6 +2,7 @@ import boto3 +from griptape.chunkers import TextChunker from griptape.drivers import AmazonOpenSearchVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -16,12 +17,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py index ad5abf932..9feb47761 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, PgVectorVectorStoreDriver from griptape.loaders import WebLoader @@ -22,12 +23,13 @@ vector_store_driver.setup() # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/engines/src/audio_engines_2.py b/docs/griptape-framework/engines/src/audio_engines_2.py index c04b466f8..92c87d638 100644 --- a/docs/griptape-framework/engines/src/audio_engines_2.py +++ b/docs/griptape-framework/engines/src/audio_engines_2.py @@ -1,7 +1,6 @@ from griptape.drivers import OpenAiAudioTranscriptionDriver from griptape.engines import AudioTranscriptionEngine from griptape.loaders import AudioLoader -from griptape.utils import load_file driver = OpenAiAudioTranscriptionDriver(model="whisper-1") @@ -9,5 +8,5 @@ audio_transcription_driver=driver, ) -audio_artifact = AudioLoader().load(load_file("tests/resources/sentences.wav")) +audio_artifact = AudioLoader().load("tests/resources/sentences.wav") engine.run(audio_artifact) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index a8a9cc06b..6ad28545f 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule @@ -8,12 +9,12 @@ prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) -artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") - +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=500).chunk(artifact) vector_store.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index b5adf2a5a..201aadfa4 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -1,5 +1,6 @@ import requests +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import PromptSummaryEngine from griptape.loaders import PdfLoader @@ -9,8 +10,9 @@ prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), ) -artifacts = PdfLoader().load(response.content) +artifact = PdfLoader().load(response.content) +chunks = TextChunker().chunk(artifact) -text = "\n\n".join([a.value for a in artifacts]) +text = "\n\n".join([a.value for a in chunks]) engine.summarize_text(text) diff --git a/docs/griptape-framework/structures/src/tasks_18.py b/docs/griptape-framework/structures/src/tasks_18.py index 08ece5a92..0d3312d4c 100644 --- a/docs/griptape-framework/structures/src/tasks_18.py +++ b/docs/griptape-framework/structures/src/tasks_18.py @@ -3,12 +3,11 @@ from griptape.loaders import AudioLoader from griptape.structures import Pipeline from griptape.tasks import AudioTranscriptionTask -from griptape.utils import load_file driver = OpenAiAudioTranscriptionDriver(model="whisper-1") task = AudioTranscriptionTask( - input=lambda _: AudioLoader().load(load_file("tests/resources/sentences2.wav")), + input=lambda _: AudioLoader().load("tests/resources/sentences2.wav"), audio_transcription_engine=AudioTranscriptionEngine( audio_transcription_driver=driver, ), diff --git a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py index 26c87e255..bdb60d98b 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -8,8 +9,9 @@ ) artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) -vector_store_driver.upsert_text_artifacts({"griptape": artifacts}) +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) vector_db = VectorStoreTool( description="This DB has information about the Griptape Python framework", vector_store_driver=vector_store_driver, diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 138ce43fc..30c293b49 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -10,6 +10,8 @@ from .audio_artifact import AudioArtifact from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact +from .table_artifact import TableArtifact + from .generic_artifact import GenericArtifact from .error_artifact import ErrorArtifact @@ -32,4 +34,5 @@ "AudioArtifact", "ActionArtifact", "GenericArtifact", + "TableArtifact", ] diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 18f4e184f..366ae5c3a 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -11,6 +11,8 @@ T = TypeVar("T", bound=BaseArtifact, covariant=True) + from griptape.artifacts import BaseArtifact + @define class ListArtifact(BaseArtifact, Generic[T]): diff --git a/griptape/artifacts/table_artifact.py b/griptape/artifacts/table_artifact.py new file mode 100644 index 000000000..d3ca003fa --- /dev/null +++ b/griptape/artifacts/table_artifact.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import csv +import io +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +from griptape.artifacts.text_artifact import TextArtifact + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@define +class TableArtifact(TextArtifact): + value: list[dict] = field(factory=list, metadata={"serializable": True}) + delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) + fieldnames: Optional[Sequence[str]] = field(factory=list, metadata={"serializable": True}) + quoting: int = field(default=csv.QUOTE_MINIMAL, kw_only=True, metadata={"serializable": True}) + line_terminator: str = field(default="\n", kw_only=True, metadata={"serializable": True}) + + def __bool__(self) -> bool: + return len(self.value) > 0 + + def to_text(self) -> str: + with io.StringIO() as csvfile: + fieldnames = (self.value[0].keys() if self.value else []) if self.fieldnames is None else self.fieldnames + + writer = csv.DictWriter( + csvfile, + fieldnames=fieldnames, + quoting=self.quoting, + delimiter=self.delimiter, + lineterminator=self.line_terminator, + ) + + writer.writeheader() + writer.writerows(self.value) + + return csvfile.getvalue().strip() diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index dce538812..c904f1532 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -1,11 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional -from attrs import Factory, define, field +from attrs import define, field -import griptape.loaders as loaders -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, InfoArtifact, TextArtifact @define @@ -17,57 +17,28 @@ class BaseFileManagerDriver(ABC): loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts. """ - default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True) - loaders: dict[str, loaders.BaseLoader] = field( - default=Factory( - lambda: { - "pdf": loaders.PdfLoader(), - "csv": loaders.CsvLoader(), - "txt": loaders.TextLoader(), - "html": loaders.TextLoader(), - "json": loaders.TextLoader(), - "yaml": loaders.TextLoader(), - "xml": loaders.TextLoader(), - "png": loaders.ImageLoader(), - "jpg": loaders.ImageLoader(), - "jpeg": loaders.ImageLoader(), - "webp": loaders.ImageLoader(), - "gif": loaders.ImageLoader(), - "bmp": loaders.ImageLoader(), - "tiff": loaders.ImageLoader(), - }, - ), - kw_only=True, - ) + workdir: str = field(kw_only=True) + encoding: Optional[str] = field(default=None, kw_only=True) - def list_files(self, path: str) -> TextArtifact | ErrorArtifact: + def list_files(self, path: str) -> TextArtifact: entries = self.try_list_files(path) return TextArtifact("\n".join(list(entries))) @abstractmethod def try_list_files(self, path: str) -> list[str]: ... - def load_file(self, path: str) -> BaseArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - source = self.try_load_file(path) - result = loader.load(source) - - if isinstance(result, BaseArtifact): - return result + def load_file(self, path: str) -> BlobArtifact | TextArtifact: + if self.encoding is None: + return BlobArtifact(self.try_load_file(path)) else: - return ListArtifact(result) + return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding) @abstractmethod def try_load_file(self, path: str) -> bytes: ... def save_file(self, path: str, value: bytes | str) -> InfoArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - encoding = None if loader is None else loader.encoding - if isinstance(value, str): - value = value.encode() if encoding is None else value.encode(encoding=encoding) + value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding) elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index a6f1f0726..b383ff7d7 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from typing import Optional from attrs import Attribute, Factory, define, field @@ -16,11 +17,11 @@ class LocalFileManagerDriver(BaseFileManagerDriver): workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory. """ - workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True) + workdir: Optional[str] = field(default=Factory(lambda: os.getcwd()), kw_only=True) @workdir.validator # pyright: ignore[reportAttributeAccessIssue] def validate_workdir(self, _: Attribute, workdir: str) -> None: - if not Path(workdir).is_absolute(): + if self.workdir is not None and not Path(workdir).is_absolute(): raise ValueError("Workdir must be an absolute path") def try_list_files(self, path: str) -> list[str]: @@ -41,8 +42,7 @@ def try_save_file(self, path: str, value: bytes) -> None: Path(full_path).write_bytes(value) def _full_path(self, path: str) -> str: - path = path.lstrip("/") - full_path = os.path.join(self.workdir, path) + full_path = path if self.workdir is None else os.path.join(self.workdir, path.lstrip("/")) # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_slash = path.endswith("/") diff --git a/griptape/drivers/web_scraper/base_web_scraper_driver.py b/griptape/drivers/web_scraper/base_web_scraper_driver.py index ae39f8eac..0c33f3713 100644 --- a/griptape/drivers/web_scraper/base_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/base_web_scraper_driver.py @@ -4,5 +4,13 @@ class BaseWebScraperDriver(ABC): + def scrape_url(self, url: str) -> TextArtifact: + source = self.fetch_url(url) + + return self.extract_page(source) + + @abstractmethod + def fetch_url(self, url: str) -> str: ... + @abstractmethod - def scrape_url(self, url: str) -> TextArtifact: ... + def extract_page(self, page: str) -> TextArtifact: ... diff --git a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py index b54ff072f..4eff4948b 100644 --- a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py @@ -38,20 +38,8 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver): exclude_ids: list[str] = field(default=Factory(list), kw_only=True) timeout: Optional[int] = field(default=None, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright - bs4 = import_optional_dependency("bs4") - markdownify = import_optional_dependency("markdownify") - - include_links = self.include_links - - # Custom MarkdownConverter to optionally linked urls. If include_links is False only - # the text of the link is returned. - class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): - def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: - if include_links: - return super().convert_a(el, text, convert_as_inline) - return text with sync_playwright() as p, p.chromium.launch(headless=True) as browser: page = browser.new_page() @@ -76,28 +64,43 @@ def skip_loading_images(route: Any) -> Any: if not content: raise Exception("can't access URL") - soup = bs4.BeautifulSoup(content, "html.parser") + return content + + def extract_page(self, page: str) -> TextArtifact: + bs4 = import_optional_dependency("bs4") + markdownify = import_optional_dependency("markdownify") + include_links = self.include_links + + # Custom MarkdownConverter to optionally linked urls. If include_links is False only + # the text of the link is returned. + class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): + def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: + if include_links: + return super().convert_a(el, text, convert_as_inline) + return text + + soup = bs4.BeautifulSoup(page, "html.parser") - # Remove unwanted elements - exclude_selector = ",".join( - self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], - ) - if exclude_selector: - for s in soup.select(exclude_selector): - s.extract() + # Remove unwanted elements + exclude_selector = ",".join( + self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], + ) + if exclude_selector: + for s in soup.select(exclude_selector): + s.extract() - text = OptionalLinksMarkdownConverter().convert_soup(soup) + text = OptionalLinksMarkdownConverter().convert_soup(soup) - # Remove leading and trailing whitespace from the entire text - text = text.strip() + # Remove leading and trailing whitespace from the entire text + text = text.strip() - # Remove trailing whitespace from each line - text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) + # Remove trailing whitespace from each line + text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) - # Indent using 2 spaces instead of tabs - text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) + # Indent using 2 spaces instead of tabs + text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) - # Remove triple+ newlines (keep double newlines for paragraphs) - text = re.sub(r"\n\n+", "\n\n", text) + # Remove triple+ newlines (keep double newlines for paragraphs) + text = re.sub(r"\n\n+", "\n\n", text) - return TextArtifact(text) + return TextArtifact(text) diff --git a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py index 2d785fde2..94b3914ea 100644 --- a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py @@ -12,6 +12,10 @@ class ProxyWebScraperDriver(BaseWebScraperDriver): proxies: dict = field(kw_only=True, metadata={"serializable": False}) params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True}) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: response = requests.get(url, proxies=self.proxies, **self.params) - return TextArtifact(response.text) + + return response.text + + def extract_page(self, page: str) -> TextArtifact: + return TextArtifact(page) diff --git a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py index 06f5573a4..e87af8af6 100644 --- a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py @@ -12,7 +12,7 @@ class TrafilaturaWebScraperDriver(BaseWebScraperDriver): include_links: bool = field(default=True, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config @@ -29,6 +29,15 @@ def scrape_url(self, url: str) -> TextArtifact: if page is None: raise Exception("can't access URL") + + return page + + def extract_page(self, page: str) -> TextArtifact: + trafilatura = import_optional_dependency("trafilatura") + use_config = trafilatura.settings.use_config + + config = use_config() + extracted_page = trafilatura.extract( page, include_links=self.include_links, diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index 7e4854d00..0348a2094 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape import utils +from griptape.chunkers import TextChunker from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -14,12 +15,13 @@ from griptape.artifacts import TextArtifact from griptape.drivers import BaseVectorStoreDriver from griptape.engines.rag import RagContext - from griptape.loaders import BaseTextLoader + from griptape.loaders import TextLoader @define(kw_only=True) class TextLoaderRetrievalRagModule(BaseRetrievalRagModule): - loader: BaseTextLoader = field() + loader: TextLoader = field() + chunker: TextChunker = field(default=Factory(lambda: TextChunker())) vector_store_driver: BaseVectorStoreDriver = field() source: Any = field() query_params: dict[str, Any] = field(factory=dict) @@ -37,7 +39,8 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params["namespace"] = namespace loader_output = self.loader.load(source) + chunks = self.chunker.chunk(loader_output) - self.vector_store_driver.upsert_text_artifacts({namespace: loader_output}) + self.vector_store_driver.upsert_text_artifacts({namespace: chunks}) return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/loaders/__init__.py b/griptape/loaders/__init__.py index b79b0ff44..b86370607 100644 --- a/griptape/loaders/__init__.py +++ b/griptape/loaders/__init__.py @@ -1,26 +1,28 @@ from .base_loader import BaseLoader -from .base_text_loader import BaseTextLoader +from .base_file_loader import BaseFileLoader + from .text_loader import TextLoader from .pdf_loader import PdfLoader from .web_loader import WebLoader from .sql_loader import SqlLoader from .csv_loader import CsvLoader -from .dataframe_loader import DataFrameLoader from .email_loader import EmailLoader + +from .blob_loader import BlobLoader + from .image_loader import ImageLoader + from .audio_loader import AudioLoader -from .blob_loader import BlobLoader __all__ = [ "BaseLoader", - "BaseTextLoader", + "BaseFileLoader", "TextLoader", "PdfLoader", "WebLoader", "SqlLoader", "CsvLoader", - "DataFrameLoader", "EmailLoader", "ImageLoader", "AudioLoader", diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 84d6b767a..1a7caefce 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -1,20 +1,20 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast +import filetype from attrs import define from griptape.artifacts import AudioArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency +from griptape.loaders.base_file_loader import BaseFileLoader @define -class AudioLoader(BaseLoader): +class AudioLoader(BaseFileLoader): """Loads audio content into audio artifacts.""" - def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: - return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) + def load(self, source: Any, *args, **kwargs) -> AudioArtifact: + return cast(AudioArtifact, super().load(source, *args, **kwargs)) - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]: - return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs)) + def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact: + return AudioArtifact(source, format=filetype.guess(source).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py new file mode 100644 index 000000000..650dc303c --- /dev/null +++ b/griptape/loaders/base_file_loader.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING + +from attrs import Factory, define, field + +from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver +from griptape.loaders import BaseLoader + +if TYPE_CHECKING: + from os import PathLike + + +@define +class BaseFileLoader(BaseLoader, ABC): + file_manager_driver: BaseFileManagerDriver = field( + default=Factory(lambda: LocalFileManagerDriver(workdir=None)), + kw_only=True, + ) + encoding: str = field(default="utf-8", kw_only=True) + + def fetch(self, source: str | PathLike, *args, **kwargs) -> str | bytes: + return self.file_manager_driver.load_file(str(source), *args, **kwargs).value diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 14f9aa10f..33c98dfe4 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -10,24 +10,33 @@ from griptape.utils.hash import bytes_to_hash, str_to_hash if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Mapping from griptape.artifacts import BaseArtifact + from griptape.common import Reference @define class BaseLoader(FuturesExecutorMixin, ABC): - encoding: Optional[str] = field(default=None, kw_only=True) + reference: Optional[Reference] = field(default=None, kw_only=True) + + def load(self, source: Any, *args, **kwargs) -> BaseArtifact: + data = self.fetch(source) + + return self.parse(data) + + @abstractmethod + def fetch(self, source: Any, *args, **kwargs) -> Any: ... @abstractmethod - def load(self, source: Any, *args, **kwargs) -> BaseArtifact | Sequence[BaseArtifact]: ... + def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ... def load_collection( self, sources: list[Any], *args, **kwargs, - ) -> Mapping[str, BaseArtifact | Sequence[BaseArtifact | Sequence[BaseArtifact]]]: + ) -> Mapping[str, BaseArtifact]: # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} @@ -39,10 +48,8 @@ def load_collection( }, ) - def to_key(self, source: Any, *args, **kwargs) -> str: + def to_key(self, source: Any) -> str: if isinstance(source, bytes): return bytes_to_hash(source) - elif isinstance(source, str): - return str_to_hash(source) else: return str_to_hash(str(source)) diff --git a/griptape/loaders/base_text_loader.py b/griptape/loaders/base_text_loader.py deleted file mode 100644 index 196cb0087..000000000 --- a/griptape/loaders/base_text_loader.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, cast - -from attrs import Factory, define, field - -from griptape.artifacts import TextArtifact -from griptape.chunkers import BaseChunker, TextChunker -from griptape.loaders import BaseLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.common import Reference - from griptape.drivers import BaseEmbeddingDriver - - -@define -class BaseTextLoader(BaseLoader, ABC): - MAX_TOKEN_RATIO = 0.5 - - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - max_tokens: int = field( - default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True), - kw_only=True, - ) - chunker: BaseChunker = field( - default=Factory( - lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), - takes_self=True, - ), - kw_only=True, - ) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - encoding: str = field(default="utf-8", kw_only=True) - reference: Optional[Reference] = field(default=None, kw_only=True) - - @abstractmethod - def load(self, source: Any, *args, **kwargs) -> list[TextArtifact]: ... - - def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) - - def _text_to_artifacts(self, text: str) -> list[TextArtifact]: - artifacts = [] - - chunks = self.chunker.chunk(text) if self.chunker else [TextArtifact(text)] - - for chunk in chunks: - if self.embedding_driver: - chunk.generate_embedding(self.embedding_driver) - - chunk.reference = self.reference - - chunk.encoding = self.encoding - - artifacts.append(chunk) - - return artifacts diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index d0099b47b..dbc7e66bf 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -5,16 +5,16 @@ from attrs import define from griptape.artifacts import BlobArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader @define -class BlobLoader(BaseLoader): +class BlobLoader(BaseFileLoader): def load(self, source: Any, *args, **kwargs) -> BlobArtifact: + return cast(BlobArtifact, super().load(source, *args, **kwargs)) + + def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: if self.encoding is None: return BlobArtifact(source) else: return BlobArtifact(source, encoding=self.encoding) - - def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]: - return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index cd0302586..fb7ba19ba 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,49 +6,19 @@ from attrs import define, field -from griptape.artifacts import CsvRowArtifact -from griptape.loaders import BaseLoader - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders.base_file_loader import BaseFileLoader @define -class CsvLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class CsvLoader(BaseFileLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: - artifacts = [] - - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [TextArtifact(self.formatter_fn(row)) for row in reader] - - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) - - return artifacts + def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: + reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter) - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, list[CsvRowArtifact]]: - return cast( - dict[str, list[CsvRowArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py deleted file mode 100644 index 9131a7682..000000000 --- a/griptape/loaders/dataframe_loader.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, Optional, cast - -from attrs import define, field - -from griptape.artifacts import CsvRowArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency -from griptape.utils.hash import str_to_hash - -if TYPE_CHECKING: - from pandas import DataFrame - - from griptape.drivers import BaseEmbeddingDriver - - -@define -class DataFrameLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - formatter_fn: Callable[[dict], str] = field( - default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True - ) - - def load(self, source: DataFrame, *args, **kwargs) -> list[CsvRowArtifact]: - artifacts = [] - - chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")] - - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) - - return artifacts - - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) - - def to_key(self, source: DataFrame, *args, **kwargs) -> str: - hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object - - return str_to_hash(str(hash_pandas_object(source, index=True).values)) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index f6c9ca406..0a3d3f600 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations import imaplib -from typing import Optional, cast +from typing import Any, Optional, cast from attrs import astuple, define, field @@ -32,11 +32,13 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: - mailparser = import_optional_dependency("mailparser") + def load(self, source: Any, *args, **kwargs) -> ListArtifact: + return cast(ListArtifact, super().load(source, *args, **kwargs)) + + def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: label, key, search_criteria, max_count = astuple(source) - artifacts = [] + mail_bytes = [] with imaplib.IMAP4_SSL(self.imap_url) as client: client.login(self.username, self.password) @@ -59,19 +61,24 @@ def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: if data is None or not data or data[0] is None: continue - message = mailparser.parse_from_bytes(data[0][1]) - - # Note: mailparser only populates the text_plain field - # if the message content type is explicitly set to 'text/plain'. - if message.text_plain: - artifacts.append(TextArtifact("\n".join(message.text_plain))) + mail_bytes.append(data[0][1]) client.close() - return ListArtifact(artifacts) + return mail_bytes + + def parse(self, source: list[bytes], *args, **kwargs) -> ListArtifact: + mailparser = import_optional_dependency("mailparser") + artifacts = [] + for byte in source: + message = mailparser.parse_from_bytes(byte) + + # Note: mailparser only populates the text_plain field + # if the message content type is explicitly set to 'text/plain'. + if message.text_plain: + artifacts.append(TextArtifact("\n".join(message.text_plain))) + + return ListArtifact(artifacts) def _count_messages(self, message_numbers: bytes) -> int: return len(list(filter(None, message_numbers.decode().split(" ")))) - - def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact]: - return cast(dict[str, ListArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 83060dfa8..513910c78 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -1,17 +1,17 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Any, Optional, cast from attrs import define, field from griptape.artifacts import ImageArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader from griptape.utils import import_optional_dependency @define -class ImageLoader(BaseLoader): +class ImageLoader(BaseFileLoader): """Loads images into image artifacts. Attributes: @@ -22,20 +22,13 @@ class ImageLoader(BaseLoader): format: Optional[str] = field(default=None, kw_only=True) - FORMAT_TO_MIME_TYPE = { - "bmp": "image/bmp", - "gif": "image/gif", - "jpeg": "image/jpeg", - "png": "image/png", - "tiff": "image/tiff", - "webp": "image/webp", - } + def load(self, source: Any, *args, **kwargs) -> ImageArtifact: + return cast(ImageArtifact, super().load(source, *args, **kwargs)) - def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: + def parse(self, source: bytes, *args, **kwargs) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") image = pil_image.open(BytesIO(source)) - # Normalize format only if requested. if self.format is not None: byte_stream = BytesIO() image.save(byte_stream, format=self.format) @@ -43,15 +36,3 @@ def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: source = byte_stream.getvalue() return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) - - def _get_mime_type(self, image_format: str | None) -> str: - if image_format is None: - raise ValueError("image_format is None") - - if image_format.lower() not in self.FORMAT_TO_MIME_TYPE: - raise ValueError(f"Unsupported image format {image_format}") - - return self.FORMAT_TO_MIME_TYPE[image_format.lower()] - - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]: - return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 419bfabf4..b1f901386 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,37 +1,29 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Any, Optional, cast -from attrs import Factory, define, field +from attrs import define -from griptape.artifacts import TextArtifact -from griptape.chunkers import PdfChunker -from griptape.loaders import BaseTextLoader +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders.base_file_loader import BaseFileLoader from griptape.utils import import_optional_dependency @define -class PdfLoader(BaseTextLoader): - chunker: PdfChunker = field( - default=Factory(lambda self: PdfChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True), - kw_only=True, - ) - encoding: None = field(default=None, kw_only=True) +class PdfLoader(BaseFileLoader): + def load(self, source: Any, *args, **kwargs) -> TextArtifact: + return cast(TextArtifact, super().load(source, *args, **kwargs)) - def load( + def parse( self, source: bytes, password: Optional[str] = None, *args, **kwargs, - ) -> list[TextArtifact]: + ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) - return self._text_to_artifacts("\n".join([p.extract_text() for p in reader.pages])) + pages = [TextArtifact(p.extract_text()) for p in reader.pages] - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + return ListArtifact(pages) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 9cff7782b..d633eede6 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,38 +1,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import TYPE_CHECKING, Any, cast from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver, BaseSqlDriver + from griptape.drivers import BaseSqlDriver @define class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - formatter_fn: Callable[[dict], str] = field( - default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True - ) - def load(self, source: str, *args, **kwargs) -> list[CsvRowArtifact]: - rows = self.sql_driver.execute_query(source) - artifacts = [] + def load(self, source: Any, *args, **kwargs) -> ListArtifact: + return cast(ListArtifact, super().load(source, *args, **kwargs)) - chunks = [TextArtifact(self.formatter_fn(row.cells)) for row in rows] if rows else [] + def fetch(self, source: str, *args, **kwargs) -> list[BaseSqlDriver.RowResult]: + return self.sql_driver.execute_query(source) or [] - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) - - return artifacts - - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[CsvRowArtifact]]: - return cast(dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs)) + def parse(self, source: list[BaseSqlDriver.RowResult], *args, **kwargs) -> ListArtifact: + return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(source)]) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 79e551a8e..68dc324ae 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,55 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import Any, cast -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.chunkers import TextChunker -from griptape.loaders import BaseTextLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.loaders import BaseFileLoader @define -class TextLoader(BaseTextLoader): - MAX_TOKEN_RATIO = 0.5 - - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - max_tokens: int = field( - default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True), - kw_only=True, - ) - chunker: TextChunker = field( - default=Factory( - lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), - takes_self=True, - ), - kw_only=True, - ) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class TextLoader(BaseFileLoader): encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - return self._text_to_artifacts(source) + def load(self, source: Any, *args, **kwargs) -> TextArtifact: + return cast(TextArtifact, super().load(source, *args, **kwargs)) - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + def parse(self, source: str | bytes, *args, **kwargs) -> TextArtifact: + if isinstance(source, str): + return TextArtifact(source, encoding=self.encoding) + else: + return TextArtifact(source.decode(self.encoding), encoding=self.encoding) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index 720ab34a1..d31a34816 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,23 +1,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any, cast from attrs import Factory, define, field +from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver -from griptape.loaders import BaseTextLoader - -if TYPE_CHECKING: - from griptape.artifacts import TextArtifact +from griptape.loaders import BaseLoader @define -class WebLoader(BaseTextLoader): +class WebLoader(BaseLoader): web_scraper_driver: BaseWebScraperDriver = field( default=Factory(lambda: TrafilaturaWebScraperDriver()), kw_only=True, ) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: - single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) - return self._text_to_artifacts(single_chunk_text_artifact.value) + def load(self, source: Any, *args, **kwargs) -> TextArtifact: + return cast(TextArtifact, super().load(source, *args, **kwargs)) + + def fetch(self, source: str, *args, **kwargs) -> str: + return self.web_scraper_driver.fetch_url(source) + + def parse(self, source: str, *args, **kwargs) -> TextArtifact: + return self.web_scraper_driver.extract_page(source) diff --git a/griptape/tools/file_manager/tool.py b/griptape/tools/file_manager/tool.py index b72f82329..8dc0c9393 100644 --- a/griptape/tools/file_manager/tool.py +++ b/griptape/tools/file_manager/tool.py @@ -5,9 +5,12 @@ from attrs import Factory, define, field from schema import Literal, Schema +import griptape.loaders as loaders from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver +from griptape.loaders.blob_loader import BlobLoader from griptape.tools import BaseTool +from griptape.utils import get_mime_type from griptape.utils.decorators import activity @@ -21,6 +24,20 @@ class FileManagerTool(BaseTool): file_manager_driver: BaseFileManagerDriver = field(default=Factory(lambda: LocalFileManagerDriver()), kw_only=True) + loaders: dict[str, loaders.BaseLoader] = field( + default=Factory( + lambda self: { + "application/pdf": loaders.PdfLoader(file_manager_driver=self.file_manager_driver), + "text/csv": loaders.CsvLoader(file_manager_driver=self.file_manager_driver), + "text": loaders.TextLoader(file_manager_driver=self.file_manager_driver), + "image": loaders.ImageLoader(file_manager_driver=self.file_manager_driver), + "application/octet-stream": BlobLoader(file_manager_driver=self.file_manager_driver), + }, + takes_self=True, + ), + kw_only=True, + ) + @activity( config={ "description": "Can be used to list files on disk", @@ -51,7 +68,11 @@ def load_files_from_disk(self, params: dict) -> ListArtifact | ErrorArtifact: artifacts = [] for path in paths: - artifact = self.file_manager_driver.load_file(path) + abs_path = os.path.join(self.file_manager_driver.workdir, path) + mime_type = get_mime_type(abs_path) + loader = next((loader for key, loader in self.loaders.items() if mime_type.startswith(key))) + + artifact = loader.load(path) if isinstance(artifact, ListArtifact): artifacts.extend(artifact.value) else: diff --git a/griptape/tools/sql/tool.py b/griptape/tools/sql/tool.py index a84bb87be..59d0a3dba 100644 --- a/griptape/tools/sql/tool.py +++ b/griptape/tools/sql/tool.py @@ -51,6 +51,6 @@ def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArti return ErrorArtifact(f"error executing query: {e}") if len(rows) > 0: - return ListArtifact(rows) + return rows else: return InfoArtifact("No results found") diff --git a/griptape/tools/web_scraper/tool.py b/griptape/tools/web_scraper/tool.py index 2895d5e0d..982123b30 100644 --- a/griptape/tools/web_scraper/tool.py +++ b/griptape/tools/web_scraper/tool.py @@ -4,6 +4,7 @@ from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, ListArtifact +from griptape.chunkers import TextChunker from griptape.loaders import WebLoader from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -12,6 +13,7 @@ @define class WebScraperTool(BaseTool): web_loader: WebLoader = field(default=Factory(lambda: WebLoader()), kw_only=True) + text_chunker: TextChunker = field(default=Factory(lambda: TextChunker()), kw_only=True) @activity( config={ @@ -24,6 +26,8 @@ def get_content(self, params: dict) -> ListArtifact | ErrorArtifact: try: result = self.web_loader.load(url) - return ListArtifact(result) + chunks = TextChunker().chunk(result) + + return ListArtifact(chunks) except Exception as e: return ErrorArtifact("Error getting page content: " + str(e)) diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 03725f59d..77e3f3b0a 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -8,7 +8,6 @@ from .futures import execute_futures_dict, execute_futures_list, execute_futures_list_dict from .token_counter import TokenCounter from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively -from .file_utils import load_file, load_files from .hash import str_to_hash from .import_utils import import_optional_dependency from .import_utils import is_dependency_installed @@ -17,6 +16,7 @@ from .deprecation import deprecation_warn from .structure_visualizer import StructureVisualizer from .reference_utils import references_from_artifacts +from .file_utils import get_mime_type def minify_json(value: str) -> str: @@ -44,8 +44,7 @@ def minify_json(value: str) -> str: "Stream", "load_artifact_from_memory", "deprecation_warn", - "load_file", - "load_files", "StructureVisualizer", "references_from_artifacts", + "get_mime_type", ] diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py index 19c9f699c..0dbbbc093 100644 --- a/griptape/utils/file_utils.py +++ b/griptape/utils/file_utils.py @@ -1,38 +1,16 @@ -from __future__ import annotations +import mimetypes -from concurrent import futures -from pathlib import Path -from typing import Optional +import filetype -import griptape.utils as utils +def get_mime_type(file_path: str) -> str: + filetype_guess = filetype.guess(file_path) -def load_file(path: str) -> bytes: - """Load a file from the given path and return its content as bytes. - - Args: - path (str): The path to the file to load. - - Returns: - The content of the file. - """ - return Path(path).read_bytes() - - -def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, bytes]: - """Load multiple files concurrently and return a dictionary of their content. - - Args: - paths: The paths to the files to load. - futures_executor: The executor to use for concurrent loading. If None, a new ThreadPoolExecutor will be created. - - Returns: - A dictionary where the keys are a hash of the path and the values are the content of the files. - """ - if futures_executor is None: - futures_executor = futures.ThreadPoolExecutor() - - with futures_executor as executor: - return utils.execute_futures_dict( - {utils.str_to_hash(str(path)): executor.submit(load_file, path) for path in paths}, - ) + if filetype_guess is None: + type_, _ = mimetypes.guess_type(file_path) + if type_ is None: + return "application/octet-stream" + else: + return type_ + else: + return filetype_guess.mime diff --git a/poetry.lock b/poetry.lock index 647231351..d90a84d61 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1595,7 +1595,7 @@ typing = ["typing-extensions (>=4.8)"] name = "filetype" version = "1.2.0" description = "Infer file type and MIME type of any file/buffer. No external dependencies." -optional = true +optional = false python-versions = "*" files = [ {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, @@ -5862,48 +5862,36 @@ python-versions = ">=3.7" files = [ {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d0b2cf8791ab5fb9e3aa3d9a79a0d5d51f55b6357eecf532a120ba3b5524db"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:243f92596f4fd4c8bd30ab8e8dd5965afe226363d75cab2468f2c707f64cd83b"}, - {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ea54f7300553af0a2a7235e9b85f4204e1fc21848f917a3213b0e0818de9a24"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:173f5f122d2e1bff8fbd9f7811b7942bead1f5e9f371cdf9e670b327e6703ebd"}, - {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:196958cde924a00488e3e83ff917be3b73cd4ed8352bbc0f2989333176d1c54d"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd90c221ed4e60ac9d476db967f436cfcecbd4ef744537c0f2d5291439848768"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-win32.whl", hash = "sha256:3166dfff2d16fe9be3241ee60ece6fcb01cf8e74dd7c5e0b64f8e19fab44911b"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-win_amd64.whl", hash = "sha256:6831a78bbd3c40f909b3e5233f87341f12d0b34a58f14115c9e94b4cdaf726d3"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7db3db284a0edaebe87f8f6642c2b2c27ed85c3e70064b84d1c9e4ec06d5d84"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:430093fce0efc7941d911d34f75a70084f12f6ca5c15d19595c18753edb7c33b"}, - {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79cb400c360c7c210097b147c16a9e4c14688a6402445ac848f296ade6283bbc"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1b30f31a36c7f3fee848391ff77eebdd3af5750bf95fbf9b8b5323edfdb4ec"}, - {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fddde2368e777ea2a4891a3fb4341e910a056be0bb15303bf1b92f073b80c02"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80bd73ea335203b125cf1d8e50fef06be709619eb6ab9e7b891ea34b5baa2287"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-win32.whl", hash = "sha256:6daeb8382d0df526372abd9cb795c992e18eed25ef2c43afe518c73f8cccb721"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-win_amd64.whl", hash = "sha256:5bc08e75ed11693ecb648b7a0a4ed80da6d10845e44be0c98c03f2f880b68ff4"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:53e68b091492c8ed2bd0141e00ad3089bcc6bf0e6ec4142ad6505b4afe64163e"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bcd18441a49499bf5528deaa9dee1f5c01ca491fc2791b13604e8f972877f812"}, - {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:165bbe0b376541092bf49542bd9827b048357f4623486096fc9aaa6d4e7c59a2"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3330415cd387d2b88600e8e26b510d0370db9b7eaf984354a43e19c40df2e2b"}, - {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97b850f73f8abbffb66ccbab6e55a195a0eb655e5dc74624d15cff4bfb35bd74"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee4c6917857fd6121ed84f56d1dc78eb1d0e87f845ab5a568aba73e78adf83"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-win32.whl", hash = "sha256:fbb034f565ecbe6c530dff948239377ba859420d146d5f62f0271407ffb8c580"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-win_amd64.whl", hash = "sha256:707c8f44931a4facd4149b52b75b80544a8d824162602b8cd2fe788207307f9a"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24af3dc43568f3780b7e1e57c49b41d98b2d940c1fd2e62d65d3928b6f95f021"}, - {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e60ed6ef0a35c6b76b7640fe452d0e47acc832ccbb8475de549a5cc5f90c2c06"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:413c85cd0177c23e32dee6898c67a5f49296640041d98fddb2c40888fe4daa2e"}, - {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:25691f4adfb9d5e796fd48bf1432272f95f4bbe5f89c475a788f31232ea6afba"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:526ce723265643dbc4c7efb54f56648cc30e7abe20f387d763364b3ce7506c82"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-win32.whl", hash = "sha256:13be2cc683b76977a700948411a94c67ad8faf542fa7da2a4b167f2244781cf3"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-win_amd64.whl", hash = "sha256:e54ef33ea80d464c3dcfe881eb00ad5921b60f8115ea1a30d781653edc2fd6a2"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:43f28005141165edd11fbbf1541c920bd29e167b8bbc1fb410d4fe2269c1667a"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b68094b165a9e930aedef90725a8fcfafe9ef95370cbb54abc0464062dbf808f"}, - {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1e03db964e9d32f112bae36f0cc1dcd1988d096cfd75d6a588a3c3def9ab2b"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:203d46bddeaa7982f9c3cc693e5bc93db476ab5de9d4b4640d5c99ff219bee8c"}, - {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ae92bebca3b1e6bd203494e5ef919a60fb6dfe4d9a47ed2453211d3bd451b9f5"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9661268415f450c95f72f0ac1217cc6f10256f860eed85c2ae32e75b60278ad8"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-win32.whl", hash = "sha256:895184dfef8708e15f7516bd930bda7e50ead069280d2ce09ba11781b630a434"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-win_amd64.whl", hash = "sha256:6e7cde3a2221aa89247944cafb1b26616380e30c63e37ed19ff0bba5e968688d"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dbcdf987f3aceef9763b6d7b1fd3e4ee210ddd26cac421d78b3c206d07b2700b"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ce119fc4ce0d64124d37f66a6f2a584fddc3c5001755f8a49f1ca0a177ef9796"}, - {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a17d8fac6df9835d8e2b4c5523666e7051d0897a93756518a1fe101c7f47f2f0"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ebc11c54c6ecdd07bb4efbfa1554538982f5432dfb8456958b6d46b9f834bb7"}, - {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2e6965346fc1491a566e019a4a1d3dfc081ce7ac1a736536367ca305da6472a8"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:220574e78ad986aea8e81ac68821e47ea9202b7e44f251b7ed8c66d9ae3f4278"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-win32.whl", hash = "sha256:b75b00083e7fe6621ce13cfce9d4469c4774e55e8e9d38c305b37f13cf1e874c"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-win_amd64.whl", hash = "sha256:c29d03e0adf3cc1a8c3ec62d176824972ae29b67a66cbb18daff3062acc6faa8"}, @@ -6375,6 +6363,11 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -6956,7 +6949,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["accelerate", "anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "torch", "trafilatura", "transformers", "voyageai"] +all = ["accelerate", "anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6998,8 +6991,6 @@ drivers-vector-redis = ["redis"] drivers-web-scraper-markdownify = ["beautifulsoup4", "markdownify", "playwright"] drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-search-duckduckgo = ["duckduckgo-search"] -loaders-audio = ["filetype"] -loaders-dataframe = ["pandas"] loaders-email = ["mail-parser"] loaders-image = ["pillow"] loaders-pdf = ["pypdf"] @@ -7008,4 +6999,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ad31933841926602b18aea04927c1f86d6825d8891495210a28dc1732768499e" +content-hash = "61d8eb5080886ac092c8e9ff29c6260276acc94ec2062f9363e838c004c5e37f" diff --git a/pyproject.toml b/pyproject.toml index a0854e0a2..eb4650e5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ numpy = "^1.26.4" stringcase = "^1.2.0" docker = "^7.1.0" requests = "^2.32.0" +filetype = "^1.2" # drivers cohere = { version = "^5.5.4", optional = true } @@ -70,7 +71,6 @@ pandas = {version = "^1.3", optional = true} pypdf = {version = "^3.9", optional = true} pillow = {version = "^10.2.0", optional = true} mail-parser = {version = "^3.15.0", optional = true} -filetype = {version = "^1.2", optional = true} [tool.poetry.extras] drivers-prompt-cohere = ["cohere"] @@ -150,11 +150,9 @@ drivers-image-generation-huggingface = [ "pillow", ] -loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] loaders-email = ["mail-parser"] -loaders-audio = ["filetype"] loaders-sql = ["sqlalchemy"] all = [ @@ -201,7 +199,6 @@ all = [ "pandas", "pypdf", "mail-parser", - "filetype", ] [tool.poetry.group.test] diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 84ce61768..3d131ef7a 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -5,9 +5,9 @@ import pytest from moto import mock_s3 -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import AmazonS3FileManagerDriver -from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -154,8 +154,7 @@ def test_list_files_failure(self, workdir, path, expected, driver): def test_load_file(self, driver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -185,9 +184,8 @@ def test_load_file_failure(self, workdir, path, expected, driver): def test_load_file_with_encoding(self, driver): artifact = driver.load_file("resources/test.txt") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) + assert isinstance(artifact, BlobArtifact) + assert artifact.encoding == "utf-8" @pytest.mark.parametrize( ("workdir", "path", "content"), @@ -240,9 +238,7 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_c def test_save_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -253,9 +249,7 @@ def test_save_file_with_encoding(self, session, bucket, get_s3_value): def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, loaders={"txt": TextLoader(encoding="ascii")}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -264,13 +258,10 @@ def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): assert get_s3_value(expected_s3_key) == "foobar" assert result.value == "Successfully saved file" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.load_file(path) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, TextArtifact) + assert result.encoding == "ascii" diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index 394a838a3..99f0285bc 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -4,9 +4,9 @@ import pytest -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader class TestLocalFileManagerDriver: @@ -127,8 +127,7 @@ def test_list_files_failure(self, workdir, path, expected, temp_dir, driver): def test_load_file(self, driver: LocalFileManagerDriver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -156,23 +155,6 @@ def test_load_file_failure(self, workdir, path, expected, temp_dir, driver): with pytest.raises(expected): driver.load_file(path) - def test_load_file_with_encoding(self, driver: LocalFileManagerDriver): - artifact = driver.load_file("resources/test.txt") - - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) - - def test_load_file_with_encoding_failure(self, driver): - driver = LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, - workdir=os.path.normpath(os.path.abspath(os.path.dirname(__file__) + "../../../../")), - ) - - with pytest.raises(UnicodeDecodeError): - driver.load_file("resources/bitcoin.pdf") - @pytest.mark.parametrize( ("workdir", "path", "content"), [ @@ -224,25 +206,24 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver): driver.save_file(path, "foobar") def test_save_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" def test_save_and_load_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.load_file(os.path.join("test", "foobar.txt")) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, TextArtifact) + assert result.encoding == "ascii" def _to_driver_workdir(self, temp_dir, workdir): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_vector_store_driver.py similarity index 95% rename from tests/unit/drivers/vector/test_base_local_vector_store_driver.py rename to tests/unit/drivers/vector/test_base_vector_store_driver.py index 20a3e2b50..8438568ad 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_vector_store_driver.py @@ -4,12 +4,13 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.drivers import BaseVectorStoreDriver -class BaseLocalVectorStoreDriver(ABC): +class TestBaseVectorStoreDriver(ABC): @pytest.fixture() @abstractmethod - def driver(self): ... + def driver(self, *args, **kwargs) -> BaseVectorStoreDriver: ... def test_upsert(self, driver): namespace = driver.upsert_text_artifact(TextArtifact(id="foo1", value="foobar")) diff --git a/tests/unit/drivers/vector/test_local_vector_store_driver.py b/tests/unit/drivers/vector/test_local_vector_store_driver.py index 2504b2486..9722bb25d 100644 --- a/tests/unit/drivers/vector/test_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_local_vector_store_driver.py @@ -3,10 +3,10 @@ from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_vector_store_driver import TestBaseVectorStoreDriver -class TestLocalVectorStoreDriver(BaseLocalVectorStoreDriver): +class TestLocalVectorStoreDriver(TestBaseVectorStoreDriver): @pytest.fixture() def driver(self): return LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py index c130858b5..9fa725e80 100644 --- a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py @@ -6,10 +6,10 @@ from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_vector_store_driver import TestBaseVectorStoreDriver -class TestPersistentLocalVectorStoreDriver(BaseLocalVectorStoreDriver): +class TestPersistentLocalVectorStoreDriver(TestBaseVectorStoreDriver): @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py index 69e334c7f..b2dbb613d 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py @@ -18,7 +18,7 @@ def test_run(self): embedding_driver = MockEmbeddingDriver() module = TextLoaderRetrievalRagModule( - loader=WebLoader(max_tokens=MAX_TOKENS, embedding_driver=embedding_driver), + loader=WebLoader(), vector_store_driver=LocalVectorStoreDriver(embedding_driver=embedding_driver), source="https://www.griptape.ai", ) diff --git a/tests/unit/loaders/conftest.py b/tests/unit/loaders/conftest.py index 1f698738a..0bbf839b8 100644 --- a/tests/unit/loaders/conftest.py +++ b/tests/unit/loaders/conftest.py @@ -1,4 +1,5 @@ import os +from io import BytesIO, StringIO from pathlib import Path import pytest @@ -14,15 +15,15 @@ def create_source(resource_path: str) -> Path: @pytest.fixture() def bytes_from_resource_path(path_from_resource_path): - def create_source(resource_path: str) -> bytes: - return Path(path_from_resource_path(resource_path)).read_bytes() + def create_source(resource_path: str) -> BytesIO: + return BytesIO(Path(path_from_resource_path(resource_path)).read_bytes()) return create_source @pytest.fixture() def str_from_resource_path(path_from_resource_path): - def test_csv_str(resource_path: str) -> str: - return Path(path_from_resource_path(resource_path)).read_text() + def test_csv_str(resource_path: str) -> StringIO: + return StringIO(Path(path_from_resource_path(resource_path)).read_text()) return test_csv_str diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index b7ebdd912..7b3516722 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -9,9 +9,9 @@ class TestAudioLoader: def loader(self): return AudioLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) def test_load(self, resource_path, mime_type, loader, create_source): @@ -21,7 +21,6 @@ def test_load(self, resource_path, mime_type, loader, create_source): assert isinstance(artifact, AudioArtifact) assert artifact.mime_type == mime_type - assert len(artifact.value) > 0 def test_load_collection(self, create_source, loader): resource_paths = ["sentences.wav", "sentences2.wav"] diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index 4812e669c..2042381bc 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -11,7 +11,7 @@ def loader(self, request): kwargs = {"encoding": encoding} if encoding is not None else {} return BlobLoader(**kwargs) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index 7af409152..3ef7e88b2 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -3,7 +3,6 @@ import pytest from griptape.loaders.csv_loader import CsvLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestCsvLoader: @@ -11,15 +10,15 @@ class TestCsvLoader: def loader(self, request): encoding = request.param if encoding is None: - return CsvLoader(embedding_driver=MockEmbeddingDriver()) + return CsvLoader() else: - return CsvLoader(embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return CsvLoader(encoding=encoding) @pytest.fixture() def loader_with_pipe_delimiter(self): - return CsvLoader(embedding_driver=MockEmbeddingDriver(), delimiter="|") + return CsvLoader(delimiter="|") - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) @@ -31,7 +30,6 @@ def test_load(self, loader, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] assert first_artifact.value == "Foo: foo1\nBar: bar1" - assert first_artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): source = create_source("test-pipe.csv") @@ -41,7 +39,6 @@ def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] assert first_artifact.value == "Bar: foo1\nFoo: bar1" - assert first_artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): resource_paths = ["test-1.csv", "test-2.csv"] @@ -53,16 +50,5 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys assert collection[loader.to_key(sources[0])][0].value == "Foo: foo1\nBar: bar1" - assert collection[loader.to_key(sources[0])][0].embedding == [0, 1] assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" - assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] - - def test_formatter_fn(self, loader, create_source): - loader.formatter_fn = lambda value: json.dumps(value) - source = create_source("test-1.csv") - - artifacts = loader.load(source) - - assert len(artifacts) == 10 - assert artifacts[0].value == '{"Foo": "foo1", "Bar": "bar1"}' diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index ade062743..1812dc531 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -1,14 +1,16 @@ from __future__ import annotations import email -from email import message -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from griptape.artifacts import ListArtifact from griptape.loaders import EmailLoader +if TYPE_CHECKING: + from email.message import Message + class TestEmailLoader: @pytest.fixture(autouse=True) @@ -127,23 +129,21 @@ def to_fetch_message(body: str, content_type: Optional[str]): return to_fetch_response(to_message(body, content_type)) -def to_fetch_response(message: message): +def to_fetch_response(message: Message): return (None, ((None, message.as_bytes()),)) -def to_message(body: str, content_type: Optional[str]) -> message: +def to_message(body: str, content_type: Optional[str]) -> Message: message = email.message_from_string(body) if content_type: message.set_type(content_type) return message -def to_value_set(artifact_or_dict: ListArtifact | dict[str, ListArtifact]) -> set[str]: - if isinstance(artifact_or_dict, ListArtifact): - return {value.value for value in artifact_or_dict.value} - elif isinstance(artifact_or_dict, dict): - return { - text_artifact.value for list_artifact in artifact_or_dict.values() for text_artifact in list_artifact.value - } +def to_value_set(artifacts: ListArtifact | dict[str, ListArtifact]) -> set[str]: + if isinstance(artifacts, dict): + return set( + {text_artifact.value for list_artifact in artifacts.values() for text_artifact in list_artifact.value} + ) else: - raise Exception + return {artifact.value for artifact in artifacts.value} diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index 7093894b0..9c491fb88 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -13,9 +13,9 @@ def loader(self): def png_loader(self): return ImageLoader(format="png") - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize( ("resource_path", "mime_type"), diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 3f4f7848e..45027b95c 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -1,29 +1,25 @@ import pytest from griptape.loaders import PdfLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestPdfLoader: @pytest.fixture() def loader(self): - return PdfLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return PdfLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("bitcoin.pdf") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 151 - assert artifacts[0].value.startswith("Bitcoin: A Peer-to-Peer") - assert artifacts[-1].value.endswith('its applications," 1957.\n9') - assert artifacts[0].embedding == [0, 1] + assert len(artifact) == 9 + assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact[-1].value.endswith('its applications," 1957.\n9') def test_load_collection(self, loader, create_source): resource_paths = ["bitcoin.pdf", "bitcoin-2.pdf"] @@ -37,7 +33,6 @@ def test_load_collection(self, loader, create_source): for key in keys: artifact = collection[key] - assert len(artifact) == 151 + assert len(artifact) == 9 assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") assert artifact[-1].value.endswith('its applications," 1957.\n9') - assert artifact[0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index 2ff6c7faf..4d33b634a 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -3,7 +3,6 @@ from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -16,7 +15,6 @@ def loader(self): engine_url="sqlite:///:memory:", create_engine_params={"connect_args": {"check_same_thread": False}, "poolclass": StaticPool}, ), - embedding_driver=MockEmbeddingDriver(), ) sql_loader.sql_driver.execute_query( @@ -35,14 +33,12 @@ def loader(self): return sql_loader def test_load(self, loader): - artifacts = loader.load("SELECT * FROM test_table;") + artifact = loader.load("SELECT * FROM test_table;") - assert len(artifacts) == 3 - assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" - assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" - - assert artifacts[0].embedding == [0, 1] + assert len(artifact) == 3 + assert artifact[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifact[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" + assert artifact[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" def test_load_collection(self, loader): sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"] @@ -55,4 +51,3 @@ def test_load_collection(self, loader): assert artifacts[loader.to_key(sources[0])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" assert artifacts[loader.to_key(sources[1])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert list(artifacts.values())[0][0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 07527f9e6..c75417f56 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -1,9 +1,6 @@ import pytest from griptape.loaders.text_loader import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestTextLoader: @@ -11,23 +8,21 @@ class TestTextLoader: def loader(self, request): encoding = request.param if encoding is None: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return TextLoader() else: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return TextLoader(encoding=encoding) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("test.txt") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 39 - assert artifacts[0].value.startswith("foobar foobar foobar") - assert artifacts[0].encoding == loader.encoding - assert artifacts[0].embedding == [0, 1] + assert artifact.value.startswith("foobar foobar foobar") + assert artifact.encoding == loader.encoding def test_load_collection(self, loader, create_source): resource_paths = ["test.txt"] @@ -39,9 +34,6 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys key = next(iter(keys)) - artifacts = collection[key] - assert len(artifacts) == 39 + artifact = collection[key] - artifact = artifacts[0] - assert artifact.embedding == [0, 1] assert artifact.encoding == loader.encoding diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index f7cccb666..d6e958042 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,7 +1,6 @@ import pytest from griptape.loaders import WebLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -13,15 +12,12 @@ def _mock_trafilatura_fetch_url(self, mocker): @pytest.fixture() def loader(self): - return WebLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return WebLoader() def test_load(self, loader): - artifacts = loader.load("https://github.com/griptape-ai/griptape") + artifact = loader.load("https://github.com/griptape-ai/griptape") - assert len(artifacts) == 1 - assert "foobar" in artifacts[0].value.lower() - - assert artifacts[0].embedding == [0, 1] + assert "foobar" in artifact.value.lower() def test_load_exception(self, mocker, loader): mocker.patch("trafilatura.fetch_url", side_effect=Exception("error")) @@ -38,9 +34,7 @@ def test_load_collection(self, loader): loader.to_key("https://github.com/griptape-ai/griptape"), loader.to_key("https://github.com/griptape-ai/griptape-docs"), ] - assert "foobar" in [a.value for artifact_list in artifacts.values() for a in artifact_list][0].lower() - - assert list(artifacts.values())[0][0].embedding == [0, 1] + assert "foobar" in [a.value for a in artifacts.values()] def test_empty_page_string_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value="") diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 469918a02..4e035bdee 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -7,7 +7,6 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader from griptape.tools import FileManagerTool from tests.utils import defaults @@ -36,7 +35,7 @@ def test_load_files_from_disk(self, file_manager): result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/bitcoin.pdf"]}}) assert isinstance(result, ListArtifact) - assert len(result.value) == 4 + assert len(result.value) == 9 def test_load_files_from_disk_with_encoding(self, file_manager): result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/test.txt"]}}) @@ -48,8 +47,7 @@ def test_load_files_from_disk_with_encoding(self, file_manager): def test_load_files_from_disk_with_encoding_failure(self): file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, + encoding="utf-8", workdir=os.path.abspath(os.path.dirname(__file__)), ) ) @@ -116,9 +114,7 @@ def test_save_content_to_file(self, temp_dir): assert result.value == "Successfully saved file" def test_save_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), workdir=temp_dir) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -127,9 +123,7 @@ def test_save_content_to_file_with_encoding(self, temp_dir): assert result.value == "Successfully saved file" def test_save_and_load_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -137,11 +131,7 @@ def test_save_and_load_content_to_file_with_encoding(self, temp_dir): assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver( - default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir - ) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.load_files_from_disk({"values": {"paths": [os.path.join("test", "foobar.txt")]}}) assert isinstance(result, ListArtifact) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py deleted file mode 100644 index 00df6958d..000000000 --- a/tests/unit/utils/test_file_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -from concurrent import futures - -from griptape import utils -from griptape.loaders import TextLoader - -MAX_TOKENS = 50 - - -class TestFileUtils: - def test_load_file(self): - dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")) - - assert file.decode("utf-8").startswith("foobar foobar foobar") - - def test_load_files(self): - dirname = os.path.dirname(__file__) - sources = ["resources/foobar-many.txt", "resources/foobar-many.txt", "resources/small.png"] - sources = [os.path.join(dirname, "../../", source) for source in sources] - files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1)) - assert len(files) == 2 - - test_file = files[utils.str_to_hash(sources[0])] - assert test_file.decode("utf-8").startswith("foobar foobar foobar") - - small_file = files[utils.str_to_hash(sources[2])] - assert len(small_file) == 97 - assert small_file[:8] == b"\x89PNG\r\n\x1a\n" - - def test_load_file_with_loader(self): - dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) - - assert len(artifacts) == 39 - assert isinstance(artifacts, list) - assert artifacts[0].value.startswith("foobar foobar foobar") - - def test_load_files_with_loader(self): - dirname = os.path.dirname(__file__) - sources = ["resources/foobar-many.txt"] - sources = [os.path.join(dirname, "../../", source) for source in sources] - files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS) - collection = loader.load_collection(list(files.values())) - - test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] - assert len(test_file_artifacts) == 39 - assert isinstance(test_file_artifacts, list) - assert test_file_artifacts[0].value.startswith("foobar foobar foobar") From e5aa986f1519fa9fbabcb2d1e98a6b652936edd0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 11:14:56 -0700 Subject: [PATCH 27/52] Use generics for loaders, remove instances of path --- .../drivers/src/image_generation_drivers_8.py | 4 +--- .../drivers/src/image_generation_drivers_9.py | 4 +--- .../drivers/src/image_query_drivers_1.py | 4 +--- .../drivers/src/image_query_drivers_2.py | 6 ++--- .../drivers/src/image_query_drivers_3.py | 4 +--- .../drivers/src/image_query_drivers_4.py | 3 +-- .../drivers/src/image_query_drivers_5.py | 4 +--- .../engines/src/image_generation_engines_3.py | 4 +--- .../engines/src/image_generation_engines_4.py | 6 ++--- .../engines/src/image_generation_engines_5.py | 6 ++--- .../engines/src/image_query_engines_1.py | 4 +--- .../structures/src/tasks_12.py | 4 +--- .../structures/src/tasks_13.py | 6 ++--- .../structures/src/tasks_14.py | 6 ++--- .../structures/src/tasks_15.py | 4 +--- .../structures/src/tasks_3.py | 4 +--- griptape/loaders/audio_loader.py | 9 ++------ griptape/loaders/base_file_loader.py | 17 +++++++++----- griptape/loaders/base_loader.py | 23 ++++++++++++------- griptape/loaders/blob_loader.py | 7 +----- griptape/loaders/csv_loader.py | 4 ++-- griptape/loaders/email_loader.py | 11 ++++----- griptape/loaders/image_loader.py | 9 +++----- griptape/loaders/pdf_loader.py | 2 -- griptape/loaders/sql_loader.py | 15 ++++-------- griptape/loaders/text_loader.py | 9 ++------ griptape/loaders/web_loader.py | 11 +++------ griptape/tasks/base_image_generation_task.py | 2 +- griptape/tools/audio_transcription/tool.py | 3 +-- griptape/tools/image_query/tool.py | 3 +-- .../tools/inpainting_image_generation/tool.py | 5 ++-- .../outpainting_image_generation/tool.py | 5 ++-- .../tools/variation_image_generation/tool.py | 3 +-- 33 files changed, 76 insertions(+), 135 deletions(-) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_8.py b/docs/griptape-framework/drivers/src/image_generation_drivers_8.py index 69437a3a5..470b47707 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_8.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_8.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.artifacts import TextArtifact from griptape.drivers import ( HuggingFacePipelineImageGenerationDriver, @@ -11,7 +9,7 @@ from griptape.tasks import VariationImageGenerationTask prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") -input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +input_image_artifact = ImageLoader().load("tests/resources/mountain.png") image_variation_task = VariationImageGenerationTask( input=(prompt_artifact, input_image_artifact), diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_9.py b/docs/griptape-framework/drivers/src/image_generation_drivers_9.py index 2054588d9..ab3dc3113 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_9.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_9.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.artifacts import TextArtifact from griptape.drivers import ( HuggingFacePipelineImageGenerationDriver, @@ -11,7 +9,7 @@ from griptape.tasks import VariationImageGenerationTask prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") -control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes()) +control_image_artifact = ImageLoader().load("canny_control_image.png") controlnet_task = VariationImageGenerationTask( input=(prompt_artifact, control_image_artifact), diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_1.py b/docs/griptape-framework/drivers/src/image_query_drivers_1.py index 0c9db5be7..0e0165d97 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_1.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_1.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AnthropicImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,6 +11,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_2.py b/docs/griptape-framework/drivers/src/image_query_drivers_2.py index 8d605c0d9..4b5b3cc9f 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_2.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_2.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AnthropicImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,9 +11,9 @@ image_query_driver=driver, ) -image_artifact1 = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact1 = ImageLoader().load("tests/resources/mountain.png") -image_artifact2 = ImageLoader().load(Path("tests/resources/cow.png").read_bytes()) +image_artifact2 = ImageLoader().load("tests/resources/cow.png") result = engine.run("Describe the weather in the image", [image_artifact1, image_artifact2]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_3.py b/docs/griptape-framework/drivers/src/image_query_drivers_3.py index 14070312b..0653d3f6e 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_3.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_3.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,6 +11,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_4.py b/docs/griptape-framework/drivers/src/image_query_drivers_4.py index 9ebf5ef59..cff4c2a10 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_4.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_4.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from griptape.drivers import AzureOpenAiImageQueryDriver from griptape.engines import ImageQueryEngine @@ -17,6 +16,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_5.py b/docs/griptape-framework/drivers/src/image_query_drivers_5.py index 2bab9a7fd..c364a24cc 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_5.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_5.py @@ -1,5 +1,3 @@ -from pathlib import Path - import boto3 from griptape.drivers import AmazonBedrockImageQueryDriver, BedrockClaudeImageQueryModelDriver @@ -16,7 +14,7 @@ engine = ImageQueryEngine(image_query_driver=driver) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") result = engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/engines/src/image_generation_engines_3.py b/docs/griptape-framework/engines/src/image_generation_engines_3.py index 83822b1bc..4bcd976d4 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_3.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_3.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,7 +13,7 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run( prompts=["A photo of a mountain landscape in winter"], diff --git a/docs/griptape-framework/engines/src/image_generation_engines_4.py b/docs/griptape-framework/engines/src/image_generation_engines_4.py index c258e1cce..e7b46b341 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_4.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_4.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,9 +13,9 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") engine.run( prompts=["A photo of a castle built into the side of a mountain"], diff --git a/docs/griptape-framework/engines/src/image_generation_engines_5.py b/docs/griptape-framework/engines/src/image_generation_engines_5.py index f91a48ec0..526ebff50 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_5.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_5.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,9 +13,9 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") engine.run( prompts=["A photo of a mountain shrouded in clouds"], diff --git a/docs/griptape-framework/engines/src/image_query_engines_1.py b/docs/griptape-framework/engines/src/image_query_engines_1.py index b0920392a..c2d08e9a9 100644 --- a/docs/griptape-framework/engines/src/image_query_engines_1.py +++ b/docs/griptape-framework/engines/src/image_query_engines_1.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -8,6 +6,6 @@ engine = ImageQueryEngine(image_query_driver=driver) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/structures/src/tasks_12.py b/docs/griptape-framework/structures/src/tasks_12.py index 917b50607..1fdc99e1c 100644 --- a/docs/griptape-framework/structures/src/tasks_12.py +++ b/docs/griptape-framework/structures/src/tasks_12.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,7 +16,7 @@ ) # Load input image artifact. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_13.py b/docs/griptape-framework/structures/src/tasks_13.py index d2aa45983..4b7616d94 100644 --- a/docs/griptape-framework/structures/src/tasks_13.py +++ b/docs/griptape-framework/structures/src/tasks_13.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,9 +16,9 @@ ) # Load input image artifacts. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_14.py b/docs/griptape-framework/structures/src/tasks_14.py index ec489096d..d2e6ba2dd 100644 --- a/docs/griptape-framework/structures/src/tasks_14.py +++ b/docs/griptape-framework/structures/src/tasks_14.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,9 +16,9 @@ ) # Load input image artifacts. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_15.py b/docs/griptape-framework/structures/src/tasks_15.py index 0c60864f7..802ac3397 100644 --- a/docs/griptape-framework/structures/src/tasks_15.py +++ b/docs/griptape-framework/structures/src/tasks_15.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -18,7 +16,7 @@ ) # Load the input image artifact. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_3.py b/docs/griptape-framework/structures/src/tasks_3.py index 6584049d0..cdfe894bd 100644 --- a/docs/griptape-framework/structures/src/tasks_3.py +++ b/docs/griptape-framework/structures/src/tasks_3.py @@ -1,10 +1,8 @@ -from pathlib import Path - from griptape.loaders import ImageLoader from griptape.structures import Agent agent = Agent() -image_artifact = ImageLoader().load(Path("tests/resources/mountain.jpg").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.jpg") agent.run([image_artifact, "What's in this image?"]) diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 1a7caefce..4e7d9c7d0 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - import filetype from attrs import define @@ -10,11 +8,8 @@ @define -class AudioLoader(BaseFileLoader): +class AudioLoader(BaseFileLoader[AudioArtifact]): """Loads audio content into audio artifacts.""" - def load(self, source: Any, *args, **kwargs) -> AudioArtifact: - return cast(AudioArtifact, super().load(source, *args, **kwargs)) - - def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact: + def parse(self, source: bytes) -> AudioArtifact: return AudioArtifact(source, format=filetype.guess(source).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py index 650dc303c..36645ef8a 100644 --- a/griptape/loaders/base_file_loader.py +++ b/griptape/loaders/base_file_loader.py @@ -1,24 +1,29 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING +from os import PathLike +from typing import TypeVar, Union from attrs import Factory, define, field +from griptape.artifacts import BaseArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver from griptape.loaders import BaseLoader -if TYPE_CHECKING: - from os import PathLike +A = TypeVar("A", bound=BaseArtifact) @define -class BaseFileLoader(BaseLoader, ABC): +class BaseFileLoader(BaseLoader[Union[str, PathLike], bytes, A], ABC): file_manager_driver: BaseFileManagerDriver = field( default=Factory(lambda: LocalFileManagerDriver(workdir=None)), kw_only=True, ) encoding: str = field(default="utf-8", kw_only=True) - def fetch(self, source: str | PathLike, *args, **kwargs) -> str | bytes: - return self.file_manager_driver.load_file(str(source), *args, **kwargs).value + def fetch(self, source: str | PathLike) -> bytes: + data = self.file_manager_driver.load_file(str(source)).value + if isinstance(data, str): + return data.encode(self.encoding) + else: + return data diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 33c98dfe4..85edf9417 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from attrs import define, field @@ -12,31 +12,38 @@ if TYPE_CHECKING: from collections.abc import Mapping - from griptape.artifacts import BaseArtifact from griptape.common import Reference +S = TypeVar("S") +F = TypeVar("F") +A = TypeVar("A", bound=BaseArtifact) + @define -class BaseLoader(FuturesExecutorMixin, ABC): +class BaseLoader(FuturesExecutorMixin, ABC, Generic[S, F, A]): reference: Optional[Reference] = field(default=None, kw_only=True) - def load(self, source: Any, *args, **kwargs) -> BaseArtifact: + def load(self, source: S, *args, **kwargs) -> A: data = self.fetch(source) - return self.parse(data) + artifact = self.parse(data) + + artifact.reference = self.reference + + return artifact @abstractmethod - def fetch(self, source: Any, *args, **kwargs) -> Any: ... + def fetch(self, source: S) -> F: ... @abstractmethod - def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ... + def parse(self, source: F) -> A: ... def load_collection( self, sources: list[Any], *args, **kwargs, - ) -> Mapping[str, BaseArtifact]: + ) -> Mapping[str, A]: # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index dbc7e66bf..b54c880b5 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - from attrs import define from griptape.artifacts import BlobArtifact @@ -9,10 +7,7 @@ @define -class BlobLoader(BaseFileLoader): - def load(self, source: Any, *args, **kwargs) -> BlobArtifact: - return cast(BlobArtifact, super().load(source, *args, **kwargs)) - +class BlobLoader(BaseFileLoader[BlobArtifact]): def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: if self.encoding is None: return BlobArtifact(source) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index fb7ba19ba..20feac524 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -11,14 +11,14 @@ @define -class CsvLoader(BaseFileLoader): +class CsvLoader(BaseFileLoader[ListArtifact]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: + def parse(self, source: bytes) -> ListArtifact: reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter) return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index 0a3d3f600..88cd8e967 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations import imaplib -from typing import Any, Optional, cast +from typing import Optional from attrs import astuple, define, field @@ -11,7 +11,7 @@ @define -class EmailLoader(BaseLoader): +class EmailLoader(BaseLoader["EmailLoader.EmailQuery", list[bytes], ListArtifact]): # pyright: ignore[reportGeneralTypeIssues] @define(frozen=True) class EmailQuery: """An email retrieval query. @@ -32,10 +32,7 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: Any, *args, **kwargs) -> ListArtifact: - return cast(ListArtifact, super().load(source, *args, **kwargs)) - - def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: + def fetch(self, source: EmailLoader.EmailQuery) -> list[bytes]: label, key, search_criteria, max_count = astuple(source) mail_bytes = [] @@ -67,7 +64,7 @@ def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: return mail_bytes - def parse(self, source: list[bytes], *args, **kwargs) -> ListArtifact: + def parse(self, source: list[bytes]) -> ListArtifact: mailparser = import_optional_dependency("mailparser") artifacts = [] for byte in source: diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 513910c78..d5f4ca742 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import Any, Optional, cast +from typing import Optional from attrs import define, field @@ -11,7 +11,7 @@ @define -class ImageLoader(BaseFileLoader): +class ImageLoader(BaseFileLoader[ImageArtifact]): """Loads images into image artifacts. Attributes: @@ -22,10 +22,7 @@ class ImageLoader(BaseFileLoader): format: Optional[str] = field(default=None, kw_only=True) - def load(self, source: Any, *args, **kwargs) -> ImageArtifact: - return cast(ImageArtifact, super().load(source, *args, **kwargs)) - - def parse(self, source: bytes, *args, **kwargs) -> ImageArtifact: + def parse(self, source: bytes) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") image = pil_image.open(BytesIO(source)) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index b1f901386..9be195e85 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -19,8 +19,6 @@ def parse( self, source: bytes, password: Optional[str] = None, - *args, - **kwargs, ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index d633eede6..594bf06e0 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,25 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast - from attrs import define, field from griptape.artifacts import ListArtifact, TextArtifact +from griptape.drivers import BaseSqlDriver from griptape.loaders import BaseLoader -if TYPE_CHECKING: - from griptape.drivers import BaseSqlDriver - @define -class SqlLoader(BaseLoader): +class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact]): sql_driver: BaseSqlDriver = field(kw_only=True) - def load(self, source: Any, *args, **kwargs) -> ListArtifact: - return cast(ListArtifact, super().load(source, *args, **kwargs)) - - def fetch(self, source: str, *args, **kwargs) -> list[BaseSqlDriver.RowResult]: + def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] - def parse(self, source: list[BaseSqlDriver.RowResult], *args, **kwargs) -> ListArtifact: + def parse(self, source: list[BaseSqlDriver.RowResult]) -> ListArtifact: return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(source)]) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 68dc324ae..7ef649bbc 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - from attrs import define, field from griptape.artifacts import TextArtifact @@ -9,13 +7,10 @@ @define -class TextLoader(BaseFileLoader): +class TextLoader(BaseFileLoader[TextArtifact]): encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: Any, *args, **kwargs) -> TextArtifact: - return cast(TextArtifact, super().load(source, *args, **kwargs)) - - def parse(self, source: str | bytes, *args, **kwargs) -> TextArtifact: + def parse(self, source: str | bytes) -> TextArtifact: if isinstance(source, str): return TextArtifact(source, encoding=self.encoding) else: diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index d31a34816..de4c32778 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - from attrs import Factory, define, field from griptape.artifacts import TextArtifact @@ -10,17 +8,14 @@ @define -class WebLoader(BaseLoader): +class WebLoader(BaseLoader[str, str, TextArtifact]): web_scraper_driver: BaseWebScraperDriver = field( default=Factory(lambda: TrafilaturaWebScraperDriver()), kw_only=True, ) - def load(self, source: Any, *args, **kwargs) -> TextArtifact: - return cast(TextArtifact, super().load(source, *args, **kwargs)) - - def fetch(self, source: str, *args, **kwargs) -> str: + def fetch(self, source: str) -> str: return self.web_scraper_driver.fetch_url(source) - def parse(self, source: str, *args, **kwargs) -> TextArtifact: + def parse(self, source: str) -> TextArtifact: return self.web_scraper_driver.extract_page(source) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 326b2a551..4a502e7cc 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -66,4 +66,4 @@ def all_negative_rulesets(self) -> list[Ruleset]: def _read_from_file(self, path: str) -> ImageArtifact: logger.info("Reading image from %s", os.path.abspath(path)) - return ImageLoader().load(Path(path).read_bytes()) + return ImageLoader().load(Path(path)) diff --git a/griptape/tools/audio_transcription/tool.py b/griptape/tools/audio_transcription/tool.py index 4174db209..826aeb895 100644 --- a/griptape/tools/audio_transcription/tool.py +++ b/griptape/tools/audio_transcription/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -32,7 +31,7 @@ class AudioTranscriptionTool(BaseTool): def transcribe_audio_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: audio_path = params["values"]["path"] - audio_artifact = self.audio_loader.load(Path(audio_path).read_bytes()) + audio_artifact = self.audio_loader.load(audio_path) return self.engine.run(audio_artifact) diff --git a/griptape/tools/image_query/tool.py b/griptape/tools/image_query/tool.py index 9d1dbb89b..7b654bd72 100644 --- a/griptape/tools/image_query/tool.py +++ b/griptape/tools/image_query/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -41,7 +40,7 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: image_artifacts = [] for image_path in image_paths: - image_artifacts.append(self.image_loader.load(Path(image_path).read_bytes())) + image_artifacts.append(self.image_loader.load(image_path)) return self.image_query_engine.run(query, image_artifacts) diff --git a/griptape/tools/inpainting_image_generation/tool.py b/griptape/tools/inpainting_image_generation/tool.py index d32f481d9..b529cb637 100644 --- a/griptape/tools/inpainting_image_generation/tool.py +++ b/griptape/tools/inpainting_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -51,8 +50,8 @@ def image_inpainting_from_file(self, params: dict[str, dict[str, str]]) -> Image image_file = params["values"]["image_file"] mask_file = params["values"]["mask_file"] - input_artifact = self.image_loader.load(Path(image_file).read_bytes()) - mask_artifact = self.image_loader.load(Path(mask_file).read_bytes()) + input_artifact = self.image_loader.load(image_file) + mask_artifact = self.image_loader.load(mask_file) return self._generate_inpainting( prompt, negative_prompt, cast(ImageArtifact, input_artifact), cast(ImageArtifact, mask_artifact) diff --git a/griptape/tools/outpainting_image_generation/tool.py b/griptape/tools/outpainting_image_generation/tool.py index afa39e178..47863b03d 100644 --- a/griptape/tools/outpainting_image_generation/tool.py +++ b/griptape/tools/outpainting_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -51,8 +50,8 @@ def image_outpainting_from_file(self, params: dict[str, dict[str, str]]) -> Imag image_file = params["values"]["image_file"] mask_file = params["values"]["mask_file"] - input_artifact = self.image_loader.load(Path(image_file).read_bytes()) - mask_artifact = self.image_loader.load(Path(mask_file).read_bytes()) + input_artifact = self.image_loader.load(image_file) + mask_artifact = self.image_loader.load(mask_file) return self._generate_outpainting(prompt, negative_prompt, input_artifact, mask_artifact) diff --git a/griptape/tools/variation_image_generation/tool.py b/griptape/tools/variation_image_generation/tool.py index 0d4456c2f..1fb8c8bcc 100644 --- a/griptape/tools/variation_image_generation/tool.py +++ b/griptape/tools/variation_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -49,7 +48,7 @@ def image_variation_from_file(self, params: dict[str, dict[str, str]]) -> ImageA negative_prompt = params["values"]["negative_prompt"] image_file = params["values"]["image_file"] - image_artifact = self.image_loader.load(Path(image_file).read_bytes()) + image_artifact = self.image_loader.load(image_file) return self._generate_variation(prompt, negative_prompt, image_artifact) From 293191518ae9417726cce28f04a725197577f555 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 12:12:59 -0700 Subject: [PATCH 28/52] Start changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4516c59b6..99b7614e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseArtifact.to_bytes()` method to convert an Artifact's value to bytes. - `BlobArtifact.base64` property for converting a `BlobArtifact`'s value to a base64 string. - `CsvLoader`/`SqlLoader`/`DataframeLoader` `formatter_fn` field for customizing how SQL results are formatted into `TextArtifact`s. +- `BaseFileLoader` for Loaders that load from a path. +- `BaseLoader.fetch()` method for fetching data from a source. +- `BaseLoader.parse()` method for parsing fetched data. ### Changed - **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. From ebcfff4b58dcd2b345fe5af1d8727c067ca86d87 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 14:04:48 -0700 Subject: [PATCH 29/52] Fix bad merge --- griptape/artifacts/__init__.py | 5 ----- griptape/artifacts/base_system_artifact.py | 10 ---------- griptape/artifacts/error_artifact.py | 2 +- griptape/artifacts/info_artifact.py | 2 +- griptape/artifacts/list_artifact.py | 4 +--- griptape/artifacts/text_artifact.py | 2 +- 6 files changed, 4 insertions(+), 21 deletions(-) delete mode 100644 griptape/artifacts/base_system_artifact.py diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 30c293b49..74e545640 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,5 +1,4 @@ from .base_artifact import BaseArtifact -from .base_system_artifact import BaseSystemArtifact from .text_artifact import TextArtifact from .blob_artifact import BlobArtifact @@ -10,18 +9,15 @@ from .audio_artifact import AudioArtifact from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact -from .table_artifact import TableArtifact from .generic_artifact import GenericArtifact from .error_artifact import ErrorArtifact from .info_artifact import InfoArtifact -from .list_artifact import ListArtifact __all__ = [ "BaseArtifact", - "BaseSystemArtifact", "ErrorArtifact", "InfoArtifact", "TextArtifact", @@ -34,5 +30,4 @@ "AudioArtifact", "ActionArtifact", "GenericArtifact", - "TableArtifact", ] diff --git a/griptape/artifacts/base_system_artifact.py b/griptape/artifacts/base_system_artifact.py deleted file mode 100644 index e5e7bfab6..000000000 --- a/griptape/artifacts/base_system_artifact.py +++ /dev/null @@ -1,10 +0,0 @@ -from abc import ABC - -from griptape.artifacts import BaseArtifact - - -class BaseSystemArtifact(BaseArtifact, ABC): - """Serves as the base class for all Artifacts specific to Griptape.""" - - def to_text(self) -> str: - return self.value diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index a3a8e3bb7..27e6a37ab 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseSystemArtifact +from griptape.artifacts import BaseArtifact @define diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 44d42658a..3391554e9 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -2,7 +2,7 @@ from attrs import define, field -from griptape.artifacts import BaseSystemArtifact +from griptape.artifacts import BaseArtifact @define diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 366ae5c3a..0e6f81ca5 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -4,15 +4,13 @@ from attrs import Attribute, define, field -from griptape.artifacts import BaseArtifact, BaseSystemArtifact +from griptape.artifacts import BaseArtifact if TYPE_CHECKING: from collections.abc import Iterator, Sequence T = TypeVar("T", bound=BaseArtifact, covariant=True) - from griptape.artifacts import BaseArtifact - @define class ListArtifact(BaseArtifact, Generic[T]): diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index 735256313..9623c8096 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -33,7 +33,7 @@ def generate_embedding(self, driver: BaseEmbeddingDriver) -> list[float]: self.embedding.clear() self.embedding.extend(embedding) - return self._embedding + return self.embedding def token_count(self, tokenizer: BaseTokenizer) -> int: return tokenizer.count_tokens(str(self.value)) From 89474252825708e5b74d9b9f4632324e72eaa1c5 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 14:15:33 -0700 Subject: [PATCH 30/52] Remove table artifact --- griptape/artifacts/table_artifact.py | 41 ---------------------------- 1 file changed, 41 deletions(-) delete mode 100644 griptape/artifacts/table_artifact.py diff --git a/griptape/artifacts/table_artifact.py b/griptape/artifacts/table_artifact.py deleted file mode 100644 index d3ca003fa..000000000 --- a/griptape/artifacts/table_artifact.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -import csv -import io -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.artifacts.text_artifact import TextArtifact - -if TYPE_CHECKING: - from collections.abc import Sequence - - -@define -class TableArtifact(TextArtifact): - value: list[dict] = field(factory=list, metadata={"serializable": True}) - delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) - fieldnames: Optional[Sequence[str]] = field(factory=list, metadata={"serializable": True}) - quoting: int = field(default=csv.QUOTE_MINIMAL, kw_only=True, metadata={"serializable": True}) - line_terminator: str = field(default="\n", kw_only=True, metadata={"serializable": True}) - - def __bool__(self) -> bool: - return len(self.value) > 0 - - def to_text(self) -> str: - with io.StringIO() as csvfile: - fieldnames = (self.value[0].keys() if self.value else []) if self.fieldnames is None else self.fieldnames - - writer = csv.DictWriter( - csvfile, - fieldnames=fieldnames, - quoting=self.quoting, - delimiter=self.delimiter, - lineterminator=self.line_terminator, - ) - - writer.writeheader() - writer.writerows(self.value) - - return csvfile.getvalue().strip() From 05e42313b1a7b950e07d1a35a3cee5697d1d99a0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 14:18:49 -0700 Subject: [PATCH 31/52] More changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99b7614e2..33fed44c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseFileLoader` for Loaders that load from a path. - `BaseLoader.fetch()` method for fetching data from a source. - `BaseLoader.parse()` method for parsing fetched data. +- `BaseFileManager.encoding` to specify the encoding when loading and saving files. ### Changed - **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. @@ -23,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.format` is now required. +- **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. +- `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. From c93f512427b73913691b3cbfea6539dd2f9c9a44 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 14:36:37 -0700 Subject: [PATCH 32/52] Type signature cleanup --- CHANGELOG.md | 7 ++++++- docs/examples/src/talk_to_a_pdf_1.py | 2 +- .../engines/src/summary_engines_1.py | 2 +- griptape/artifacts/__init__.py | 9 +++------ griptape/chunkers/base_chunker.py | 5 +++-- griptape/loaders/audio_loader.py | 4 ++-- griptape/loaders/base_loader.py | 8 ++++---- griptape/loaders/blob_loader.py | 6 +++--- griptape/loaders/csv_loader.py | 4 ++-- griptape/loaders/email_loader.py | 4 ++-- griptape/loaders/image_loader.py | 9 +++++---- griptape/loaders/pdf_loader.py | 10 ++++------ griptape/loaders/sql_loader.py | 4 ++-- griptape/loaders/text_loader.py | 8 ++++---- griptape/loaders/web_loader.py | 4 ++-- 15 files changed, 44 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33fed44c7..ae297663e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,17 +14,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseLoader.fetch()` method for fetching data from a source. - `BaseLoader.parse()` method for parsing fetched data. - `BaseFileManager.encoding` to specify the encoding when loading and saving files. +- `BaseWebScraperDriver.extract_page()` method for extracting data from an already scraped web page. +- `TextLoaderRetrievalRagModule.chunker` for specifying the chunking strategy. ### Changed - **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. +- **BREAKING**: Removed `DataframeLoader`. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. -- **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. +- **BREAKING**: Removed `CsvRowArtifact`. +- **BREAKING**: `CsvLoader` and `SqlLoader` now return `ListArtifact[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. - **BREAKING**: Removed `AudioArtifact.media_type`. - **BREAKING**: Removed `BlobArtifact.dir_name`. - **BREAKING**: Moved `ImageArtifact.prompt` and `ImageArtifact.model` into `ImageArtifact.meta`. - **BREAKING**: `ImageArtifact.format` is now required. - **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. +- **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. - `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index aa2b1f153..be0848b9e 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -31,7 +31,7 @@ rag_engine=engine, ) -artifacts = PdfLoader().load(response.content) +artifacts = PdfLoader().parse(response.content) chunks = TextChunker().chunk(artifacts) vector_store.upsert_text_artifacts({namespace: chunks}) diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index 201aadfa4..d84f16798 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -10,7 +10,7 @@ prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), ) -artifact = PdfLoader().load(response.content) +artifact = PdfLoader().parse(response.content) chunks = TextChunker().chunk(artifact) text = "\n\n".join([a.value for a in chunks]) diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 74e545640..c0e5f767e 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -1,20 +1,17 @@ from .base_artifact import BaseArtifact - +from .error_artifact import ErrorArtifact +from .info_artifact import InfoArtifact from .text_artifact import TextArtifact +from .json_artifact import JsonArtifact from .blob_artifact import BlobArtifact from .boolean_artifact import BooleanArtifact from .csv_row_artifact import CsvRowArtifact from .list_artifact import ListArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact -from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact - from .generic_artifact import GenericArtifact -from .error_artifact import ErrorArtifact -from .info_artifact import InfoArtifact - __all__ = [ "BaseArtifact", diff --git a/griptape/chunkers/base_chunker.py b/griptape/chunkers/base_chunker.py index 623185237..9b6ef64b9 100644 --- a/griptape/chunkers/base_chunker.py +++ b/griptape/chunkers/base_chunker.py @@ -6,6 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts import TextArtifact +from griptape.artifacts.list_artifact import ListArtifact from griptape.chunkers import ChunkSeparator from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -32,8 +33,8 @@ def validate_max_tokens(self, _: Attribute, max_tokens: int) -> None: if max_tokens < 0: raise ValueError("max_tokens must be 0 or greater.") - def chunk(self, text: TextArtifact | str) -> list[TextArtifact]: - text = text.value if isinstance(text, TextArtifact) else text + def chunk(self, text: TextArtifact | ListArtifact | str) -> list[TextArtifact]: + text = text.to_text() if isinstance(text, (TextArtifact, ListArtifact)) else text return [TextArtifact(c) for c in self._chunk_recursively(text)] diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 4e7d9c7d0..0bff5c642 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -11,5 +11,5 @@ class AudioLoader(BaseFileLoader[AudioArtifact]): """Loads audio content into audio artifacts.""" - def parse(self, source: bytes) -> AudioArtifact: - return AudioArtifact(source, format=filetype.guess(source).extension) + def parse(self, data: bytes) -> AudioArtifact: + return AudioArtifact(data, format=filetype.guess(data).extension) diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 85edf9417..93030a197 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -14,9 +14,9 @@ from griptape.common import Reference -S = TypeVar("S") -F = TypeVar("F") -A = TypeVar("A", bound=BaseArtifact) +S = TypeVar("S") # Type for the input source +F = TypeVar("F") # Type for the fetched data +A = TypeVar("A", bound=BaseArtifact) # Type for the returned Artifact @define @@ -36,7 +36,7 @@ def load(self, source: S, *args, **kwargs) -> A: def fetch(self, source: S) -> F: ... @abstractmethod - def parse(self, source: F) -> A: ... + def parse(self, data: F) -> A: ... def load_collection( self, diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index b54c880b5..df148e66a 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -8,8 +8,8 @@ @define class BlobLoader(BaseFileLoader[BlobArtifact]): - def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: + def parse(self, data: bytes) -> BlobArtifact: if self.encoding is None: - return BlobArtifact(source) + return BlobArtifact(data) else: - return BlobArtifact(source, encoding=self.encoding) + return BlobArtifact(data, encoding=self.encoding) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 20feac524..22697a2d3 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -18,7 +18,7 @@ class CsvLoader(BaseFileLoader[ListArtifact]): default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def parse(self, source: bytes) -> ListArtifact: - reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter) + def parse(self, data: bytes) -> ListArtifact: + reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index 88cd8e967..08b22a1d6 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -64,10 +64,10 @@ def fetch(self, source: EmailLoader.EmailQuery) -> list[bytes]: return mail_bytes - def parse(self, source: list[bytes]) -> ListArtifact: + def parse(self, data: list[bytes]) -> ListArtifact: mailparser = import_optional_dependency("mailparser") artifacts = [] - for byte in source: + for byte in data: message = mailparser.parse_from_bytes(byte) # Note: mailparser only populates the text_plain field diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index d5f4ca742..3af3922dc 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -22,14 +22,15 @@ class ImageLoader(BaseFileLoader[ImageArtifact]): format: Optional[str] = field(default=None, kw_only=True) - def parse(self, source: bytes) -> ImageArtifact: + def parse(self, data: bytes) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") - image = pil_image.open(BytesIO(source)) + image = pil_image.open(BytesIO(data)) + # Normalize format only if requested. if self.format is not None: byte_stream = BytesIO() image.save(byte_stream, format=self.format) image = pil_image.open(byte_stream) - source = byte_stream.getvalue() + data = byte_stream.getvalue() - return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) + return ImageArtifact(data, format=image.format.lower(), width=image.width, height=image.height) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 9be195e85..5bf5337ae 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import Any, Optional, cast +from typing import Optional from attrs import define @@ -12,16 +12,14 @@ @define class PdfLoader(BaseFileLoader): - def load(self, source: Any, *args, **kwargs) -> TextArtifact: - return cast(TextArtifact, super().load(source, *args, **kwargs)) - def parse( self, - source: bytes, + data: bytes, + *, password: Optional[str] = None, ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") - reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) + reader = pypdf.PdfReader(BytesIO(data), strict=True, password=password) pages = [TextArtifact(p.extract_text()) for p in reader.pages] return ListArtifact(pages) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 594bf06e0..e61aa06d1 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -14,5 +14,5 @@ class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact]): def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] - def parse(self, source: list[BaseSqlDriver.RowResult]) -> ListArtifact: - return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(source)]) + def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact: + return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(data)]) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 7ef649bbc..c33eb9018 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -10,8 +10,8 @@ class TextLoader(BaseFileLoader[TextArtifact]): encoding: str = field(default="utf-8", kw_only=True) - def parse(self, source: str | bytes) -> TextArtifact: - if isinstance(source, str): - return TextArtifact(source, encoding=self.encoding) + def parse(self, data: str | bytes) -> TextArtifact: + if isinstance(data, str): + return TextArtifact(data, encoding=self.encoding) else: - return TextArtifact(source.decode(self.encoding), encoding=self.encoding) + return TextArtifact(data.decode(self.encoding), encoding=self.encoding) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index de4c32778..697c18bba 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -17,5 +17,5 @@ class WebLoader(BaseLoader[str, str, TextArtifact]): def fetch(self, source: str) -> str: return self.web_scraper_driver.fetch_url(source) - def parse(self, source: str) -> TextArtifact: - return self.web_scraper_driver.extract_page(source) + def parse(self, data: str) -> TextArtifact: + return self.web_scraper_driver.extract_page(data) From b6a105597b0013b508a39bc690df395f20712587 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 14:46:33 -0700 Subject: [PATCH 33/52] Use new list artifact --- griptape/loaders/csv_loader.py | 2 +- griptape/loaders/email_loader.py | 2 +- griptape/loaders/sql_loader.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 22697a2d3..1ac1019fb 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -18,7 +18,7 @@ class CsvLoader(BaseFileLoader[ListArtifact]): default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def parse(self, data: bytes) -> ListArtifact: + def parse(self, data: bytes) -> ListArtifact[TextArtifact]: reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index 08b22a1d6..8e935cfa4 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -64,7 +64,7 @@ def fetch(self, source: EmailLoader.EmailQuery) -> list[bytes]: return mail_bytes - def parse(self, data: list[bytes]) -> ListArtifact: + def parse(self, data: list[bytes]) -> ListArtifact[TextArtifact]: mailparser = import_optional_dependency("mailparser") artifacts = [] for byte in data: diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index e61aa06d1..fc7d1afb0 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -14,5 +14,5 @@ class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact]): def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] - def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact: + def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[TextArtifact]: return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(data)]) From c67ee5669a99bfe0f91b5241617943a5604fa109 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 14:50:56 -0700 Subject: [PATCH 34/52] More changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae297663e..018e11833 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseFileManager.encoding` to specify the encoding when loading and saving files. - `BaseWebScraperDriver.extract_page()` method for extracting data from an already scraped web page. - `TextLoaderRetrievalRagModule.chunker` for specifying the chunking strategy. +- `file_utils.get_mime_type` utility for getting the MIME type of a file. ### Changed - **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. @@ -30,6 +31,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `ImageArtifact.format` is now required. - **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. - **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. +- **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`. +- **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras. +- `filetype` is now a core dependency. +- `FileManagerTool` now uses `filetype` for more accurate file type detection. - `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. From cd6e9afa2462b59244d3730783a411b4cc558a08 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 5 Sep 2024 15:16:55 -0700 Subject: [PATCH 35/52] Add migration --- CHANGELOG.md | 5 ++-- MIGRATION.md | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 018e11833..35d82595d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. - **BREAKING**: Removed `DataframeLoader`. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. -- **BREAKING**: Removed `CsvRowArtifact`. - **BREAKING**: `CsvLoader` and `SqlLoader` now return `ListArtifact[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. - **BREAKING**: Removed `AudioArtifact.media_type`. @@ -32,7 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. - **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. - **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`. -- **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras. +- **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed. +- **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. +- **BREAKING**: Removed `DataframeLoader`. - `filetype` is now a core dependency. - `FileManagerTool` now uses `filetype` for more accurate file type detection. - `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. diff --git a/MIGRATION.md b/MIGRATION.md index 016a93f03..8f7d79409 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,8 +1,84 @@ # Migration Guide This document provides instructions for migrating your codebase to accommodate breaking changes introduced in new versions of Griptape. + ## 0.31.X to 0.32.X +### Removed `DataframeLoader` + +`DataframeLoader` has been removed. Use `CsvLoader.parse` or build `TextArtifact`s from the dataframe instead. + +#### Before + +```python +DataframeLoader().load(df) +``` + +#### After +```python +# Convert the dataframe to csv bytes and parse it +CsvLoader().parse(bytes(df.to_csv(line_terminator='\r\n', index=False), encoding='utf-8')) +# Or build TextArtifacts from the dataframe +[TextArtifact(row) for row in source.to_dict(orient="records")] +``` + +### `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. + +#### Before +```python +PdfLoader().load(Path("attention.pdf").read_bytes()) +PdfLoader().load_collection([Path("attention.pdf").read_bytes(), Path("CoT.pdf").read_bytes()]) +``` + +#### After +```python +PdfLoader().load("attention.pdf") +PdfLoader().load_collection([Path("attention.pdf"), "CoT.pdf"]) +``` + +### Removed `fileutils.load_file` and `fileutils.load_files` + +`griptape.utils.file_utils.load_file` and `griptape.utils.file_utils.load_files` have been removed. +You can now pass the file path directly to the Loader. + +#### Before + +```python +PdfLoader().load(load_file("attention.pdf").read_bytes()) +PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) +``` + +```python +PdfLoader().load("attention.pdf") +PdfLoader().load_collection(["attention.pdf", "CoT.pdf"]) +``` + +### Loaders no longer chunk data + +Loaders no longer chunk the data after loading it. If you need to chunk the data, use a [Chunker](https://docs.griptape.ai/stable/griptape-framework/data/chunkers/) after loading the data. + +#### Before + +```python +chunks = PdfLoader().load("attention.pdf") +vector_store.upsert_text_artifacts( + { + "griptape": chunks, + } +) +``` + +#### After +```python +artifact = PdfLoader().load("attention.pdf") +chunks = Chunker().chunk(artifact) +vector_store.upsert_text_artifacts( + { + "griptape": chunks, + } +) +``` + ### Removed `MediaArtifact` `MediaArtifact` has been removed. Use `ImageArtifact` or `AudioArtifact` instead. From 9a9769598b24b166f946688b15d9d1fe35334e7a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:24:39 -0700 Subject: [PATCH 36/52] Fix bad merge --- griptape/loaders/csv_loader.py | 10 +++++----- griptape/loaders/sql_loader.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 1ac1019fb..032a4e42a 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -6,19 +6,19 @@ from attrs import define, field -from griptape.artifacts import ListArtifact, TextArtifact -from griptape.loaders.base_file_loader import BaseFileLoader +from griptape.artifacts import CsvRowArtifact, ListArtifact +from griptape.loaders import BaseFileLoader @define -class CsvLoader(BaseFileLoader[ListArtifact]): +class CsvLoader(BaseFileLoader[ListArtifact[CsvRowArtifact]]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def parse(self, data: bytes) -> ListArtifact[TextArtifact]: + def parse(self, data: bytes) -> ListArtifact[CsvRowArtifact]: reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) - return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) + return ListArtifact([CsvRowArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index fc7d1afb0..cfc22cbb4 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -2,7 +2,7 @@ from attrs import define, field -from griptape.artifacts import ListArtifact, TextArtifact +from griptape.artifacts import CsvRowArtifact, ListArtifact from griptape.drivers import BaseSqlDriver from griptape.loaders import BaseLoader @@ -14,5 +14,5 @@ class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact]): def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] - def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[TextArtifact]: - return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(data)]) + def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[CsvRowArtifact]: + return ListArtifact([CsvRowArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(data)]) From 3284e5324e2366b3bf7ed292d3462fab70534344 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:34:48 -0700 Subject: [PATCH 37/52] Fix bad merge --- MIGRATION.md | 1 + griptape/loaders/csv_loader.py | 2 +- griptape/loaders/sql_loader.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/MIGRATION.md b/MIGRATION.md index 8f7d79409..fc3c297be 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -172,6 +172,7 @@ print(type(results[0].value)) # ```python results = CsvLoader().load(Path("people.csv").read_text()) +print(type(results)) # print(results[0].value) # name: John\nAge: 30 print(type(results[0].value)) # diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 032a4e42a..22c2f2e00 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -21,4 +21,4 @@ class CsvLoader(BaseFileLoader[ListArtifact[CsvRowArtifact]]): def parse(self, data: bytes) -> ListArtifact[CsvRowArtifact]: reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) - return ListArtifact([CsvRowArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) + return ListArtifact([CsvRowArtifact(row, meta={"row_num": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index cfc22cbb4..5423fe55f 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -15,4 +15,4 @@ def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[CsvRowArtifact]: - return ListArtifact([CsvRowArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(data)]) + return ListArtifact([CsvRowArtifact(row.cells, meta={"row_num": row_num}) for row_num, row in enumerate(data)]) From 55ccab6c771ea28f6b43a2683101013623a5031c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:36:43 -0700 Subject: [PATCH 38/52] Fixed bad changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35d82595d..0eff17433 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed. - **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. - **BREAKING**: Removed `DataframeLoader`. +- `LocalFileManagerDriver.workdir` is now optional. - `filetype` is now a core dependency. - `FileManagerTool` now uses `filetype` for more accurate file type detection. - `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. From 55a021a32b02f8d426adce953fc111b09aa5391f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 6 Sep 2024 10:54:43 -0700 Subject: [PATCH 39/52] Update docs --- docs/griptape-framework/data/loaders.md | 86 ++++++++++++------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index f1826e67a..1fbe24b3a 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -5,96 +5,96 @@ search: ## Overview -Loaders are used to load textual data from different sources and chunk it into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s. -Each loader can be used to load a single "document" with [load()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load) or -multiple documents with [load_collection()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load_collection). +Loaders are used to load data from sources and parse it into [Artifact](../../griptape-framework/data/artifacts.md)s. +Each loader can be used to load a single "source" with [load()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load) or +multiple sources with [load_collection()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load_collection). -## PDF -!!! info - This driver requires the `loaders-pdf` [extra](../index.md#extras). +## File + +The following Loaders load a file using a [FileManagerDriver](../../reference/griptape/drivers/file_manager_driver.md) and loads the resulting data into an [Artifact](../../reference/griptape/artifacts/artifact.md) for the respective file type. + +### Text -Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) and can be used to load PDFs from a path or from an IO stream: +Loads text files into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python ---8<-- "docs/griptape-framework/data/src/loaders_1.py" +--8<-- "docs/griptape-framework/data/src/loaders_5.py" ``` -## SQL +### PDF -Can be used to load data from a SQL database into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: +!!! info + This driver requires the `loaders-pdf` [extra](../index.md#extras). + +Loads PDF files into [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) containing a page of the PDF: ```python ---8<-- "docs/griptape-framework/data/src/loaders_2.py" +--8<-- "docs/griptape-framework/data/src/loaders_1.py" ``` -## CSV +### CSV -Can be used to load CSV files into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: +Loads CSV files into [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) containing a row of the CSV: ```python --8<-- "docs/griptape-framework/data/src/loaders_3.py" ``` -## Text +### Image + +!!! info + This driver requires the `loaders-image` [extra](../index.md#extras). + +Loads images into [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md)s: -Used to load arbitrary text and text files: ```python ---8<-- "docs/griptape-framework/data/src/loaders_5.py" +--8<-- "docs/griptape-framework/data/src/loaders_7.py" ``` -You can set a custom [tokenizer](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.tokenizer), [max_tokens](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.max_tokens) parameter, and [chunker](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.chunker). +By default, the Image Loader will load images in their native format, but not all models work on all formats. To normalize the format of Artifacts returned by the Loader, set the `format` field. -## Web +```python +--8<-- "docs/griptape-framework/data/src/loaders_8.py" +``` -!!! info - This driver requires the `loaders-web` [extra](../index.md#extras). +### Audio -Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) and can be used to load web pages: +Loads audio files into [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md)s: + +The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. ```python ---8<-- "docs/griptape-framework/data/src/loaders_6.py" +--8<-- "docs/griptape-framework/data/src/loaders_10.py" ``` -## Image +## Web !!! info - This driver requires the `loaders-image` [extra](../index.md#extras). + This driver requires the `loaders-web` [extra](../index.md#extras). -The Image Loader is used to load an image as an [ImageArtifact](./artifacts.md#image). The Loader operates on image bytes that can be sourced from files on disk, downloaded images, or images in memory. +Scrapes web pages using a [WebScraperDriver](../../reference/griptape/drivers/web_scraper_driver.md) and loads the resulting text into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s. ```python ---8<-- "docs/griptape-framework/data/src/loaders_7.py" +--8<-- "docs/griptape-framework/data/src/loaders_6.py" ``` -By default, the Image Loader will load images in their native format, but not all models work on all formats. To normalize the format of Artifacts returned by the Loader, set the `format` field. +## SQL + +Loads data from a SQL database using a [SQLDriver](../../reference/griptape/drivers/sql_driver.md) and loads the resulting data into [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md) containing a row of the SQL query. ```python ---8<-- "docs/griptape-framework/data/src/loaders_8.py" +--8<-- "docs/griptape-framework/data/src/loaders_2.py" ``` - ## Email !!! info This driver requires the `loaders-email` [extra](../index.md#extras). -Can be used to load email from an imap server: +Loads data from an imap email server into a [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) containing an email. ```python --8<-- "docs/griptape-framework/data/src/loaders_9.py" ``` - -## Audio - -!!! info - This driver requires the `loaders-audio` [extra](../index.md#extras). - -The [Audio Loader](../../reference/griptape/loaders/audio_loader.md) is used to load audio content as an [AudioArtifact](./artifacts.md#audio). The Loader operates on audio bytes that can be sourced from files on disk, downloaded audio, or audio in memory. - -The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. - -```python ---8<-- "docs/griptape-framework/data/src/loaders_10.py" -``` From f78d29d3bf4479d7436755375934d383c70200f6 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 10 Sep 2024 14:36:52 -0700 Subject: [PATCH 40/52] Fix doc links --- docs/griptape-framework/data/loaders.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 1fbe24b3a..a8a8cb7c5 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -12,11 +12,11 @@ multiple sources with [load_collection()](../../reference/griptape/loaders/base_ ## File -The following Loaders load a file using a [FileManagerDriver](../../reference/griptape/drivers/file_manager_driver.md) and loads the resulting data into an [Artifact](../../reference/griptape/artifacts/artifact.md) for the respective file type. +The following Loaders load a file using a [FileManagerDriver](../../reference/griptape/drivers/file_manager/base_file_manager_driver.md) and loads the resulting data into an [Artifact](../../griptape-framework/data/artifacts.md) for the respective file type. ### Text -Loads text files into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: +Loads text files into [TextArtifact](../../griptape-framework/data/artifacts.md#text)s: ```python --8<-- "docs/griptape-framework/data/src/loaders_5.py" @@ -27,7 +27,7 @@ Loads text files into [TextArtifact](../../reference/griptape/artifacts/text_art !!! info This driver requires the `loaders-pdf` [extra](../index.md#extras). -Loads PDF files into [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) containing a page of the PDF: +Loads PDF files into [ListArtifact](../../griptape-framework/data/artifacts.md#list)s, where each element is a [TextArtifact](../../griptape-framework/data/artifacts.md#text) containing a page of the PDF: ```python --8<-- "docs/griptape-framework/data/src/loaders_1.py" @@ -35,7 +35,7 @@ Loads PDF files into [ListArtifact](../../reference/griptape/artifacts/list_arti ### CSV -Loads CSV files into [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) containing a row of the CSV: +Loads CSV files into [ListArtifact](../../griptape-framework/data/artifacts.md#list)s, where each element is a [TextArtifact](../../griptape-framework/data/artifacts.md#text) containing a row of the CSV: ```python --8<-- "docs/griptape-framework/data/src/loaders_3.py" @@ -46,7 +46,7 @@ Loads CSV files into [ListArtifact](../../reference/griptape/artifacts/list_arti !!! info This driver requires the `loaders-image` [extra](../index.md#extras). -Loads images into [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md)s: +Loads images into [ImageArtifact](../../griptape-framework/data/artifacts.md#image)s: ```python @@ -61,7 +61,7 @@ By default, the Image Loader will load images in their native format, but not al ### Audio -Loads audio files into [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md)s: +Loads audio files into [AudioArtifact](../../griptape-framework/data/artifacts.md#audio)s: The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. @@ -74,7 +74,7 @@ The Loader will load audio in its native format and populates the resulting Arti !!! info This driver requires the `loaders-web` [extra](../index.md#extras). -Scrapes web pages using a [WebScraperDriver](../../reference/griptape/drivers/web_scraper_driver.md) and loads the resulting text into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s. +Scrapes web pages using a [WebScraperDriver](../drivers/web-scraper-drivers.md) and loads the resulting text into [TextArtifact](../../griptape-framework/data/artifacts.md#text)s. ```python --8<-- "docs/griptape-framework/data/src/loaders_6.py" @@ -82,7 +82,7 @@ Scrapes web pages using a [WebScraperDriver](../../reference/griptape/drivers/we ## SQL -Loads data from a SQL database using a [SQLDriver](../../reference/griptape/drivers/sql_driver.md) and loads the resulting data into [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md) containing a row of the SQL query. +Loads data from a SQL database using a [SQLDriver](../drivers/sql-drivers.md) and loads the resulting data into [ListArtifact](../../griptape-framework/data/artifacts.md#list)s, where each element is a [CsvRowArtifact](../../griptape-framework/data/artifacts.md#csv) containing a row of the SQL query. ```python --8<-- "docs/griptape-framework/data/src/loaders_2.py" From eda407b3735fd1944a2ae95919a800b036b4c72b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 09:37:21 -0700 Subject: [PATCH 41/52] Fix bad merge --- griptape/artifacts/blob_artifact.py | 11 ----------- griptape/artifacts/json_artifact.py | 7 ------- griptape/loaders/base_loader.py | 1 + 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 7f7e4915a..7c814a052 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -7,13 +7,6 @@ from griptape.artifacts import BaseArtifact -def value_to_bytes(value: Any) -> bytes: - if isinstance(value, bytes): - return value - else: - return str(value).encode() - - @define class BlobArtifact(BaseArtifact): """Stores arbitrary binary data. @@ -35,10 +28,6 @@ def base64(self) -> str: def mime_type(self) -> str: return "application/octet-stream" - @property - def mime_type(self) -> str: - return "application/octet-stream" - def to_bytes(self) -> bytes: return self.value diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index 0f788c94f..57700afd5 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -10,13 +10,6 @@ Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] -def value_to_json(value: Any) -> Json: - if isinstance(value, str): - return json.loads(value) - else: - return json.loads(json.dumps(value)) - - @define class JsonArtifact(BaseArtifact): """Stores JSON data. diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 93030a197..296e3a385 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -5,6 +5,7 @@ from attrs import define, field +from griptape.artifacts import BaseArtifact from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin from griptape.utils.futures import execute_futures_dict from griptape.utils.hash import bytes_to_hash, str_to_hash From f2d7bd7765d17b82009f49e93e1d3cba3a2cc1e2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 17:01:04 -0700 Subject: [PATCH 42/52] Fix bad merge --- griptape/loaders/csv_loader.py | 12 +++++++----- griptape/loaders/sql_loader.py | 13 +++++++++---- tests/unit/loaders/test_csv_loader.py | 2 -- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 22c2f2e00..4487d7aec 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,23 +2,25 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import Callable from attrs import define, field -from griptape.artifacts import CsvRowArtifact, ListArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.loaders import BaseFileLoader @define -class CsvLoader(BaseFileLoader[ListArtifact[CsvRowArtifact]]): +class CsvLoader(BaseFileLoader[ListArtifact[TextArtifact]]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def parse(self, data: bytes) -> ListArtifact[CsvRowArtifact]: + def parse(self, data: bytes) -> ListArtifact[TextArtifact]: reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) - return ListArtifact([CsvRowArtifact(row, meta={"row_num": row_num}) for row_num, row in enumerate(reader)]) + return ListArtifact( + [TextArtifact(self.formatter_fn(row), meta={"row_num": row_num}) for row_num, row in enumerate(reader)] + ) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 5423fe55f..0c6e8bdf9 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,18 +1,23 @@ from __future__ import annotations +from typing import Callable + from attrs import define, field -from griptape.artifacts import CsvRowArtifact, ListArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers import BaseSqlDriver from griptape.loaders import BaseLoader @define -class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact]): +class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact[TextArtifact]]): sql_driver: BaseSqlDriver = field(kw_only=True) + formatter_fn: Callable[[dict], str] = field( + default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True + ) def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] - def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[CsvRowArtifact]: - return ListArtifact([CsvRowArtifact(row.cells, meta={"row_num": row_num}) for row_num, row in enumerate(data)]) + def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[TextArtifact]: + return ListArtifact([TextArtifact(self.formatter_fn(row.cells)) for row in data]) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index 3ef7e88b2..9c5d0febb 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,5 +1,3 @@ -import json - import pytest from griptape.loaders.csv_loader import CsvLoader From d940a984a7de97962b44d2baccc80d8426eed604 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 16 Sep 2024 10:19:04 -0700 Subject: [PATCH 43/52] Fix PR comments --- docs/examples/src/query_webpage_1.py | 4 +-- .../drivers/src/vector_store_drivers_1.py | 2 +- .../drivers/src/vector_store_drivers_10.py | 2 +- .../drivers/src/vector_store_drivers_11.py | 2 +- .../drivers/src/vector_store_drivers_3.py | 2 +- .../engines/src/summary_engines_1.py | 2 +- griptape/loaders/base_loader.py | 29 ++++++++++++------- poetry.lock | 2 +- 8 files changed, 27 insertions(+), 18 deletions(-) diff --git a/docs/examples/src/query_webpage_1.py b/docs/examples/src/query_webpage_1.py index 9cc033566..43c59e4c1 100644 --- a/docs/examples/src/query_webpage_1.py +++ b/docs/examples/src/query_webpage_1.py @@ -9,8 +9,8 @@ artifacts = WebLoader().load("https://www.griptape.ai") chunks = TextChunker().chunk(artifacts) -for a in chunks: - vector_store.upsert_text_artifact(a, namespace="griptape") +for chunk in chunks: + vector_store.upsert_text_artifact(chunk, namespace="griptape") results = vector_store.query("creativity", count=3, namespace="griptape") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py index 33f611a71..5011ecd3b 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py @@ -14,7 +14,7 @@ chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] +[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index f82294913..2ef6a6fe0 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -30,7 +30,7 @@ ) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] +[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py index 739a52c63..6972de787 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py @@ -25,7 +25,7 @@ chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] +[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py index c8ff0e0b5..71f679d99 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py @@ -19,7 +19,7 @@ chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] +[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index d84f16798..a2f9ce405 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -13,6 +13,6 @@ artifact = PdfLoader().parse(response.content) chunks = TextChunker().chunk(artifact) -text = "\n\n".join([a.value for a in chunks]) +text = "\n\n".join([chunk.value for chunk in chunks]) engine.summarize_text(text) diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 296e3a385..f7340283b 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -22,9 +22,15 @@ @define class BaseLoader(FuturesExecutorMixin, ABC, Generic[S, F, A]): + """Fetches data from a source, parses it, and returns an Artifact. + + Attributes: + reference: The optional `Reference` to set on the Artifact. + """ + reference: Optional[Reference] = field(default=None, kw_only=True) - def load(self, source: S, *args, **kwargs) -> A: + def load(self, source: S) -> A: data = self.fetch(source) artifact = self.parse(data) @@ -34,29 +40,32 @@ def load(self, source: S, *args, **kwargs) -> A: return artifact @abstractmethod - def fetch(self, source: S) -> F: ... + def fetch(self, source: S) -> F: + """Fetches data from the source.""" + + ... @abstractmethod - def parse(self, data: F) -> A: ... + def parse(self, data: F) -> A: + """Parses the fetched data and returns an Artifact.""" + + ... def load_collection( self, sources: list[Any], - *args, - **kwargs, ) -> Mapping[str, A]: + """Loads a collection of sources and returns a dictionary of Artifacts.""" # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} return execute_futures_dict( - { - key: self.futures_executor.submit(self.load, source, *args, **kwargs) - for key, source in sources_by_key.items() - }, + {key: self.futures_executor.submit(self.load, source) for key, source in sources_by_key.items()}, ) - def to_key(self, source: Any) -> str: + def to_key(self, source: S) -> str: + """Converts the source to a key for the collection.""" if isinstance(source, bytes): return bytes_to_hash(source) else: diff --git a/poetry.lock b/poetry.lock index 0ec06076c..b386b6d65 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7004,4 +7004,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "61d8eb5080886ac092c8e9ff29c6260276acc94ec2062f9363e838c004c5e37f" +content-hash = "765caae605143c2205cd71de88588ca6bd0d454a2a23f21d368746a01e31b828" From 64760f2ff6bcb739d50b6eff8f72141292d820a5 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 25 Sep 2024 15:09:38 -0700 Subject: [PATCH 44/52] Fix changelog --- CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b58dbf9eb..03888626c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- `Workflow.input_tasks` and `Workflow.output_tasks` to access the input and output tasks of a Workflow. +- Ability to pass nested list of `Tasks` to `Structure.tasks` allowing for more complex declarative Structure definitions. +- Parameter `pipeline_task` on `HuggingFacePipelinePromptDriver` for creating different types of `Pipeline`s. +- `TavilyWebSearchDriver` to integrate Tavily's web search capabilities. - `BaseFileLoader` for Loaders that load from a path. - `BaseLoader.fetch()` method for fetching data from a source. - `BaseLoader.parse()` method for parsing fetched data. @@ -15,6 +19,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `file_utils.get_mime_type` utility for getting the MIME type of a file. ### Changed +- **BREAKING**: Renamed parameters on several classes to `client`: + - `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`. + - `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`. + - `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver`. + - `bedrock_client` on `AmazonBedrockImageGenerationDriver`. + - `bedrock_client` on `AmazonBedrockImageQueryDriver`. + - `bedrock_client` on `AmazonBedrockPromptDriver`. + - `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver`. + - `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver`. + - `sqs_client` on `AmazonSqsEventListenerDriver`. + - `iotdata_client` on `AwsIotCoreEventListenerDriver`. + - `s3_client` on `AmazonS3FileManagerDriver`. + - `s3_client` on `AwsS3Tool`. + - `iam_client` on `AwsIamTool`. + - `pusher_client` on `PusherEventListenerDriver`. + - `mq` on `MarqoVectorStoreDriver`. + - `model_client` on `GooglePromptDriver`. + - `model_client` on `GoogleTokenizer`. +- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`. - **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. - **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. - **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`. @@ -25,6 +48,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `filetype` is now a core dependency. - `FileManagerTool` now uses `filetype` for more accurate file type detection. - `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. +- Several places where API clients are initialized are now lazy loaded. + ## Added - `Workflow.input_tasks` and `Workflow.output_tasks` to access the input and output tasks of a Workflow. From 97c497b54860adb84afe689eb61abc90007e1909 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 25 Sep 2024 15:15:53 -0700 Subject: [PATCH 45/52] Support deprecated loading --- CHANGELOG.md | 2 +- griptape/loaders/base_file_loader.py | 10 +++++++++- tests/unit/loaders/test_text_loader.py | 4 ++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03888626c..44d7faf35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. - **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`. - **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed. -- **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. +- **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. Passing `bytes` is still supported but deprecated. - **BREAKING**: Removed `DataframeLoader`. - `LocalFileManagerDriver.workdir` is now optional. - `filetype` is now a core dependency. diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py index 36645ef8a..9fcffa7ae 100644 --- a/griptape/loaders/base_file_loader.py +++ b/griptape/loaders/base_file_loader.py @@ -9,6 +9,7 @@ from griptape.artifacts import BaseArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver from griptape.loaders import BaseLoader +from griptape.utils import deprecation_warn A = TypeVar("A", bound=BaseArtifact) @@ -21,7 +22,14 @@ class BaseFileLoader(BaseLoader[Union[str, PathLike], bytes, A], ABC): ) encoding: str = field(default="utf-8", kw_only=True) - def fetch(self, source: str | PathLike) -> bytes: + def fetch(self, source: str | PathLike | bytes) -> bytes: + if isinstance(source, bytes): + deprecation_warn( + "Using bytes as the source is deprecated and will be removed in a future release. " + "Please use a string or PathLike object instead." + ) + return source + data = self.file_manager_driver.load_file(str(source)).value if isinstance(data, str): return data.encode(self.encoding) diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index c75417f56..a435610ef 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -37,3 +37,7 @@ def test_load_collection(self, loader, create_source): artifact = collection[key] assert artifact.encoding == loader.encoding + + def test_load_deprecated_bytes(self, loader): + with pytest.warns(DeprecationWarning): + loader.load(b"test.txt") From 2763a383d466425e87a0268a94a5e42565e1f318 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 27 Sep 2024 09:27:26 -0700 Subject: [PATCH 46/52] Clean up Vector Store Driver examples --- CHANGELOG.md | 1 + docs/examples/src/query_webpage_1.py | 3 +- .../drivers/src/vector_store_drivers_1.py | 2 +- .../drivers/src/vector_store_drivers_10.py | 2 +- .../drivers/src/vector_store_drivers_11.py | 2 +- .../drivers/src/vector_store_drivers_3.py | 2 +- .../engines/src/summary_engines_1.py | 6 +--- .../vector/base_vector_store_driver.py | 28 ++++++------------- 8 files changed, 15 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44d7faf35..dd1cfc309 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed. - **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. Passing `bytes` is still supported but deprecated. - **BREAKING**: Removed `DataframeLoader`. +- `BaseVectorStoreDriver.upsert_text_artifacts` now returns a list or dictionary of upserted vector ids. - `LocalFileManagerDriver.workdir` is now optional. - `filetype` is now a core dependency. - `FileManagerTool` now uses `filetype` for more accurate file type detection. diff --git a/docs/examples/src/query_webpage_1.py b/docs/examples/src/query_webpage_1.py index 43c59e4c1..b839b1302 100644 --- a/docs/examples/src/query_webpage_1.py +++ b/docs/examples/src/query_webpage_1.py @@ -9,8 +9,7 @@ artifacts = WebLoader().load("https://www.griptape.ai") chunks = TextChunker().chunk(artifacts) -for chunk in chunks: - vector_store.upsert_text_artifact(chunk, namespace="griptape") +vector_store.upsert_text_artifacts({"griptape": chunks}) results = vector_store.query("creativity", count=3, namespace="griptape") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py index 5011ecd3b..f531aa618 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py @@ -14,7 +14,7 @@ chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index 2ef6a6fe0..4599cfa47 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -30,7 +30,7 @@ ) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py index 6972de787..144f14c59 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py @@ -25,7 +25,7 @@ chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py index 71f679d99..d84164cb4 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py @@ -19,7 +19,7 @@ chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(chunk, namespace="griptape") for chunk in chunks] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index a2f9ce405..5a16e4819 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -1,6 +1,5 @@ import requests -from griptape.chunkers import TextChunker from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import PromptSummaryEngine from griptape.loaders import PdfLoader @@ -11,8 +10,5 @@ ) artifact = PdfLoader().parse(response.content) -chunks = TextChunker().chunk(artifact) -text = "\n\n".join([chunk.value for chunk in chunks]) - -engine.summarize_text(text) +engine.summarize_artifacts(artifact) diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 50810752e..e2a394bf4 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -43,9 +43,9 @@ def upsert_text_artifacts( *, meta: Optional[dict] = None, **kwargs, - ) -> None: + ) -> list[str] | dict[str, list[str]]: if isinstance(artifacts, list): - utils.execute_futures_list( + return utils.execute_futures_list( [ self.futures_executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs) for a in artifacts @@ -65,7 +65,7 @@ def upsert_text_artifacts( ) ) - utils.execute_futures_list_dict(futures_dict) + return utils.execute_futures_list_dict(futures_dict) def upsert_text_artifact( self, @@ -89,32 +89,20 @@ def upsert_text_artifact( vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver) - if isinstance(vector, list): - return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) - else: - raise ValueError("Vector must be an instance of 'list'.") + return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) def upsert_text( self, string: str, *, - vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, + vector_id: Optional[str] = None, **kwargs, ) -> str: - vector_id = self._get_default_vector_id(string) if vector_id is None else vector_id - - if self.does_entry_exist(vector_id, namespace=namespace): - return vector_id - else: - return self.upsert_vector( - self.embedding_driver.embed_string(string), - vector_id=vector_id, - namespace=namespace, - meta=meta or {}, - **kwargs, - ) + return self.upsert_text_artifact( + TextArtifact(string), vector_id=vector_id, namespace=namespace, meta=meta, **kwargs + ) def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool: try: From 0bcf576a9b48c6d45dfc1dcf15260676c5f1c97d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 1 Oct 2024 10:04:57 -0700 Subject: [PATCH 47/52] Regenerate lock file --- poetry.lock | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0028938e5..6e8d56acd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6364,11 +6364,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -6950,7 +6945,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "trafilatura", "transformers", "voyageai"] +all = ["accelerate", "anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6994,8 +6989,6 @@ drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-search-duckduckgo = ["duckduckgo-search"] drivers-web-search-exa = ["exa-py"] drivers-web-search-tavily = ["tavily-python"] -loaders-audio = ["filetype"] -loaders-dataframe = ["pandas"] loaders-email = ["mail-parser"] loaders-image = ["pillow"] loaders-pdf = ["pypdf"] @@ -7004,4 +6997,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4fe9e2ded897a0af4b9019926eb93c2d3a80cce07ab4707bfe59a6bd24124ce3" +content-hash = "f07e987c284abf1bfbf78e845120bcaa9ff5cc1819f376cfc4790a50663c05fe" From 27e426262f5affe2044ce6b05f9be64af7df64f5 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 3 Oct 2024 10:53:59 -0700 Subject: [PATCH 48/52] Fix changelog --- CHANGELOG.md | 41 ++++++++--------------------------------- 1 file changed, 8 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd1cfc309..89839a2d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased -### Added + +## Added - `Workflow.input_tasks` and `Workflow.output_tasks` to access the input and output tasks of a Workflow. - Ability to pass nested list of `Tasks` to `Structure.tasks` allowing for more complex declarative Structure definitions. -- Parameter `pipeline_task` on `HuggingFacePipelinePromptDriver` for creating different types of `Pipeline`s. - `TavilyWebSearchDriver` to integrate Tavily's web search capabilities. +- `ExaWebSearchDriver` to integrate Exa's web search capabilities. +- `Workflow.outputs` to access the outputs of a Workflow. - `BaseFileLoader` for Loaders that load from a path. - `BaseLoader.fetch()` method for fetching data from a source. - `BaseLoader.parse()` method for parsing fetched data. @@ -38,48 +40,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `model_client` on `GooglePromptDriver`. - `model_client` on `GoogleTokenizer`. - **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`. +- **BREAKING**: Update `pypdf` dependency to `^5.0.1`. +- **BREAKING**: Update `redis` dependency to `^5.1.0`. - **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. - **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. - **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`. - **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed. - **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. Passing `bytes` is still supported but deprecated. - **BREAKING**: Removed `DataframeLoader`. +- Several places where API clients are initialized are now lazy loaded. +- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags - `BaseVectorStoreDriver.upsert_text_artifacts` now returns a list or dictionary of upserted vector ids. - `LocalFileManagerDriver.workdir` is now optional. - `filetype` is now a core dependency. - `FileManagerTool` now uses `filetype` for more accurate file type detection. - `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. -- Several places where API clients are initialized are now lazy loaded. - - -## Added -- `Workflow.input_tasks` and `Workflow.output_tasks` to access the input and output tasks of a Workflow. -- Ability to pass nested list of `Tasks` to `Structure.tasks` allowing for more complex declarative Structure definitions. -- `TavilyWebSearchDriver` to integrate Tavily's web search capabilities. -- `ExaWebSearchDriver` to integrate Exa's web search capabilities. -- `Workflow.outputs` to access the outputs of a Workflow. - -### Changed -- **BREAKING**: Renamed parameters on several classes to `client`: - - `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`. - - `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`. - - `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver`. - - `bedrock_client` on `AmazonBedrockImageGenerationDriver`. - - `bedrock_client` on `AmazonBedrockImageQueryDriver`. - - `bedrock_client` on `AmazonBedrockPromptDriver`. - - `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver`. - - `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver`. - - `sqs_client` on `AmazonSqsEventListenerDriver`. - - `iotdata_client` on `AwsIotCoreEventListenerDriver`. - - `s3_client` on `AmazonS3FileManagerDriver`. - - `s3_client` on `AwsS3Tool`. - - `iam_client` on `AwsIamTool`. - - `pusher_client` on `PusherEventListenerDriver`. - - `mq` on `MarqoVectorStoreDriver`. - - `model_client` on `GooglePromptDriver`. - - `model_client` on `GoogleTokenizer`. -- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`. -- Several places where API clients are initialized are now lazy loaded. - `Structure.output`'s type is now `BaseArtifact` and raises an exception if the output is `None`. - **BREAKING**: Update `pypdf` dependency to `^5.0.1`. - **BREAKING**: Update `redis` dependency to `^5.1.0`. From f99d5cd6ca4e59088e0bf7ad1309a0b9ae3bdf83 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 3 Oct 2024 11:02:39 -0700 Subject: [PATCH 49/52] Update chunkers docs --- docs/griptape-framework/data/chunkers.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/griptape-framework/data/chunkers.md b/docs/griptape-framework/data/chunkers.md index 507645923..bafbc1c80 100644 --- a/docs/griptape-framework/data/chunkers.md +++ b/docs/griptape-framework/data/chunkers.md @@ -18,3 +18,7 @@ Here is how to use a chunker: ```python --8<-- "docs/griptape-framework/data/src/chunkers_1.py" ``` + +The most common use of a Chunker is to split up a long text into smaller chunks for inserting into a Vector Database when doing Retrieval Augmented Generation (RAG). + +See [RagEngine](../../griptape-framework/engines/rag-engines.md) for more information on how to use Chunkers in RAG pipelines. From 0418e7cd5a9a134fd843980428cb63fe8800ddc6 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 3 Oct 2024 11:03:08 -0700 Subject: [PATCH 50/52] Update doc link order --- mkdocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mkdocs.yml b/mkdocs.yml index f065a0d01..0be4ec7e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -131,8 +131,8 @@ nav: - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" - - Chunkers: "griptape-framework/data/chunkers.md" - Loaders: "griptape-framework/data/loaders.md" + - Chunkers: "griptape-framework/data/chunkers.md" - Misc: - Events: "griptape-framework/misc/events.md" - Tokenizers: "griptape-framework/misc/tokenizers.md" From 23090b613949b53555fa3e705f701749e9a7b296 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 4 Oct 2024 13:29:50 -0700 Subject: [PATCH 51/52] Fix changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89839a2d1..16062edb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. Passing `bytes` is still supported but deprecated. - **BREAKING**: Removed `DataframeLoader`. - Several places where API clients are initialized are now lazy loaded. -- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags - `BaseVectorStoreDriver.upsert_text_artifacts` now returns a list or dictionary of upserted vector ids. - `LocalFileManagerDriver.workdir` is now optional. - `filetype` is now a core dependency. From 46740992df05e8718504a9fb59f945a6c73f73fa Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 4 Oct 2024 13:32:15 -0700 Subject: [PATCH 52/52] Regenerate lock filee --- poetry.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 6e8d56acd..b0d38fc42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6945,7 +6945,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["accelerate", "anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "torch", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6997,4 +6997,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f07e987c284abf1bfbf78e845120bcaa9ff5cc1819f376cfc4790a50663c05fe" +content-hash = "ca72d32879b2af60bd30f0401c0e11de24e4dcf7a9eadd06a867b3ca1db85d92"