diff --git a/MIGRATION.md b/MIGRATION.md index 39932d4fac..27a3ae76f6 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -77,14 +77,14 @@ TextArtifact("name: John\nAge: 30") #### Before ```python -results = CsvLoader().load(Path("people.csv").read_text()) +results = CsvLoader().load("people.csv") print(results[0].value) # {"name": "John", "age": 30} ``` #### After ```python -results = CsvLoader().load(Path("people.csv").read_text()) +results = CsvLoader().load("people.csv") print(results[0].value) # name: John\nAge: 30 print(results[0].meta["row"]) # 0 diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_8.py b/docs/griptape-framework/drivers/src/image_generation_drivers_8.py index 69437a3a5d..470b477074 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_8.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_8.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.artifacts import TextArtifact from griptape.drivers import ( HuggingFacePipelineImageGenerationDriver, @@ -11,7 +9,7 @@ from griptape.tasks import VariationImageGenerationTask prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") -input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +input_image_artifact = ImageLoader().load("tests/resources/mountain.png") image_variation_task = VariationImageGenerationTask( input=(prompt_artifact, input_image_artifact), diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_9.py b/docs/griptape-framework/drivers/src/image_generation_drivers_9.py index 2054588d93..ab3dc31138 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_9.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_9.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.artifacts import TextArtifact from griptape.drivers import ( HuggingFacePipelineImageGenerationDriver, @@ -11,7 +9,7 @@ from griptape.tasks import VariationImageGenerationTask prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") -control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes()) +control_image_artifact = ImageLoader().load("canny_control_image.png") controlnet_task = VariationImageGenerationTask( input=(prompt_artifact, control_image_artifact), diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_1.py b/docs/griptape-framework/drivers/src/image_query_drivers_1.py index 0c9db5be7b..0e0165d97e 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_1.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_1.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AnthropicImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,6 +11,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_2.py b/docs/griptape-framework/drivers/src/image_query_drivers_2.py index 8d605c0d97..4b5b3cc9f9 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_2.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_2.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AnthropicImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,9 +11,9 @@ image_query_driver=driver, ) -image_artifact1 = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact1 = ImageLoader().load("tests/resources/mountain.png") -image_artifact2 = ImageLoader().load(Path("tests/resources/cow.png").read_bytes()) +image_artifact2 = ImageLoader().load("tests/resources/cow.png") result = engine.run("Describe the weather in the image", [image_artifact1, image_artifact2]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_3.py b/docs/griptape-framework/drivers/src/image_query_drivers_3.py index 14070312b3..0653d3f6e7 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_3.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_3.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,6 +11,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_4.py b/docs/griptape-framework/drivers/src/image_query_drivers_4.py index 9ebf5ef597..cff4c2a10c 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_4.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_4.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from griptape.drivers import AzureOpenAiImageQueryDriver from griptape.engines import ImageQueryEngine @@ -17,6 +16,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_5.py b/docs/griptape-framework/drivers/src/image_query_drivers_5.py index 2bab9a7fd4..c364a24cca 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_5.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_5.py @@ -1,5 +1,3 @@ -from pathlib import Path - import boto3 from griptape.drivers import AmazonBedrockImageQueryDriver, BedrockClaudeImageQueryModelDriver @@ -16,7 +14,7 @@ engine = ImageQueryEngine(image_query_driver=driver) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") result = engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/engines/src/image_generation_engines_3.py b/docs/griptape-framework/engines/src/image_generation_engines_3.py index 83822b1bc1..4bcd976d4c 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_3.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_3.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,7 +13,7 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run( prompts=["A photo of a mountain landscape in winter"], diff --git a/docs/griptape-framework/engines/src/image_generation_engines_4.py b/docs/griptape-framework/engines/src/image_generation_engines_4.py index c258e1cce6..e7b46b341d 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_4.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_4.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,9 +13,9 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") engine.run( prompts=["A photo of a castle built into the side of a mountain"], diff --git a/docs/griptape-framework/engines/src/image_generation_engines_5.py b/docs/griptape-framework/engines/src/image_generation_engines_5.py index f91a48ec0e..526ebff504 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_5.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_5.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,9 +13,9 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") engine.run( prompts=["A photo of a mountain shrouded in clouds"], diff --git a/docs/griptape-framework/engines/src/image_query_engines_1.py b/docs/griptape-framework/engines/src/image_query_engines_1.py index b0920392ac..c2d08e9a96 100644 --- a/docs/griptape-framework/engines/src/image_query_engines_1.py +++ b/docs/griptape-framework/engines/src/image_query_engines_1.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -8,6 +6,6 @@ engine = ImageQueryEngine(image_query_driver=driver) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/structures/src/tasks_12.py b/docs/griptape-framework/structures/src/tasks_12.py index 917b506071..1fdc99e1cb 100644 --- a/docs/griptape-framework/structures/src/tasks_12.py +++ b/docs/griptape-framework/structures/src/tasks_12.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,7 +16,7 @@ ) # Load input image artifact. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_13.py b/docs/griptape-framework/structures/src/tasks_13.py index d2aa459837..4b7616d949 100644 --- a/docs/griptape-framework/structures/src/tasks_13.py +++ b/docs/griptape-framework/structures/src/tasks_13.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,9 +16,9 @@ ) # Load input image artifacts. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_14.py b/docs/griptape-framework/structures/src/tasks_14.py index ec489096d7..d2e6ba2dd1 100644 --- a/docs/griptape-framework/structures/src/tasks_14.py +++ b/docs/griptape-framework/structures/src/tasks_14.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,9 +16,9 @@ ) # Load input image artifacts. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_15.py b/docs/griptape-framework/structures/src/tasks_15.py index 0c60864f7a..802ac33979 100644 --- a/docs/griptape-framework/structures/src/tasks_15.py +++ b/docs/griptape-framework/structures/src/tasks_15.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -18,7 +16,7 @@ ) # Load the input image artifact. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_3.py b/docs/griptape-framework/structures/src/tasks_3.py index 6584049d07..cdfe894bd3 100644 --- a/docs/griptape-framework/structures/src/tasks_3.py +++ b/docs/griptape-framework/structures/src/tasks_3.py @@ -1,10 +1,8 @@ -from pathlib import Path - from griptape.loaders import ImageLoader from griptape.structures import Agent agent = Agent() -image_artifact = ImageLoader().load(Path("tests/resources/mountain.jpg").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.jpg") agent.run([image_artifact, "What's in this image?"]) diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 1a7caefce5..4e7d9c7d05 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - import filetype from attrs import define @@ -10,11 +8,8 @@ @define -class AudioLoader(BaseFileLoader): +class AudioLoader(BaseFileLoader[AudioArtifact]): """Loads audio content into audio artifacts.""" - def load(self, source: Any, *args, **kwargs) -> AudioArtifact: - return cast(AudioArtifact, super().load(source, *args, **kwargs)) - - def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact: + def parse(self, source: bytes) -> AudioArtifact: return AudioArtifact(source, format=filetype.guess(source).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py index 650dc303cd..36645ef8a8 100644 --- a/griptape/loaders/base_file_loader.py +++ b/griptape/loaders/base_file_loader.py @@ -1,24 +1,29 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING +from os import PathLike +from typing import TypeVar, Union from attrs import Factory, define, field +from griptape.artifacts import BaseArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver from griptape.loaders import BaseLoader -if TYPE_CHECKING: - from os import PathLike +A = TypeVar("A", bound=BaseArtifact) @define -class BaseFileLoader(BaseLoader, ABC): +class BaseFileLoader(BaseLoader[Union[str, PathLike], bytes, A], ABC): file_manager_driver: BaseFileManagerDriver = field( default=Factory(lambda: LocalFileManagerDriver(workdir=None)), kw_only=True, ) encoding: str = field(default="utf-8", kw_only=True) - def fetch(self, source: str | PathLike, *args, **kwargs) -> str | bytes: - return self.file_manager_driver.load_file(str(source), *args, **kwargs).value + def fetch(self, source: str | PathLike) -> bytes: + data = self.file_manager_driver.load_file(str(source)).value + if isinstance(data, str): + return data.encode(self.encoding) + else: + return data diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 7c16c8ded7..a1e06b6b86 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -1,10 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from attrs import define, field +from griptape.artifacts import BaseArtifact from griptape.mixins import FuturesExecutorMixin from griptape.utils.futures import execute_futures_dict from griptape.utils.hash import bytes_to_hash, str_to_hash @@ -12,31 +13,38 @@ if TYPE_CHECKING: from collections.abc import Mapping - from griptape.artifacts import BaseArtifact from griptape.common import Reference +S = TypeVar("S") +F = TypeVar("F") +A = TypeVar("A", bound=BaseArtifact) + @define -class BaseLoader(FuturesExecutorMixin, ABC): +class BaseLoader(FuturesExecutorMixin, ABC, Generic[S, F, A]): reference: Optional[Reference] = field(default=None, kw_only=True) - def load(self, source: Any, *args, **kwargs) -> BaseArtifact: + def load(self, source: S, *args, **kwargs) -> A: data = self.fetch(source) - return self.parse(data) + artifact = self.parse(data) + + artifact.reference = self.reference + + return artifact @abstractmethod - def fetch(self, source: Any, *args, **kwargs) -> Any: ... + def fetch(self, source: S) -> F: ... @abstractmethod - def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ... + def parse(self, source: F) -> A: ... def load_collection( self, sources: list[Any], *args, **kwargs, - ) -> Mapping[str, BaseArtifact]: + ) -> Mapping[str, A]: # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index dbc7e66bfb..b54c880b58 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - from attrs import define from griptape.artifacts import BlobArtifact @@ -9,10 +7,7 @@ @define -class BlobLoader(BaseFileLoader): - def load(self, source: Any, *args, **kwargs) -> BlobArtifact: - return cast(BlobArtifact, super().load(source, *args, **kwargs)) - +class BlobLoader(BaseFileLoader[BlobArtifact]): def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: if self.encoding is None: return BlobArtifact(source) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 222952401b..ad0382d6e3 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -10,11 +10,11 @@ @define -class CsvLoader(BaseFileLoader): +class CsvLoader(BaseFileLoader[ListArtifact]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: + def parse(self, source: bytes) -> ListArtifact: reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter) return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index 0a3d3f600c..88cd8e967c 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations import imaplib -from typing import Any, Optional, cast +from typing import Optional from attrs import astuple, define, field @@ -11,7 +11,7 @@ @define -class EmailLoader(BaseLoader): +class EmailLoader(BaseLoader["EmailLoader.EmailQuery", list[bytes], ListArtifact]): # pyright: ignore[reportGeneralTypeIssues] @define(frozen=True) class EmailQuery: """An email retrieval query. @@ -32,10 +32,7 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: Any, *args, **kwargs) -> ListArtifact: - return cast(ListArtifact, super().load(source, *args, **kwargs)) - - def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: + def fetch(self, source: EmailLoader.EmailQuery) -> list[bytes]: label, key, search_criteria, max_count = astuple(source) mail_bytes = [] @@ -67,7 +64,7 @@ def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: return mail_bytes - def parse(self, source: list[bytes], *args, **kwargs) -> ListArtifact: + def parse(self, source: list[bytes]) -> ListArtifact: mailparser = import_optional_dependency("mailparser") artifacts = [] for byte in source: diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 513910c78c..d5f4ca7427 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import Any, Optional, cast +from typing import Optional from attrs import define, field @@ -11,7 +11,7 @@ @define -class ImageLoader(BaseFileLoader): +class ImageLoader(BaseFileLoader[ImageArtifact]): """Loads images into image artifacts. Attributes: @@ -22,10 +22,7 @@ class ImageLoader(BaseFileLoader): format: Optional[str] = field(default=None, kw_only=True) - def load(self, source: Any, *args, **kwargs) -> ImageArtifact: - return cast(ImageArtifact, super().load(source, *args, **kwargs)) - - def parse(self, source: bytes, *args, **kwargs) -> ImageArtifact: + def parse(self, source: bytes) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") image = pil_image.open(BytesIO(source)) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index b1f9013864..9be195e856 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -19,8 +19,6 @@ def parse( self, source: bytes, password: Optional[str] = None, - *args, - **kwargs, ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index d633eede60..594bf06e0c 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,25 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast - from attrs import define, field from griptape.artifacts import ListArtifact, TextArtifact +from griptape.drivers import BaseSqlDriver from griptape.loaders import BaseLoader -if TYPE_CHECKING: - from griptape.drivers import BaseSqlDriver - @define -class SqlLoader(BaseLoader): +class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact]): sql_driver: BaseSqlDriver = field(kw_only=True) - def load(self, source: Any, *args, **kwargs) -> ListArtifact: - return cast(ListArtifact, super().load(source, *args, **kwargs)) - - def fetch(self, source: str, *args, **kwargs) -> list[BaseSqlDriver.RowResult]: + def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] - def parse(self, source: list[BaseSqlDriver.RowResult], *args, **kwargs) -> ListArtifact: + def parse(self, source: list[BaseSqlDriver.RowResult]) -> ListArtifact: return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(source)]) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 68dc324aef..7ef649bbc3 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - from attrs import define, field from griptape.artifacts import TextArtifact @@ -9,13 +7,10 @@ @define -class TextLoader(BaseFileLoader): +class TextLoader(BaseFileLoader[TextArtifact]): encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: Any, *args, **kwargs) -> TextArtifact: - return cast(TextArtifact, super().load(source, *args, **kwargs)) - - def parse(self, source: str | bytes, *args, **kwargs) -> TextArtifact: + def parse(self, source: str | bytes) -> TextArtifact: if isinstance(source, str): return TextArtifact(source, encoding=self.encoding) else: diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index d31a348167..de4c32778e 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any, cast - from attrs import Factory, define, field from griptape.artifacts import TextArtifact @@ -10,17 +8,14 @@ @define -class WebLoader(BaseLoader): +class WebLoader(BaseLoader[str, str, TextArtifact]): web_scraper_driver: BaseWebScraperDriver = field( default=Factory(lambda: TrafilaturaWebScraperDriver()), kw_only=True, ) - def load(self, source: Any, *args, **kwargs) -> TextArtifact: - return cast(TextArtifact, super().load(source, *args, **kwargs)) - - def fetch(self, source: str, *args, **kwargs) -> str: + def fetch(self, source: str) -> str: return self.web_scraper_driver.fetch_url(source) - def parse(self, source: str, *args, **kwargs) -> TextArtifact: + def parse(self, source: str) -> TextArtifact: return self.web_scraper_driver.extract_page(source) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 43b10a7c25..67028b00c3 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -65,4 +65,4 @@ def all_negative_rulesets(self) -> list[Ruleset]: def _read_from_file(self, path: str) -> ImageArtifact: logger.info("Reading image from %s", os.path.abspath(path)) - return ImageLoader().load(Path(path).read_bytes()) + return ImageLoader().load(Path(path)) diff --git a/griptape/tools/audio_transcription/tool.py b/griptape/tools/audio_transcription/tool.py index 4174db2090..826aeb895b 100644 --- a/griptape/tools/audio_transcription/tool.py +++ b/griptape/tools/audio_transcription/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -32,7 +31,7 @@ class AudioTranscriptionTool(BaseTool): def transcribe_audio_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: audio_path = params["values"]["path"] - audio_artifact = self.audio_loader.load(Path(audio_path).read_bytes()) + audio_artifact = self.audio_loader.load(audio_path) return self.engine.run(audio_artifact) diff --git a/griptape/tools/image_query/tool.py b/griptape/tools/image_query/tool.py index 9d1dbb89b5..7b654bd722 100644 --- a/griptape/tools/image_query/tool.py +++ b/griptape/tools/image_query/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -41,7 +40,7 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: image_artifacts = [] for image_path in image_paths: - image_artifacts.append(self.image_loader.load(Path(image_path).read_bytes())) + image_artifacts.append(self.image_loader.load(image_path)) return self.image_query_engine.run(query, image_artifacts) diff --git a/griptape/tools/inpainting_image_generation/tool.py b/griptape/tools/inpainting_image_generation/tool.py index d32f481d9d..b529cb637a 100644 --- a/griptape/tools/inpainting_image_generation/tool.py +++ b/griptape/tools/inpainting_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -51,8 +50,8 @@ def image_inpainting_from_file(self, params: dict[str, dict[str, str]]) -> Image image_file = params["values"]["image_file"] mask_file = params["values"]["mask_file"] - input_artifact = self.image_loader.load(Path(image_file).read_bytes()) - mask_artifact = self.image_loader.load(Path(mask_file).read_bytes()) + input_artifact = self.image_loader.load(image_file) + mask_artifact = self.image_loader.load(mask_file) return self._generate_inpainting( prompt, negative_prompt, cast(ImageArtifact, input_artifact), cast(ImageArtifact, mask_artifact) diff --git a/griptape/tools/outpainting_image_generation/tool.py b/griptape/tools/outpainting_image_generation/tool.py index afa39e178c..47863b03dd 100644 --- a/griptape/tools/outpainting_image_generation/tool.py +++ b/griptape/tools/outpainting_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -51,8 +50,8 @@ def image_outpainting_from_file(self, params: dict[str, dict[str, str]]) -> Imag image_file = params["values"]["image_file"] mask_file = params["values"]["mask_file"] - input_artifact = self.image_loader.load(Path(image_file).read_bytes()) - mask_artifact = self.image_loader.load(Path(mask_file).read_bytes()) + input_artifact = self.image_loader.load(image_file) + mask_artifact = self.image_loader.load(mask_file) return self._generate_outpainting(prompt, negative_prompt, input_artifact, mask_artifact) diff --git a/griptape/tools/variation_image_generation/tool.py b/griptape/tools/variation_image_generation/tool.py index 0d4456c2fb..1fb8c8bcc1 100644 --- a/griptape/tools/variation_image_generation/tool.py +++ b/griptape/tools/variation_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -49,7 +48,7 @@ def image_variation_from_file(self, params: dict[str, dict[str, str]]) -> ImageA negative_prompt = params["values"]["negative_prompt"] image_file = params["values"]["image_file"] - image_artifact = self.image_loader.load(Path(image_file).read_bytes()) + image_artifact = self.image_loader.load(image_file) return self._generate_variation(prompt, negative_prompt, image_artifact)