Skip to content

Commit

Permalink
Use generics for loaders, remove instances of path
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 5, 2024
1 parent 62b6e17 commit 6b3bc8b
Show file tree
Hide file tree
Showing 34 changed files with 79 additions and 137 deletions.
4 changes: 2 additions & 2 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ TextArtifact("name: John\nAge: 30")
#### Before

```python
results = CsvLoader().load(Path("people.csv").read_text())
results = CsvLoader().load("people.csv")

print(results[0].value) # {"name": "John", "age": 30}
```

#### After
```python
results = CsvLoader().load(Path("people.csv").read_text())
results = CsvLoader().load("people.csv")

print(results[0].value) # name: John\nAge: 30
print(results[0].meta["row"]) # 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.artifacts import TextArtifact
from griptape.drivers import (
HuggingFacePipelineImageGenerationDriver,
Expand All @@ -11,7 +9,7 @@
from griptape.tasks import VariationImageGenerationTask

prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k")
input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
input_image_artifact = ImageLoader().load("tests/resources/mountain.png")

image_variation_task = VariationImageGenerationTask(
input=(prompt_artifact, input_image_artifact),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.artifacts import TextArtifact
from griptape.drivers import (
HuggingFacePipelineImageGenerationDriver,
Expand All @@ -11,7 +9,7 @@
from griptape.tasks import VariationImageGenerationTask

prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k")
control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes())
control_image_artifact = ImageLoader().load("canny_control_image.png")

controlnet_task = VariationImageGenerationTask(
input=(prompt_artifact, control_image_artifact),
Expand Down
4 changes: 1 addition & 3 deletions docs/griptape-framework/drivers/src/image_query_drivers_1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AnthropicImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader
Expand All @@ -13,6 +11,6 @@
image_query_driver=driver,
)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

engine.run("Describe the weather in the image", [image_artifact])
6 changes: 2 additions & 4 deletions docs/griptape-framework/drivers/src/image_query_drivers_2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AnthropicImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader
Expand All @@ -13,9 +11,9 @@
image_query_driver=driver,
)

image_artifact1 = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact1 = ImageLoader().load("tests/resources/mountain.png")

image_artifact2 = ImageLoader().load(Path("tests/resources/cow.png").read_bytes())
image_artifact2 = ImageLoader().load("tests/resources/cow.png")

result = engine.run("Describe the weather in the image", [image_artifact1, image_artifact2])

Expand Down
4 changes: 1 addition & 3 deletions docs/griptape-framework/drivers/src/image_query_drivers_3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader
Expand All @@ -13,6 +11,6 @@
image_query_driver=driver,
)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

engine.run("Describe the weather in the image", [image_artifact])
3 changes: 1 addition & 2 deletions docs/griptape-framework/drivers/src/image_query_drivers_4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from pathlib import Path

from griptape.drivers import AzureOpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
Expand All @@ -17,6 +16,6 @@
image_query_driver=driver,
)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

engine.run("Describe the weather in the image", [image_artifact])
4 changes: 1 addition & 3 deletions docs/griptape-framework/drivers/src/image_query_drivers_5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import boto3

from griptape.drivers import AmazonBedrockImageQueryDriver, BedrockClaudeImageQueryModelDriver
Expand All @@ -16,7 +14,7 @@

engine = ImageQueryEngine(image_query_driver=driver)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")


result = engine.run("Describe the weather in the image", [image_artifact])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver
from griptape.engines import VariationImageGenerationEngine
from griptape.loaders import ImageLoader
Expand All @@ -15,7 +13,7 @@
image_generation_driver=driver,
)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

engine.run(
prompts=["A photo of a mountain landscape in winter"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver
from griptape.engines import InpaintingImageGenerationEngine
from griptape.loaders import ImageLoader
Expand All @@ -15,9 +13,9 @@
image_generation_driver=driver,
)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes())
mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png")

engine.run(
prompts=["A photo of a castle built into the side of a mountain"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver
from griptape.engines import OutpaintingImageGenerationEngine
from griptape.loaders import ImageLoader
Expand All @@ -15,9 +13,9 @@
image_generation_driver=driver,
)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes())
mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png")

engine.run(
prompts=["A photo of a mountain shrouded in clouds"],
Expand Down
4 changes: 1 addition & 3 deletions docs/griptape-framework/engines/src/image_query_engines_1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader
Expand All @@ -8,6 +6,6 @@

engine = ImageQueryEngine(image_query_driver=driver)

image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

engine.run("Describe the weather in the image", [image_artifact])
4 changes: 1 addition & 3 deletions docs/griptape-framework/structures/src/tasks_12.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver
from griptape.engines import VariationImageGenerationEngine
from griptape.loaders import ImageLoader
Expand All @@ -18,7 +16,7 @@
)

# Load input image artifact.
image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

# Instantiate a pipeline.
pipeline = Pipeline()
Expand Down
6 changes: 2 additions & 4 deletions docs/griptape-framework/structures/src/tasks_13.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver
from griptape.engines import InpaintingImageGenerationEngine
from griptape.loaders import ImageLoader
Expand All @@ -18,9 +16,9 @@
)

# Load input image artifacts.
image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes())
mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png")

# Instantiate a pipeline.
pipeline = Pipeline()
Expand Down
6 changes: 2 additions & 4 deletions docs/griptape-framework/structures/src/tasks_14.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver
from griptape.engines import OutpaintingImageGenerationEngine
from griptape.loaders import ImageLoader
Expand All @@ -18,9 +16,9 @@
)

# Load input image artifacts.
image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes())
mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png")

# Instantiate a pipeline.
pipeline = Pipeline()
Expand Down
4 changes: 1 addition & 3 deletions docs/griptape-framework/structures/src/tasks_15.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader
Expand All @@ -18,7 +16,7 @@
)

# Load the input image artifact.
image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.png")

# Instantiate a pipeline.
pipeline = Pipeline()
Expand Down
4 changes: 1 addition & 3 deletions docs/griptape-framework/structures/src/tasks_3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from pathlib import Path

from griptape.loaders import ImageLoader
from griptape.structures import Agent

agent = Agent()

image_artifact = ImageLoader().load(Path("tests/resources/mountain.jpg").read_bytes())
image_artifact = ImageLoader().load("tests/resources/mountain.jpg")

agent.run([image_artifact, "What's in this image?"])
9 changes: 2 additions & 7 deletions griptape/loaders/audio_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Any, cast

import filetype
from attrs import define

Expand All @@ -10,11 +8,8 @@


@define
class AudioLoader(BaseFileLoader):
class AudioLoader(BaseFileLoader[AudioArtifact]):
"""Loads audio content into audio artifacts."""

def load(self, source: Any, *args, **kwargs) -> AudioArtifact:
return cast(AudioArtifact, super().load(source, *args, **kwargs))

def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact:
def parse(self, source: bytes) -> AudioArtifact:
return AudioArtifact(source, format=filetype.guess(source).extension)
17 changes: 11 additions & 6 deletions griptape/loaders/base_file_loader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING
from os import PathLike
from typing import TypeVar, Union

from attrs import Factory, define, field

from griptape.artifacts import BaseArtifact
from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
from os import PathLike
A = TypeVar("A", bound=BaseArtifact)


@define
class BaseFileLoader(BaseLoader, ABC):
class BaseFileLoader(BaseLoader[Union[str, PathLike], bytes, A], ABC):
file_manager_driver: BaseFileManagerDriver = field(
default=Factory(lambda: LocalFileManagerDriver(workdir=None)),
kw_only=True,
)
encoding: str = field(default="utf-8", kw_only=True)

def fetch(self, source: str | PathLike, *args, **kwargs) -> str | bytes:
return self.file_manager_driver.load_file(str(source), *args, **kwargs).value
def fetch(self, source: str | PathLike) -> bytes:
data = self.file_manager_driver.load_file(str(source)).value
if isinstance(data, str):
return data.encode(self.encoding)
else:
return data
24 changes: 16 additions & 8 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,50 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.mixins import FuturesExecutorMixin
from griptape.utils.futures import execute_futures_dict
from griptape.utils.hash import bytes_to_hash, str_to_hash

if TYPE_CHECKING:
from collections.abc import Mapping

from griptape.artifacts import BaseArtifact
from griptape.common import Reference

S = TypeVar("S")
F = TypeVar("F")
A = TypeVar("A", bound=BaseArtifact)


@define
class BaseLoader(FuturesExecutorMixin, ABC):
class BaseLoader(FuturesExecutorMixin, ABC, Generic[S, F, A]):
reference: Optional[Reference] = field(default=None, kw_only=True)

def load(self, source: Any, *args, **kwargs) -> BaseArtifact:
def load(self, source: S, *args, **kwargs) -> A:
data = self.fetch(source)

return self.parse(data)
artifact = self.parse(data)

artifact.reference = self.reference

return artifact

@abstractmethod
def fetch(self, source: Any, *args, **kwargs) -> Any: ...
def fetch(self, source: S) -> F: ...

@abstractmethod
def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ...
def parse(self, source: F) -> A: ...

def load_collection(
self,
sources: list[Any],
*args,
**kwargs,
) -> Mapping[str, BaseArtifact]:
) -> Mapping[str, A]:
# Create a dictionary before actually submitting the jobs to the executor
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}
Expand Down
Loading

0 comments on commit 6b3bc8b

Please sign in to comment.