Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 4, 2024
1 parent 1bbe4f0 commit a568189
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 9 deletions.
5 changes: 5 additions & 0 deletions griptape/loaders/audio_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, cast

import filetype
from attrs import define

Expand All @@ -11,5 +13,8 @@
class AudioLoader(BaseFileLoader):
"""Loads audio content into audio artifacts."""

def load(self, source: Any, *args, **kwargs) -> AudioArtifact:
return cast(AudioArtifact, super().load(source, *args, **kwargs))

def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact:
return AudioArtifact(source, format=filetype.guess(source).extension)
2 changes: 1 addition & 1 deletion griptape/loaders/base_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ class BaseFileLoader(BaseLoader, ABC):
)
encoding: str = field(default="utf-8", kw_only=True)

def fetch(self, source: str | PathLike, *args, **kwargs) -> bytes:
def fetch(self, source: str | PathLike, *args, **kwargs) -> str | bytes:
return self.file_manager_driver.load_file(str(source), *args, **kwargs).value
5 changes: 5 additions & 0 deletions griptape/loaders/blob_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, cast

from attrs import define

from griptape.artifacts import BlobArtifact
Expand All @@ -8,6 +10,9 @@

@define
class BlobLoader(BaseFileLoader):
def load(self, source: Any, *args, **kwargs) -> BlobArtifact:
return cast(BlobArtifact, super().load(source, *args, **kwargs))

def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact:
if self.encoding is None:
return BlobArtifact(source)
Expand Down
5 changes: 4 additions & 1 deletion griptape/loaders/email_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import imaplib
from typing import Optional
from typing import Any, Optional, cast

from attrs import astuple, define, field

Expand Down Expand Up @@ -32,6 +32,9 @@ class EmailQuery:
username: str = field(kw_only=True)
password: str = field(kw_only=True)

def load(self, source: Any, *args, **kwargs) -> ListArtifact:
return cast(ListArtifact, super().load(source, *args, **kwargs))

def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]:
label, key, search_criteria, max_count = astuple(source)

Expand Down
5 changes: 4 additions & 1 deletion griptape/loaders/image_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from io import BytesIO
from typing import Optional
from typing import Any, Optional, cast

from attrs import define, field

Expand All @@ -22,6 +22,9 @@ class ImageLoader(BaseFileLoader):

format: Optional[str] = field(default=None, kw_only=True)

def load(self, source: Any, *args, **kwargs) -> ImageArtifact:
return cast(ImageArtifact, super().load(source, *args, **kwargs))

def parse(self, source: bytes, *args, **kwargs) -> ImageArtifact:
pil_image = import_optional_dependency("PIL.Image")
image = pil_image.open(BytesIO(source))
Expand Down
5 changes: 4 additions & 1 deletion griptape/loaders/pdf_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from io import BytesIO
from typing import Optional
from typing import Any, Optional, cast

from attrs import define

Expand All @@ -12,6 +12,9 @@

@define
class PdfLoader(BaseFileLoader):
def load(self, source: Any, *args, **kwargs) -> TextArtifact:
return cast(TextArtifact, super().load(source, *args, **kwargs))

def parse(
self,
source: bytes,
Expand Down
5 changes: 4 additions & 1 deletion griptape/loaders/sql_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

from attrs import define, field

Expand All @@ -15,6 +15,9 @@
class SqlLoader(BaseLoader):
sql_driver: BaseSqlDriver = field(kw_only=True)

def load(self, source: Any, *args, **kwargs) -> TableArtifact:
return cast(TableArtifact, super().load(source, *args, **kwargs))

def fetch(self, source: str, *args, **kwargs) -> list[BaseSqlDriver.RowResult]:
return self.sql_driver.execute_query(source) or []

Expand Down
5 changes: 5 additions & 0 deletions griptape/loaders/text_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, cast

from attrs import define, field

from griptape.artifacts import TextArtifact
Expand All @@ -10,6 +12,9 @@
class TextLoader(BaseFileLoader):
encoding: str = field(default="utf-8", kw_only=True)

def load(self, source: Any, *args, **kwargs) -> TextArtifact:
return cast(TextArtifact, super().load(source, *args, **kwargs))

def parse(self, source: str | bytes, *args, **kwargs) -> TextArtifact:
if isinstance(source, str):
return TextArtifact(source, encoding=self.encoding)
Expand Down
9 changes: 5 additions & 4 deletions griptape/loaders/web_loader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any, cast

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
from griptape.artifacts import TextArtifact


@define
class WebLoader(BaseLoader):
Expand All @@ -18,6 +16,9 @@ class WebLoader(BaseLoader):
kw_only=True,
)

def load(self, source: Any, *args, **kwargs) -> TextArtifact:
return cast(TextArtifact, super().load(source, *args, **kwargs))

def fetch(self, source: str, *args, **kwargs) -> str:
return self.web_scraper_driver.fetch_url(source)

Expand Down

0 comments on commit a568189

Please sign in to comment.