Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 30, 2024
1 parent 370a5ac commit b8cee74
Show file tree
Hide file tree
Showing 23 changed files with 169 additions and 260 deletions.
4 changes: 2 additions & 2 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from .base_system_artifact import BaseSystemArtifact
from .error_artifact import ErrorArtifact
from .info_artifact import InfoArtifact
from .list_artifact import ListArtifact

from .text_artifact import TextArtifact
from .json_artifact import JsonArtifact
from .csv_row_artifact import CsvRowArtifact
from .table_artifact import TableArtifact

from .list_artifact import ListArtifact

from .blob_artifact import BlobArtifact

from .image_artifact import ImageArtifact

from .audio_artifact import AudioArtifact

from .action_artifact import ActionArtifact
Expand Down
12 changes: 7 additions & 5 deletions griptape/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
from .base_loader import BaseLoader
from .base_text_loader import BaseTextLoader
from .base_file_loader import BaseFileLoader

from .text_loader import TextLoader
from .pdf_loader import PdfLoader
from .web_loader import WebLoader
from .sql_loader import SqlLoader
from .csv_loader import CsvLoader
from .dataframe_loader import DataFrameLoader
from .email_loader import EmailLoader

from .blob_loader import BlobLoader

from .image_loader import ImageLoader

from .audio_loader import AudioLoader
from .blob_loader import BlobLoader


__all__ = [
"BaseLoader",
"BaseTextLoader",
"BaseFileLoader",
"TextLoader",
"PdfLoader",
"WebLoader",
"SqlLoader",
"CsvLoader",
"DataFrameLoader",
"EmailLoader",
"ImageLoader",
"AudioLoader",
Expand Down
7 changes: 2 additions & 5 deletions griptape/loaders/audio_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import cast

from attrs import define

from griptape.artifacts import AudioArtifact
Expand All @@ -14,7 +12,6 @@ class AudioLoader(BaseLoader):
"""Loads audio content into audio artifacts."""

def load(self, source: bytes, *args, **kwargs) -> AudioArtifact:
return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension)
filetype = import_optional_dependency("filetype")

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]:
return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs))
return AudioArtifact(source, format=filetype.guess(source).extension)
24 changes: 24 additions & 0 deletions griptape/loaders/base_file_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from abc import ABC
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import Optional

from attrs import define, field

from griptape.loaders import BaseLoader


@define
class BaseFileLoader(BaseLoader, ABC):
encoding: Optional[str] = field(default=None, kw_only=True)

def fetch(self, source: str | BytesIO | PathLike, *args, **kwargs) -> bytes:
if isinstance(source, (str, PathLike)):
content = Path(source).read_bytes()
elif isinstance(source, BytesIO):
content = source.read()

return content
13 changes: 11 additions & 2 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@
from collections.abc import Mapping

from griptape.artifacts import BaseArtifact
from griptape.common import Reference


@define
class BaseLoader(FuturesExecutorMixin, ABC):
encoding: Optional[str] = field(default=None, kw_only=True)
reference: Optional[Reference] = field(default=None, kw_only=True)

def load(self, source: Any, *args, **kwargs) -> BaseArtifact:
data = self.fetch(source)

return self.parse(data)

@abstractmethod
def fetch(self, source: Any, *args, **kwargs) -> bytes: ...

@abstractmethod
def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ...
def parse(self, source: bytes, *args, **kwargs) -> BaseArtifact: ...

def load_collection(
self,
Expand Down
33 changes: 0 additions & 33 deletions griptape/loaders/base_text_loader.py

This file was deleted.

11 changes: 3 additions & 8 deletions griptape/loaders/blob_loader.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from __future__ import annotations

from typing import Any, cast

from attrs import define

from griptape.artifacts import BlobArtifact
from griptape.loaders import BaseLoader
from griptape.loaders import BaseFileLoader


@define
class BlobLoader(BaseLoader):
def load(self, source: Any, *args, **kwargs) -> BlobArtifact:
class BlobLoader(BaseFileLoader):
def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact:
if self.encoding is None:
return BlobArtifact(source)
else:
return BlobArtifact(source, encoding=self.encoding)

def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]:
return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs))
37 changes: 5 additions & 32 deletions griptape/loaders/csv_loader.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,19 @@
from __future__ import annotations

