-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
370a5ac
commit b8cee74
Showing
23 changed files
with
169 additions
and
260 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from io import BytesIO | ||
from os import PathLike | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.loaders import BaseLoader | ||
|
||
|
||
@define | ||
class BaseFileLoader(BaseLoader, ABC): | ||
encoding: Optional[str] = field(default=None, kw_only=True) | ||
|
||
def fetch(self, source: str | BytesIO | PathLike, *args, **kwargs) -> bytes: | ||
if isinstance(source, (str, PathLike)): | ||
content = Path(source).read_bytes() | ||
elif isinstance(source, BytesIO): | ||
content = source.read() | ||
|
||
return content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,15 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, cast | ||
|
||
from attrs import define | ||
|
||
from griptape.artifacts import BlobArtifact | ||
from griptape.loaders import BaseLoader | ||
from griptape.loaders import BaseFileLoader | ||
|
||
|
||
@define | ||
class BlobLoader(BaseLoader): | ||
def load(self, source: Any, *args, **kwargs) -> BlobArtifact: | ||
class BlobLoader(BaseFileLoader): | ||
def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: | ||
if self.encoding is None: | ||
return BlobArtifact(source) | ||
else: | ||
return BlobArtifact(source, encoding=self.encoding) | ||
|
||
def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]: | ||
return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
import csv | ||
from io import StringIO | ||
from typing import TYPE_CHECKING, Optional, cast | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import TableArtifact | ||
from griptape.loaders import BaseLoader | ||
|
||
if TYPE_CHECKING: | ||
from griptape.drivers import BaseEmbeddingDriver | ||
from griptape.loaders.text_loader import TextLoader | ||
|
||
|
||
@define | ||
class CsvLoader(BaseLoader): | ||
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) | ||
class CsvLoader(TextLoader): | ||
delimiter: str = field(default=",", kw_only=True) | ||
encoding: str = field(default="utf-8", kw_only=True) | ||
|
||
def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact: | ||
if isinstance(source, bytes): | ||
source = source.decode(encoding=self.encoding) | ||
elif isinstance(source, (bytearray, memoryview)): | ||
raise ValueError(f"Unsupported source type: {type(source)}") | ||
|
||
reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) | ||
|
||
artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) | ||
|
||
if self.embedding_driver: | ||
artifact.generate_embedding(self.embedding_driver) | ||
|
||
return artifact | ||
def parse(self, source: bytes, *args, **kwargs) -> TableArtifact: | ||
reader = csv.DictReader(source.decode(self.encoding), delimiter=self.delimiter) | ||
|
||
def load_collection( | ||
self, | ||
sources: list[bytes | str], | ||
*args, | ||
**kwargs, | ||
) -> dict[str, TableArtifact]: | ||
return cast( | ||
dict[str, TableArtifact], | ||
super().load_collection(sources, *args, **kwargs), | ||
) | ||
return TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from attrs import field | ||
|
||
from griptape.loaders.base_loader import BaseLoader | ||
from griptape.utils.futures import execute_futures_dict | ||
from griptape.utils.hash import bytes_to_hash, str_to_hash | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Mapping | ||
|
||
from griptape.artifacts import BaseArtifact | ||
|
||
|
||
class FileLoader(BaseLoader): | ||
encoding: Optional[str] = field(default=None, kw_only=True) | ||
|
||
@abstractmethod | ||
def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ... | ||
|
||
def load_collection( | ||
self, | ||
sources: list[Any], | ||
*args, | ||
**kwargs, | ||
) -> Mapping[str, BaseArtifact]: | ||
# Create a dictionary before actually submitting the jobs to the executor | ||
# to avoid duplicate work. | ||
sources_by_key = {self.to_key(source): source for source in sources} | ||
|
||
return execute_futures_dict( | ||
{ | ||
key: self.futures_executor.submit(self.load, source, *args, **kwargs) | ||
for key, source in sources_by_key.items() | ||
}, | ||
) | ||
|
||
def to_key(self, source: Any) -> str: | ||
if isinstance(source, bytes): | ||
return bytes_to_hash(source) | ||
else: | ||
return str_to_hash(str(source)) |
Oops, something went wrong.