Skip to content

Commit

Permalink
Fixes, test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench committed Apr 23, 2024
1 parent c0572ef commit d05e557
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 32 deletions.
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver

from .audio_generation.base_audio_generation_driver import BaseAudioGenerationDriver
from .audio_generation.dummy_audio_generation_driver import DummyAudioGenerationDriver
from .audio_generation.elevenlabs_audio_generation_driver import ElevenLabsAudioGenerationDriver

__all__ = [
Expand Down Expand Up @@ -185,5 +186,6 @@
"LocalFileManagerDriver",
"AmazonS3FileManagerDriver",
"BaseAudioGenerationDriver",
"DummyAudioGenerationDriver",
"ElevenLabsAudioGenerationDriver",
]
12 changes: 6 additions & 6 deletions griptape/tasks/image_query_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attr import define, field

from griptape.artifacts import MediaArtifact, TextArtifact
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.engines import ImageQueryEngine
from griptape.tasks import BaseTask
from griptape.utils import J2
Expand All @@ -23,12 +23,12 @@ class ImageQueryTask(BaseTask):
"""

_image_query_engine: ImageQueryEngine = field(default=None, kw_only=True, alias="image_query_engine")
_input: tuple[str, list[MediaArtifact]] | tuple[TextArtifact, list[MediaArtifact]] | Callable[
[BaseTask], tuple[TextArtifact, list[MediaArtifact]]
_input: tuple[str, list[ImageArtifact]] | tuple[TextArtifact, list[ImageArtifact]] | Callable[
[BaseTask], tuple[TextArtifact, list[ImageArtifact]]
] = field(default=None, alias="input")

@property
def input(self) -> tuple[TextArtifact, list[MediaArtifact]]:
def input(self) -> tuple[TextArtifact, list[ImageArtifact]]:
if isinstance(self._input, tuple):
if isinstance(self._input[0], TextArtifact):
query_text = self._input[0]
Expand All @@ -47,8 +47,8 @@ def input(self) -> tuple[TextArtifact, list[MediaArtifact]]:
@input.setter
def input(
self,
value: tuple[TextArtifact, list[MediaArtifact]]
| Callable[[BaseTask], tuple[TextArtifact, list[MediaArtifact]]],
value: tuple[TextArtifact, list[ImageArtifact]]
| Callable[[BaseTask], tuple[TextArtifact, list[ImageArtifact]]],
) -> None:
self._input = value

Expand Down
12 changes: 6 additions & 6 deletions griptape/tasks/inpainting_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attr import define, field

from griptape.engines import InpaintingImageGenerationEngine
from griptape.artifacts import MediaArtifact, TextArtifact
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.tasks import BaseImageGenerationTask, BaseTask
from griptape.utils import J2

Expand All @@ -30,12 +30,12 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask):
default=None, kw_only=True, alias="image_generation_engine"
)
_input: (
tuple[str | TextArtifact, MediaArtifact, MediaArtifact]
| Callable[[BaseTask], tuple[TextArtifact, MediaArtifact, MediaArtifact]]
tuple[str | TextArtifact, ImageArtifact, ImageArtifact]
| Callable[[BaseTask], tuple[TextArtifact, ImageArtifact, ImageArtifact]]
) = field(default=None)

@property
def input(self) -> tuple[TextArtifact, MediaArtifact, MediaArtifact]:
def input(self) -> tuple[TextArtifact, ImageArtifact, ImageArtifact]:
if isinstance(self._input, tuple):
if isinstance(self._input[0], TextArtifact):
input_text = self._input[0]
Expand All @@ -49,7 +49,7 @@ def input(self) -> tuple[TextArtifact, MediaArtifact, MediaArtifact]:
raise ValueError("Input must be a tuple of (text, image, mask) or a callable that returns such a tuple.")

@input.setter
def input(self, value: tuple[TextArtifact, MediaArtifact, MediaArtifact]) -> None:
def input(self, value: tuple[TextArtifact, ImageArtifact, ImageArtifact]) -> None:
self._input = value

@property
Expand All @@ -67,7 +67,7 @@ def image_generation_engine(self) -> InpaintingImageGenerationEngine:
def image_generation_engine(self, value: InpaintingImageGenerationEngine) -> None:
self._image_generation_engine = value

def run(self) -> MediaArtifact:
def run(self) -> ImageArtifact:
prompt_artifact = self.input[0]
image_artifact = self.input[1]
mask_artifact = self.input[2]
Expand Down
12 changes: 6 additions & 6 deletions griptape/tasks/outpainting_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attr import define, field

from griptape.engines import OutpaintingImageGenerationEngine
from griptape.artifacts import MediaArtifact, TextArtifact
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.tasks import BaseImageGenerationTask, BaseTask
from griptape.utils import J2

Expand All @@ -29,12 +29,12 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask):
_image_generation_engine: OutpaintingImageGenerationEngine = field(
default=None, kw_only=True, alias="image_generation_engine"
)
_input: tuple[str | TextArtifact, MediaArtifact, MediaArtifact] | Callable[
[BaseTask], tuple[TextArtifact, MediaArtifact, MediaArtifact]
_input: tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[
[BaseTask], tuple[TextArtifact, ImageArtifact, ImageArtifact]
] = field(default=None)

@property
def input(self) -> tuple[TextArtifact, MediaArtifact, MediaArtifact]:
def input(self) -> tuple[TextArtifact, ImageArtifact, ImageArtifact]:
if isinstance(self._input, tuple):
if isinstance(self._input[0], TextArtifact):
input_text = self._input[0]
Expand All @@ -48,7 +48,7 @@ def input(self) -> tuple[TextArtifact, MediaArtifact, MediaArtifact]:
raise ValueError("Input must be a tuple of (text, image, mask) or a callable that returns such a tuple.")

@input.setter
def input(self, value: tuple[TextArtifact, MediaArtifact, MediaArtifact]) -> None:
def input(self, value: tuple[TextArtifact, ImageArtifact, ImageArtifact]) -> None:
self._input = value

@property
Expand All @@ -67,7 +67,7 @@ def image_generation_engine(self) -> OutpaintingImageGenerationEngine:
def image_generation_engine(self, value: OutpaintingImageGenerationEngine) -> None:
self._image_generation_engine = value

def run(self) -> MediaArtifact:
def run(self) -> ImageArtifact:
prompt_artifact = self.input[0]
image_artifact = self.input[1]
mask_artifact = self.input[2]
Expand Down
6 changes: 3 additions & 3 deletions griptape/tasks/prompt_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import Callable

from attr import define, field, Factory
from attr import define, field

from griptape.engines import PromptImageGenerationEngine
from griptape.artifacts import MediaArtifact, TextArtifact
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.tasks import BaseImageGenerationTask, BaseTask
from griptape.utils import J2

Expand Down Expand Up @@ -60,7 +60,7 @@ def image_generation_engine(self) -> PromptImageGenerationEngine:
def image_generation_engine(self, value: PromptImageGenerationEngine) -> None:
self._image_generation_engine = value

def run(self) -> MediaArtifact:
def run(self) -> ImageArtifact:
image_artifact = self.image_generation_engine.run(
prompts=[self.input.to_text()], rulesets=self.all_rulesets, negative_rulesets=self.negative_rulesets
)
Expand Down
10 changes: 5 additions & 5 deletions griptape/tasks/variation_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attr import define, field

from griptape.engines import VariationImageGenerationEngine
from griptape.artifacts import MediaArtifact, TextArtifact
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.tasks import BaseImageGenerationTask, BaseTask
from griptape.utils import J2

Expand All @@ -29,12 +29,12 @@ class VariationImageGenerationTask(BaseImageGenerationTask):
_image_generation_engine: VariationImageGenerationEngine = field(
default=None, kw_only=True, alias="image_generation_engine"
)
_input: tuple[str | TextArtifact, MediaArtifact] | Callable[[BaseTask], tuple[TextArtifact, MediaArtifact]] = field(
_input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], tuple[TextArtifact, ImageArtifact]] = field(
default=None
)

@property
def input(self) -> tuple[TextArtifact, MediaArtifact]:
def input(self) -> tuple[TextArtifact, ImageArtifact]:
if isinstance(self._input, tuple):
if isinstance(self._input[0], TextArtifact):
input_text = self._input[0]
Expand All @@ -48,7 +48,7 @@ def input(self) -> tuple[TextArtifact, MediaArtifact]:
raise ValueError("Input must be a tuple of (text, image) or a callable that returns such a tuple.")

@input.setter
def input(self, value: tuple[TextArtifact, MediaArtifact]) -> None:
def input(self, value: tuple[TextArtifact, ImageArtifact]) -> None:
self._input = value

@property
Expand All @@ -66,7 +66,7 @@ def image_generation_engine(self) -> VariationImageGenerationEngine:
def image_generation_engine(self, value: VariationImageGenerationEngine) -> None:
self._image_generation_engine = value

def run(self) -> MediaArtifact:
def run(self) -> ImageArtifact:
prompt_artifact = self.input[0]
image_artifact = self.input[1]

Expand Down
1 change: 1 addition & 0 deletions tests/unit/config/test_amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_to_dict(self, config):
"seed": None,
"type": "AmazonBedrockImageGenerationDriver",
},
"audio_generation_driver": {"type": "DummyAudioGenerationDriver"},
"image_query_driver": {
"type": "AmazonBedrockImageQueryDriver",
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/config/test_anthropic_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_to_dict(self, config):
"top_k": 250,
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"audio_generation_driver": {"type": "DummyAudioGenerationDriver"},
"image_query_driver": {
"type": "AnthropicImageQueryDriver",
"model": "claude-3-opus-20240229",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/config/test_google_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_to_dict(self, config):
"top_k": None,
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"audio_generation_driver": {"type": "DummyAudioGenerationDriver"},
"image_query_driver": {"type": "DummyImageQueryDriver"},
"embedding_driver": {
"type": "GoogleEmbeddingDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/config/test_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_to_dict(self, config):
"style": None,
"type": "OpenAiImageGenerationDriver",
},
"audio_generation_driver": {"type": "DummyAudioGenerationDriver"},
"image_query_driver": {
"api_version": None,
"base_url": None,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/config/test_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_to_dict(self, config):
"conversation_memory_driver": None,
"embedding_driver": {"type": "DummyEmbeddingDriver"},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"audio_generation_driver": {"type": "DummyAudioGenerationDriver"},
"image_query_driver": {"type": "DummyImageQueryDriver"},
"vector_store_driver": {
"embedding_driver": {"type": "DummyEmbeddingDriver"},
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/mixins/test_image_artifact_file_output_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import pytest

from griptape.artifacts import ImageArtifact
from griptape.mixins import ImageArtifactFileOutputMixin
from griptape.mixins import MediaArtifactFileOutputMixin


class TestImageArtifactFileOutputMixin:
class TestMediaArtifactFileOutputMixin:
def test_no_output(self):
class Test(ImageArtifactFileOutputMixin):
class Test(MediaArtifactFileOutputMixin):
pass

assert Test().output_file is None
Expand All @@ -18,7 +18,7 @@ class Test(ImageArtifactFileOutputMixin):
def test_output_file(self):
artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png")

class Test(ImageArtifactFileOutputMixin):
class Test(MediaArtifactFileOutputMixin):
def run(self):
self._write_to_file(artifact)

Expand All @@ -33,7 +33,7 @@ def run(self):
def test_output_dir(self):
artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png")

class Test(ImageArtifactFileOutputMixin):
class Test(MediaArtifactFileOutputMixin):
def run(self):
self._write_to_file(artifact)

Expand All @@ -46,7 +46,7 @@ def run(self):
assert os.path.exists(os.path.join(outdir, artifact.name))

def test_output_file_and_dir(self):
class Test(ImageArtifactFileOutputMixin):
class Test(MediaArtifactFileOutputMixin):
pass

outfile = "test.txt"
Expand Down

0 comments on commit d05e557

Please sign in to comment.