From b8cee741823744066a75c5ebf10fd82bc24a3abf Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 29 Aug 2024 09:05:59 -0700 Subject: [PATCH] WIP --- griptape/artifacts/__init__.py | 4 +-- griptape/loaders/__init__.py | 12 ++++--- griptape/loaders/audio_loader.py | 7 ++-- griptape/loaders/base_file_loader.py | 24 +++++++++++++ griptape/loaders/base_loader.py | 13 +++++-- griptape/loaders/base_text_loader.py | 33 ------------------ griptape/loaders/blob_loader.py | 11 ++---- griptape/loaders/csv_loader.py | 37 +++----------------- griptape/loaders/dataframe_loader.py | 35 ------------------- griptape/loaders/email_loader.py | 32 ++++++++++-------- griptape/loaders/file_loader.py | 45 +++++++++++++++++++++++++ griptape/loaders/image_loader.py | 26 ++------------ griptape/loaders/pdf_loader.py | 29 +++++----------- griptape/loaders/sql_loader.py | 11 ++---- griptape/loaders/text_loader.py | 26 +++----------- griptape/loaders/web_loader.py | 17 +++++----- pyproject.toml | 1 - tests/unit/loaders/test_audio_loader.py | 5 ++- tests/unit/loaders/test_blob_loader.py | 4 +-- tests/unit/loaders/test_email_loader.py | 10 +++--- tests/unit/loaders/test_image_loader.py | 6 ++-- tests/unit/loaders/test_pdf_loader.py | 21 +++++------- tests/unit/loaders/test_text_loader.py | 20 ++++------- 23 files changed, 169 insertions(+), 260 deletions(-) create mode 100644 griptape/loaders/base_file_loader.py delete mode 100644 griptape/loaders/base_text_loader.py delete mode 100644 griptape/loaders/dataframe_loader.py create mode 100644 griptape/loaders/file_loader.py diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index fae411d3a4..e266ce25c6 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -3,17 +3,17 @@ from .base_system_artifact import BaseSystemArtifact from .error_artifact import ErrorArtifact from .info_artifact import InfoArtifact +from .list_artifact import ListArtifact from .text_artifact import TextArtifact from .json_artifact import JsonArtifact from .csv_row_artifact import CsvRowArtifact from .table_artifact import TableArtifact -from .list_artifact import ListArtifact - from .blob_artifact import BlobArtifact from .image_artifact import ImageArtifact + from .audio_artifact import AudioArtifact from .action_artifact import ActionArtifact diff --git a/griptape/loaders/__init__.py b/griptape/loaders/__init__.py index b79b0ff448..b863706070 100644 --- a/griptape/loaders/__init__.py +++ b/griptape/loaders/__init__.py @@ -1,26 +1,28 @@ from .base_loader import BaseLoader -from .base_text_loader import BaseTextLoader +from .base_file_loader import BaseFileLoader + from .text_loader import TextLoader from .pdf_loader import PdfLoader from .web_loader import WebLoader from .sql_loader import SqlLoader from .csv_loader import CsvLoader -from .dataframe_loader import DataFrameLoader from .email_loader import EmailLoader + +from .blob_loader import BlobLoader + from .image_loader import ImageLoader + from .audio_loader import AudioLoader -from .blob_loader import BlobLoader __all__ = [ "BaseLoader", - "BaseTextLoader", + "BaseFileLoader", "TextLoader", "PdfLoader", "WebLoader", "SqlLoader", "CsvLoader", - "DataFrameLoader", "EmailLoader", "ImageLoader", "AudioLoader", diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 84d6b767ae..befdca034b 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import cast - from attrs import define from griptape.artifacts import AudioArtifact @@ -14,7 +12,6 @@ class AudioLoader(BaseLoader): """Loads audio content into audio artifacts.""" def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: - return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) + filetype = import_optional_dependency("filetype") - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]: - return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs)) + return AudioArtifact(source, format=filetype.guess(source).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py new file mode 100644 index 0000000000..a02d0b5810 --- /dev/null +++ b/griptape/loaders/base_file_loader.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from abc import ABC +from io import BytesIO +from os import PathLike +from pathlib import Path +from typing import Optional + +from attrs import define, field + +from griptape.loaders import BaseLoader + + +@define +class BaseFileLoader(BaseLoader, ABC): + encoding: Optional[str] = field(default=None, kw_only=True) + + def fetch(self, source: str | BytesIO | PathLike, *args, **kwargs) -> bytes: + if isinstance(source, (str, PathLike)): + content = Path(source).read_bytes() + elif isinstance(source, BytesIO): + content = source.read() + + return content diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 03132c38ed..e4918d8d42 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -13,14 +13,23 @@ from collections.abc import Mapping from griptape.artifacts import BaseArtifact + from griptape.common import Reference @define class BaseLoader(FuturesExecutorMixin, ABC): - encoding: Optional[str] = field(default=None, kw_only=True) + reference: Optional[Reference] = field(default=None, kw_only=True) + + def load(self, source: Any, *args, **kwargs) -> BaseArtifact: + data = self.fetch(source) + + return self.parse(data) + + @abstractmethod + def fetch(self, source: Any, *args, **kwargs) -> bytes: ... @abstractmethod - def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ... + def parse(self, source: bytes, *args, **kwargs) -> BaseArtifact: ... def load_collection( self, diff --git a/griptape/loaders/base_text_loader.py b/griptape/loaders/base_text_loader.py deleted file mode 100644 index b750bf56d9..0000000000 --- a/griptape/loaders/base_text_loader.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, cast - -from attrs import Factory, define, field - -from griptape.artifacts import TextArtifact -from griptape.loaders import BaseLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.common import Reference - from griptape.tokenizers import BaseTokenizer - - -@define -class BaseTextLoader(BaseLoader, ABC): - tokenizer: BaseTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - encoding: str = field(default="utf-8", kw_only=True) - reference: Optional[Reference] = field(default=None, kw_only=True) - - @abstractmethod - def load(self, source: Any, *args, **kwargs) -> TextArtifact: ... - - def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, TextArtifact]: - return cast( - dict[str, TextArtifact], - super().load_collection(sources, *args, **kwargs), - ) diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index d0099b47bc..b6b1e44897 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -1,20 +1,15 @@ from __future__ import annotations -from typing import Any, cast - from attrs import define from griptape.artifacts import BlobArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader @define -class BlobLoader(BaseLoader): - def load(self, source: Any, *args, **kwargs) -> BlobArtifact: +class BlobLoader(BaseFileLoader): + def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: if self.encoding is None: return BlobArtifact(source) else: return BlobArtifact(source, encoding=self.encoding) - - def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]: - return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index b54f0d4be2..8dd124775e 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -1,46 +1,19 @@ from __future__ import annotations import csv -from io import StringIO -from typing import TYPE_CHECKING, Optional, cast from attrs import define, field from griptape.artifacts import TableArtifact -from griptape.loaders import BaseLoader - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.loaders.text_loader import TextLoader @define -class CsvLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class CsvLoader(TextLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact: - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - - artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) - - if self.embedding_driver: - artifact.generate_embedding(self.embedding_driver) - - return artifact + def parse(self, source: bytes, *args, **kwargs) -> TableArtifact: + reader = csv.DictReader(source.decode(self.encoding), delimiter=self.delimiter) - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, TableArtifact]: - return cast( - dict[str, TableArtifact], - super().load_collection(sources, *args, **kwargs), - ) + return TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py deleted file mode 100644 index 1d3f628f0b..0000000000 --- a/griptape/loaders/dataframe_loader.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, cast - -from attrs import define, field - -from griptape.artifacts import TableArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency, str_to_hash - -if TYPE_CHECKING: - from pandas import DataFrame - - from griptape.drivers import BaseEmbeddingDriver - - -@define -class DataFrameLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - - def load(self, source: DataFrame, *args, **kwargs) -> TableArtifact: - artifact = TableArtifact(list(source.to_dict(orient="records"))) - - if self.embedding_driver: - artifact.generate_embedding(self.embedding_driver) - - return artifact - - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, TableArtifact]: - return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) - - def to_key(self, source: DataFrame, *args, **kwargs) -> str: - hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object - - return str_to_hash(str(hash_pandas_object(source, index=True).values)) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index f6c9ca4062..771230516b 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations import imaplib -from typing import Optional, cast +from typing import Optional from attrs import astuple, define, field @@ -32,11 +32,10 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: - mailparser = import_optional_dependency("mailparser") + def fetch(self, source: EmailQuery, *args, **kwargs) -> bytes: label, key, search_criteria, max_count = astuple(source) - artifacts = [] + mail_bytes = [] with imaplib.IMAP4_SSL(self.imap_url) as client: client.login(self.username, self.password) @@ -59,19 +58,24 @@ def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: if data is None or not data or data[0] is None: continue - message = mailparser.parse_from_bytes(data[0][1]) - - # Note: mailparser only populates the text_plain field - # if the message content type is explicitly set to 'text/plain'. - if message.text_plain: - artifacts.append(TextArtifact("\n".join(message.text_plain))) + mail_bytes.append(data[0][1]) client.close() - return ListArtifact(artifacts) + return bytes(mail_bytes) + + def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: + mailparser = import_optional_dependency("mailparser") + artifacts = [] + for byte in source: + message = mailparser.parse_from_bytes(byte) + + # Note: mailparser only populates the text_plain field + # if the message content type is explicitly set to 'text/plain'. + if message.text_plain: + artifacts.append(TextArtifact(message.text_plain)) + + return ListArtifact(artifacts) def _count_messages(self, message_numbers: bytes) -> int: return len(list(filter(None, message_numbers.decode().split(" ")))) - - def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact]: - return cast(dict[str, ListArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/file_loader.py b/griptape/loaders/file_loader.py new file mode 100644 index 0000000000..e87865edc6 --- /dev/null +++ b/griptape/loaders/file_loader.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +from attrs import field + +from griptape.loaders.base_loader import BaseLoader +from griptape.utils.futures import execute_futures_dict +from griptape.utils.hash import bytes_to_hash, str_to_hash + +if TYPE_CHECKING: + from collections.abc import Mapping + + from griptape.artifacts import BaseArtifact + + +class FileLoader(BaseLoader): + encoding: Optional[str] = field(default=None, kw_only=True) + + @abstractmethod + def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ... + + def load_collection( + self, + sources: list[Any], + *args, + **kwargs, + ) -> Mapping[str, BaseArtifact]: + # Create a dictionary before actually submitting the jobs to the executor + # to avoid duplicate work. + sources_by_key = {self.to_key(source): source for source in sources} + + return execute_futures_dict( + { + key: self.futures_executor.submit(self.load, source, *args, **kwargs) + for key, source in sources_by_key.items() + }, + ) + + def to_key(self, source: Any) -> str: + if isinstance(source, bytes): + return bytes_to_hash(source) + else: + return str_to_hash(str(source)) diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 0541d84e43..dab7fd50ef 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Optional from attrs import define, field @@ -22,20 +22,10 @@ class ImageLoader(BaseLoader): format: Optional[str] = field(default=None, kw_only=True) - FORMAT_TO_MIME_TYPE = { - "bmp": "image/bmp", - "gif": "image/gif", - "jpeg": "image/jpeg", - "png": "image/png", - "tiff": "image/tiff", - "webp": "image/webp", - } - - def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: + def parse(self, source: bytes, *args, **kwargs) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") image = pil_image.open(BytesIO(source)) - # Normalize format only if requested. if self.format is not None: byte_stream = BytesIO() image.save(byte_stream, format=self.format) @@ -43,15 +33,3 @@ def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: source = byte_stream.getvalue() return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) - - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]: - return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs)) - - def _get_mime_type(self, image_format: str | None) -> str: - if image_format is None: - raise ValueError("image_format is None") - - if image_format.lower() not in self.FORMAT_TO_MIME_TYPE: - raise ValueError(f"Unsupported image format {image_format}") - - return self.FORMAT_TO_MIME_TYPE[image_format.lower()] diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 90236f2b59..4c27d010ed 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,25 +1,18 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Optional -from attrs import Factory, define, field +from attrs import define -from griptape.artifacts import ListArtifact -from griptape.chunkers import PdfChunker -from griptape.loaders import BaseTextLoader +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency @define -class PdfLoader(BaseTextLoader): - chunker: PdfChunker = field( - default=Factory(lambda self: PdfChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True), - kw_only=True, - ) - encoding: None = field(default=None, kw_only=True) - - def load( +class PdfLoader(BaseLoader): + def parse( self, source: bytes, password: Optional[str] = None, @@ -27,12 +20,8 @@ def load( **kwargs, ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") - reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) - return ListArtifact([p.extract_text() for p in reader.pages]) + reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) + pages = [TextArtifact(p.extract_text()) for p in reader.pages] - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ListArtifact]: - return cast( - dict[str, ListArtifact], - super().load_collection(sources, *args, **kwargs), - ) + return ListArtifact(pages) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index d723a5a9fb..0f907b152b 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING from attrs import define, field @@ -8,22 +8,15 @@ from griptape.loaders import BaseLoader if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver, BaseSqlDriver + from griptape.drivers import BaseSqlDriver @define class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) def load(self, source: str, *args, **kwargs) -> TableArtifact: rows = self.sql_driver.execute_query(source) artifact = TableArtifact([row.cells for row in rows] if rows else []) - if self.embedding_driver: - artifact.generate_embedding(self.embedding_driver) - return artifact - - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, TableArtifact]: - return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index fd4af65fcb..8f6facd842 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,32 +1,14 @@ from __future__ import annotations -from typing import cast - from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.loaders import BaseTextLoader +from griptape.loaders import BaseFileLoader @define -class TextLoader(BaseTextLoader): +class TextLoader(BaseFileLoader): encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: str | bytes, *args, **kwargs) -> TextArtifact: - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - return TextArtifact(source) - - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, TextArtifact]: - return cast( - dict[str, TextArtifact], - super().load_collection(sources, *args, **kwargs), - ) + def parse(self, source: bytes, *args, **kwargs) -> TextArtifact: + return TextArtifact(source.decode(self.encoding), encoding=self.encoding) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index b2f31a88df..f97c23dcba 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,22 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from attrs import Factory, define, field +from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver -from griptape.loaders import BaseTextLoader - -if TYPE_CHECKING: - from griptape.artifacts import TextArtifact +from griptape.loaders import BaseLoader @define -class WebLoader(BaseTextLoader): +class WebLoader(BaseLoader): web_scraper_driver: BaseWebScraperDriver = field( default=Factory(lambda: TrafilaturaWebScraperDriver()), kw_only=True, ) - def load(self, source: str, *args, **kwargs) -> TextArtifact: - return self.web_scraper_driver.scrape_url(source) + def fetch(self, source: str, *args, **kwargs) -> bytes: + return self.web_scraper_driver.scrape_url(source).value.encode() + + def parse(self, source: bytes, *args, **kwargs) -> TextArtifact: + return TextArtifact(source.decode()) diff --git a/pyproject.toml b/pyproject.toml index 2afdc5910f..261fb3764f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,6 @@ drivers-image-generation-huggingface = [ "pillow", ] -loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] loaders-email = ["mail-parser"] diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index 425a00de7d..c6e22739b4 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -17,11 +17,10 @@ def create_source(self, bytes_from_resource_path): def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) - artifact = loader.load(source)[0] + artifact = loader.load(source) assert isinstance(artifact, AudioArtifact) assert artifact.mime_type == mime_type - assert len(artifact.value) > 0 def test_load_collection(self, create_source, loader): resource_paths = ["sentences.wav", "sentences2.wav"] @@ -32,7 +31,7 @@ def test_load_collection(self, create_source, loader): assert len(collection) == len(resource_paths) for key in collection: - artifact = collection[key][0] + artifact = collection[key] assert isinstance(artifact, AudioArtifact) assert artifact.mime_type == "audio/wav" assert len(artifact.value) > 0 diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index c824b240e1..4812e669c8 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -18,7 +18,7 @@ def create_source(self, request): def test_load(self, loader, create_source): source = create_source("test.txt") - artifact = loader.load(source)[0] + artifact = loader.load(source) assert isinstance(artifact, BlobArtifact) if loader.encoding is None: @@ -37,7 +37,7 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys key = next(iter(keys)) - artifact = collection[key][0] + artifact = collection[key] assert isinstance(artifact, BlobArtifact) if loader.encoding is None: diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index 75b87551f6..1812dc531e 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -73,7 +73,7 @@ def test_load_with_search(self, loader, mock_search, mock_fetch, match_count): # Then mock_search.assert_called_once_with(None, "key", '"search-criteria"') assert mock_fetch.call_count == match_count - assert isinstance(list_artifact, list) + assert isinstance(list_artifact, ListArtifact) assert to_value_set(list_artifact) == {f"message-{i}" for i in range(match_count)} def test_load_returns_error_artifact_when_select_returns_non_ok(self, loader, mock_select): @@ -140,8 +140,10 @@ def to_message(body: str, content_type: Optional[str]) -> Message: return message -def to_value_set(artifacts: list | dict[str, list]) -> set[str]: +def to_value_set(artifacts: ListArtifact | dict[str, ListArtifact]) -> set[str]: if isinstance(artifacts, dict): - return set({text_artifact.value for list_artifact in artifacts.values() for text_artifact in list_artifact}) + return set( + {text_artifact.value for list_artifact in artifacts.values() for text_artifact in list_artifact.value} + ) else: - return {text_artifact.value for text_artifact in artifacts} + return {artifact.value for artifact in artifacts.value} diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index 54a344e8c3..7093894b00 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -31,7 +31,7 @@ def create_source(self, bytes_from_resource_path): def test_load(self, resource_path, mime_type, loader, create_source): source = create_source(resource_path) - artifact = loader.load(source)[0] + artifact = loader.load(source) assert isinstance(artifact, ImageArtifact) assert artifact.height == 32 @@ -45,7 +45,7 @@ def test_load(self, resource_path, mime_type, loader, create_source): def test_load_normalize(self, resource_path, png_loader, create_source): source = create_source(resource_path) - artifact = png_loader.load(source)[0] + artifact = png_loader.load(source) assert isinstance(artifact, ImageArtifact) assert artifact.height == 32 @@ -64,7 +64,7 @@ def test_load_collection(self, create_source, png_loader): assert collection.keys() == keys for key in keys: - artifact = collection[key][0] + artifact = collection[key] assert isinstance(artifact, ImageArtifact) assert artifact.height == 32 assert artifact.width == 32 diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 3f4f7848e5..119481525a 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -1,15 +1,12 @@ import pytest from griptape.loaders import PdfLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestPdfLoader: @pytest.fixture() def loader(self): - return PdfLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return PdfLoader() @pytest.fixture() def create_source(self, bytes_from_resource_path): @@ -18,12 +15,11 @@ def create_source(self, bytes_from_resource_path): def test_load(self, loader, create_source): source = create_source("bitcoin.pdf") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 151 - assert artifacts[0].value.startswith("Bitcoin: A Peer-to-Peer") - assert artifacts[-1].value.endswith('its applications," 1957.\n9') - assert artifacts[0].embedding == [0, 1] + assert len(artifact) == 9 + assert artifact.value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact.value.endswith('its applications," 1957.\n9') def test_load_collection(self, loader, create_source): resource_paths = ["bitcoin.pdf", "bitcoin-2.pdf"] @@ -37,7 +33,6 @@ def test_load_collection(self, loader, create_source): for key in keys: artifact = collection[key] - assert len(artifact) == 151 - assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") - assert artifact[-1].value.endswith('its applications," 1957.\n9') - assert artifact[0].embedding == [0, 1] + assert len(artifact) == 9 + assert artifact.value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact.value.endswith('its applications," 1957.\n9') diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 07527f9e62..383f3e7a44 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -1,9 +1,6 @@ import pytest from griptape.loaders.text_loader import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestTextLoader: @@ -11,9 +8,9 @@ class TestTextLoader: def loader(self, request): encoding = request.param if encoding is None: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return TextLoader() else: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return TextLoader(encoding=encoding) @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) def create_source(self, request): @@ -22,12 +19,10 @@ def create_source(self, request): def test_load(self, loader, create_source): source = create_source("test.txt") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 39 - assert artifacts[0].value.startswith("foobar foobar foobar") - assert artifacts[0].encoding == loader.encoding - assert artifacts[0].embedding == [0, 1] + assert artifact.value.startswith("foobar foobar foobar") + assert artifact.encoding == loader.encoding def test_load_collection(self, loader, create_source): resource_paths = ["test.txt"] @@ -39,9 +34,6 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys key = next(iter(keys)) - artifacts = collection[key] - assert len(artifacts) == 39 + artifact = collection[key] - artifact = artifacts[0] - assert artifact.embedding == [0, 1] assert artifact.encoding == loader.encoding