Skip to content

Commit

Permalink
Refactor converters
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 9, 2024
1 parent 4ac583a commit 7300f52
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 14 deletions.
6 changes: 2 additions & 4 deletions griptape/artifacts/audio_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.artifacts import BlobArtifact


@define
class AudioArtifact(BaseArtifact):
class AudioArtifact(BlobArtifact):
"""Stores audio data.
Attributes:
value: The audio data.
format: The audio format, e.g. "wav" or "mp3".
"""

value: bytes = field(metadata={"serializable": True})
format: str = field(kw_only=True, metadata={"serializable": True})

@property
Expand Down
6 changes: 4 additions & 2 deletions griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class BaseArtifact(SerializableMixin, ABC):
meta: The metadata associated with the Artifact. Defaults to an empty dictionary.
name: The name of the Artifact. Defaults to the id.
value: The value of the Artifact.
encoding: The encoding of the Artifact when converting to bytes.
encoding: The encoding to use when encoding/decoding the value.
encoding_error_handler: The error handler to use when encoding/decoding the value.
"""

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
Expand All @@ -36,6 +37,7 @@ class BaseArtifact(SerializableMixin, ABC):
metadata={"serializable": True},
)
value: Any = field()
encoding_error_handler: str = field(default="strict", kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)

def __str__(self) -> str:
Expand All @@ -48,7 +50,7 @@ def __len__(self) -> int:
return len(self.value)

def to_bytes(self) -> bytes:
return self.to_text().encode(self.encoding)
return self.to_text().encode(encoding=self.encoding, errors=self.encoding_error_handler)

Check warning on line 53 in griptape/artifacts/base_artifact.py

View check run for this annotation

Codecov / codecov/patch

griptape/artifacts/base_artifact.py#L53

Added line #L53 was not covered by tests

@abstractmethod
def to_text(self) -> str: ...
4 changes: 0 additions & 4 deletions griptape/artifacts/blob_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@ class BlobArtifact(BaseArtifact):
Attributes:
value: The binary data.
encoding: The encoding to use when converting the binary data to text.
encoding_error_handler: The error handler to use when converting the binary data to text.
"""

value: bytes = field(converter=lambda value: BlobArtifact.value_to_bytes(value), metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)

@property
def mime_type(self) -> str:
Expand Down
2 changes: 0 additions & 2 deletions griptape/artifacts/image_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@ class ImageArtifact(BlobArtifact):
"""Stores image data.
Attributes:
value: The image data.
format: The format of the image data.
width: The width of the image.
height: The height of the image
"""

value: bytes = field(metadata={"serializable": True})
format: str = field(kw_only=True, metadata={"serializable": True})
width: int = field(kw_only=True, metadata={"serializable": True})
height: int = field(kw_only=True, metadata={"serializable": True})
Expand Down
2 changes: 0 additions & 2 deletions griptape/artifacts/text_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
@define
class TextArtifact(BaseArtifact):
value: str = field(converter=str, metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)
embedding: Optional[list[float]] = field(default=None, kw_only=True)

def __add__(self, other: BaseArtifact) -> TextArtifact:
Expand Down

0 comments on commit 7300f52

Please sign in to comment.