import csv
from io import StringIO
from typing import TYPE_CHECKING, Optional, cast

from attrs import define, field

from griptape.artifacts import TableArtifact
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
from griptape.drivers import BaseEmbeddingDriver
from griptape.loaders.text_loader import TextLoader


@define
class CsvLoader(BaseLoader):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
class CsvLoader(TextLoader):
delimiter: str = field(default=",", kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)

def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact:
if isinstance(source, bytes):
source = source.decode(encoding=self.encoding)
elif isinstance(source, (bytearray, memoryview)):
raise ValueError(f"Unsupported source type: {type(source)}")

reader = csv.DictReader(StringIO(source), delimiter=self.delimiter)

artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames)

if self.embedding_driver:
artifact.generate_embedding(self.embedding_driver)

return artifact
def parse(self, source: bytes, *args, **kwargs) -> TableArtifact:
reader = csv.DictReader(source.decode(self.encoding), delimiter=self.delimiter)

def load_collection(
self,
sources: list[bytes | str],
*args,
**kwargs,
) -> dict[str, TableArtifact]:
return cast(
dict[str, TableArtifact],
super().load_collection(sources, *args, **kwargs),
)
return TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames)
35 changes: 0 additions & 35 deletions griptape/loaders/dataframe_loader.py

This file was deleted.

32 changes: 18 additions & 14 deletions griptape/loaders/email_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import imaplib
from typing import Optional, cast
from typing import Optional

from attrs import astuple, define, field

Expand Down Expand Up @@ -32,11 +32,10 @@ class EmailQuery:
username: str = field(kw_only=True)
password: str = field(kw_only=True)

def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact:
mailparser = import_optional_dependency("mailparser")
def fetch(self, source: EmailQuery, *args, **kwargs) -> bytes:
label, key, search_criteria, max_count = astuple(source)

artifacts = []
mail_bytes = []
with imaplib.IMAP4_SSL(self.imap_url) as client:
client.login(self.username, self.password)

Expand All @@ -59,19 +58,24 @@ def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact:
if data is None or not data or data[0] is None:
continue

message = mailparser.parse_from_bytes(data[0][1])

# Note: mailparser only populates the text_plain field
# if the message content type is explicitly set to 'text/plain'.
if message.text_plain:
artifacts.append(TextArtifact("\n".join(message.text_plain)))
mail_bytes.append(data[0][1])

client.close()

return ListArtifact(artifacts)
return bytes(mail_bytes)

def parse(self, source: bytes, *args, **kwargs) -> ListArtifact:
mailparser = import_optional_dependency("mailparser")
artifacts = []
for byte in source:
message = mailparser.parse_from_bytes(byte)

# Note: mailparser only populates the text_plain field
# if the message content type is explicitly set to 'text/plain'.
if message.text_plain:
artifacts.append(TextArtifact(message.text_plain))

return ListArtifact(artifacts)

def _count_messages(self, message_numbers: bytes) -> int:
return len(list(filter(None, message_numbers.decode().split(" "))))

def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact]:
return cast(dict[str, ListArtifact], super().load_collection(sources, *args, **kwargs))
45 changes: 45 additions & 0 deletions griptape/loaders/file_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from attrs import field

from griptape.loaders.base_loader import BaseLoader
from griptape.utils.futures import execute_futures_dict
from griptape.utils.hash import bytes_to_hash, str_to_hash

if TYPE_CHECKING:
from collections.abc import Mapping

from griptape.artifacts import BaseArtifact


class FileLoader(BaseLoader):
encoding: Optional[str] = field(default=None, kw_only=True)

@abstractmethod
def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ...

def load_collection(
self,
sources: list[Any],
*args,
**kwargs,
) -> Mapping[str, BaseArtifact]:
# Create a dictionary before actually submitting the jobs to the executor
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}

return execute_futures_dict(
{
key: self.futures_executor.submit(self.load, source, *args, **kwargs)
for key, source in sources_by_key.items()
},
)

def to_key(self, source: Any) -> str:
if isinstance(source, bytes):
return bytes_to_hash(source)
else:
return str_to_hash(str(source))
Loading

0 comments on commit b8cee74

Please sign in to comment.