diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index dce5388122..c7c79ce851 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -1,11 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional -from attrs import Factory, define, field +from attrs import define, field -import griptape.loaders as loaders -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, TextArtifact @define @@ -17,28 +17,7 @@ class BaseFileManagerDriver(ABC): loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts. """ - default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True) - loaders: dict[str, loaders.BaseLoader] = field( - default=Factory( - lambda: { - "pdf": loaders.PdfLoader(), - "csv": loaders.CsvLoader(), - "txt": loaders.TextLoader(), - "html": loaders.TextLoader(), - "json": loaders.TextLoader(), - "yaml": loaders.TextLoader(), - "xml": loaders.TextLoader(), - "png": loaders.ImageLoader(), - "jpg": loaders.ImageLoader(), - "jpeg": loaders.ImageLoader(), - "webp": loaders.ImageLoader(), - "gif": loaders.ImageLoader(), - "bmp": loaders.ImageLoader(), - "tiff": loaders.ImageLoader(), - }, - ), - kw_only=True, - ) + encoding: Optional[str] = field(default=None, kw_only=True) def list_files(self, path: str) -> TextArtifact | ErrorArtifact: entries = self.try_list_files(path) @@ -47,27 +26,18 @@ def list_files(self, path: str) -> TextArtifact | ErrorArtifact: @abstractmethod def try_list_files(self, path: str) -> list[str]: ... - def load_file(self, path: str) -> BaseArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - source = self.try_load_file(path) - result = loader.load(source) - - if isinstance(result, BaseArtifact): - return result + def load_file(self, path: str) -> BlobArtifact: + if self.encoding is None: + return BlobArtifact(self.try_load_file(path)) else: - return ListArtifact(result) + return BlobArtifact(self.try_load_file(path), encoding=self.encoding) @abstractmethod def try_load_file(self, path: str) -> bytes: ... def save_file(self, path: str, value: bytes | str) -> InfoArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - encoding = None if loader is None else loader.encoding - if isinstance(value, str): - value = value.encode() if encoding is None else value.encode(encoding=encoding) + value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding) elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index a6f1f0726c..e366f57904 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from typing import Optional from attrs import Attribute, Factory, define, field @@ -16,11 +17,11 @@ class LocalFileManagerDriver(BaseFileManagerDriver): workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory. """ - workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True) + workdir: Optional[str] = field(default=Factory(lambda: os.getcwd()), kw_only=True) @workdir.validator # pyright: ignore[reportAttributeAccessIssue] def validate_workdir(self, _: Attribute, workdir: str) -> None: - if not Path(workdir).is_absolute(): + if self.workdir is not None and not Path(workdir).is_absolute(): raise ValueError("Workdir must be an absolute path") def try_list_files(self, path: str) -> list[str]: @@ -42,7 +43,7 @@ def try_save_file(self, path: str, value: bytes) -> None: def _full_path(self, path: str) -> str: path = path.lstrip("/") - full_path = os.path.join(self.workdir, path) + full_path = os.path.join(self.workdir, path) if self.workdir else path # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_slash = path.endswith("/") diff --git a/griptape/drivers/web_scraper/base_web_scraper_driver.py b/griptape/drivers/web_scraper/base_web_scraper_driver.py index ae39f8eacf..0c33f3713c 100644 --- a/griptape/drivers/web_scraper/base_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/base_web_scraper_driver.py @@ -4,5 +4,13 @@ class BaseWebScraperDriver(ABC): + def scrape_url(self, url: str) -> TextArtifact: + source = self.fetch_url(url) + + return self.extract_page(source) + + @abstractmethod + def fetch_url(self, url: str) -> str: ... + @abstractmethod - def scrape_url(self, url: str) -> TextArtifact: ... + def extract_page(self, page: str) -> TextArtifact: ... diff --git a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py index b54ff072f9..4eff4948b5 100644 --- a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py @@ -38,20 +38,8 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver): exclude_ids: list[str] = field(default=Factory(list), kw_only=True) timeout: Optional[int] = field(default=None, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright - bs4 = import_optional_dependency("bs4") - markdownify = import_optional_dependency("markdownify") - - include_links = self.include_links - - # Custom MarkdownConverter to optionally linked urls. If include_links is False only - # the text of the link is returned. - class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): - def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: - if include_links: - return super().convert_a(el, text, convert_as_inline) - return text with sync_playwright() as p, p.chromium.launch(headless=True) as browser: page = browser.new_page() @@ -76,28 +64,43 @@ def skip_loading_images(route: Any) -> Any: if not content: raise Exception("can't access URL") - soup = bs4.BeautifulSoup(content, "html.parser") + return content + + def extract_page(self, page: str) -> TextArtifact: + bs4 = import_optional_dependency("bs4") + markdownify = import_optional_dependency("markdownify") + include_links = self.include_links + + # Custom MarkdownConverter to optionally linked urls. If include_links is False only + # the text of the link is returned. + class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): + def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: + if include_links: + return super().convert_a(el, text, convert_as_inline) + return text + + soup = bs4.BeautifulSoup(page, "html.parser") - # Remove unwanted elements - exclude_selector = ",".join( - self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], - ) - if exclude_selector: - for s in soup.select(exclude_selector): - s.extract() + # Remove unwanted elements + exclude_selector = ",".join( + self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], + ) + if exclude_selector: + for s in soup.select(exclude_selector): + s.extract() - text = OptionalLinksMarkdownConverter().convert_soup(soup) + text = OptionalLinksMarkdownConverter().convert_soup(soup) - # Remove leading and trailing whitespace from the entire text - text = text.strip() + # Remove leading and trailing whitespace from the entire text + text = text.strip() - # Remove trailing whitespace from each line - text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) + # Remove trailing whitespace from each line + text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) - # Indent using 2 spaces instead of tabs - text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) + # Indent using 2 spaces instead of tabs + text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) - # Remove triple+ newlines (keep double newlines for paragraphs) - text = re.sub(r"\n\n+", "\n\n", text) + # Remove triple+ newlines (keep double newlines for paragraphs) + text = re.sub(r"\n\n+", "\n\n", text) - return TextArtifact(text) + return TextArtifact(text) diff --git a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py index 2d785fde26..94b3914eae 100644 --- a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py @@ -12,6 +12,10 @@ class ProxyWebScraperDriver(BaseWebScraperDriver): proxies: dict = field(kw_only=True, metadata={"serializable": False}) params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True}) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: response = requests.get(url, proxies=self.proxies, **self.params) - return TextArtifact(response.text) + + return response.text + + def extract_page(self, page: str) -> TextArtifact: + return TextArtifact(page) diff --git a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py index 06f5573a4a..e87af8af63 100644 --- a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py @@ -12,7 +12,7 @@ class TrafilaturaWebScraperDriver(BaseWebScraperDriver): include_links: bool = field(default=True, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config @@ -29,6 +29,15 @@ def scrape_url(self, url: str) -> TextArtifact: if page is None: raise Exception("can't access URL") + + return page + + def extract_page(self, page: str) -> TextArtifact: + trafilatura = import_optional_dependency("trafilatura") + use_config = trafilatura.settings.use_config + + config = use_config() + extracted_page = trafilatura.extract( page, include_links=self.include_links, diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index befdca034b..a23ce7cc75 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -3,15 +3,15 @@ from attrs import define from griptape.artifacts import AudioArtifact -from griptape.loaders import BaseLoader +from griptape.loaders.base_file_loader import BaseFileLoader from griptape.utils import import_optional_dependency @define -class AudioLoader(BaseLoader): +class AudioLoader(BaseFileLoader): """Loads audio content into audio artifacts.""" - def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: + def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact: filetype = import_optional_dependency("filetype") return AudioArtifact(source, format=filetype.guess(source).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py index a02d0b5810..b6f2e14c3b 100644 --- a/griptape/loaders/base_file_loader.py +++ b/griptape/loaders/base_file_loader.py @@ -1,24 +1,24 @@ from __future__ import annotations from abc import ABC -from io import BytesIO -from os import PathLike -from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field +from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver from griptape.loaders import BaseLoader +if TYPE_CHECKING: + from os import PathLike + @define class BaseFileLoader(BaseLoader, ABC): - encoding: Optional[str] = field(default=None, kw_only=True) - - def fetch(self, source: str | BytesIO | PathLike, *args, **kwargs) -> bytes: - if isinstance(source, (str, PathLike)): - content = Path(source).read_bytes() - elif isinstance(source, BytesIO): - content = source.read() + file_manager_driver: BaseFileManagerDriver = field( + default=Factory(lambda: LocalFileManagerDriver(workdir=None)), + kw_only=True, + ) + encoding: str = field(default="utf-8", kw_only=True) - return content + def fetch(self, source: str | PathLike, *args, **kwargs) -> bytes: + return self.file_manager_driver.load_file(str(source), *args, **kwargs) diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index e4918d8d42..7c16c8ded7 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -26,10 +26,10 @@ def load(self, source: Any, *args, **kwargs) -> BaseArtifact: return self.parse(data) @abstractmethod - def fetch(self, source: Any, *args, **kwargs) -> bytes: ... + def fetch(self, source: Any, *args, **kwargs) -> Any: ... @abstractmethod - def parse(self, source: bytes, *args, **kwargs) -> BaseArtifact: ... + def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ... def load_collection( self, diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 8dd124775e..029e47ac81 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -1,6 +1,7 @@ from __future__ import annotations import csv +from io import StringIO from attrs import define, field @@ -14,6 +15,6 @@ class CsvLoader(TextLoader): encoding: str = field(default="utf-8", kw_only=True) def parse(self, source: bytes, *args, **kwargs) -> TableArtifact: - reader = csv.DictReader(source.decode(self.encoding), delimiter=self.delimiter) + reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter) return TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index 771230516b..d1c35b9fbd 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -32,7 +32,7 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def fetch(self, source: EmailQuery, *args, **kwargs) -> bytes: + def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: label, key, search_criteria, max_count = astuple(source) mail_bytes = [] @@ -62,9 +62,9 @@ def fetch(self, source: EmailQuery, *args, **kwargs) -> bytes: client.close() - return bytes(mail_bytes) + return mail_bytes - def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: + def parse(self, source: list[bytes], *args, **kwargs) -> ListArtifact: mailparser = import_optional_dependency("mailparser") artifacts = [] for byte in source: @@ -73,7 +73,7 @@ def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: # Note: mailparser only populates the text_plain field # if the message content type is explicitly set to 'text/plain'. if message.text_plain: - artifacts.append(TextArtifact(message.text_plain)) + artifacts.append(TextArtifact("\n".join(message.text_plain))) return ListArtifact(artifacts) diff --git a/griptape/loaders/file_loader.py b/griptape/loaders/file_loader.py deleted file mode 100644 index e87865edc6..0000000000 --- a/griptape/loaders/file_loader.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Optional - -from attrs import field - -from griptape.loaders.base_loader import BaseLoader -from griptape.utils.futures import execute_futures_dict -from griptape.utils.hash import bytes_to_hash, str_to_hash - -if TYPE_CHECKING: - from collections.abc import Mapping - - from griptape.artifacts import BaseArtifact - - -class FileLoader(BaseLoader): - encoding: Optional[str] = field(default=None, kw_only=True) - - @abstractmethod - def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ... - - def load_collection( - self, - sources: list[Any], - *args, - **kwargs, - ) -> Mapping[str, BaseArtifact]: - # Create a dictionary before actually submitting the jobs to the executor - # to avoid duplicate work. - sources_by_key = {self.to_key(source): source for source in sources} - - return execute_futures_dict( - { - key: self.futures_executor.submit(self.load, source, *args, **kwargs) - for key, source in sources_by_key.items() - }, - ) - - def to_key(self, source: Any) -> str: - if isinstance(source, bytes): - return bytes_to_hash(source) - else: - return str_to_hash(str(source)) diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index dab7fd50ef..e9f5a8f52c 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -6,12 +6,12 @@ from attrs import define, field from griptape.artifacts import ImageArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader from griptape.utils import import_optional_dependency @define -class ImageLoader(BaseLoader): +class ImageLoader(BaseFileLoader): """Loads images into image artifacts. Attributes: diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 4c27d010ed..a9bb366048 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -6,12 +6,12 @@ from attrs import define from griptape.artifacts import ListArtifact, TextArtifact -from griptape.loaders import BaseLoader +from griptape.loaders.base_file_loader import BaseFileLoader from griptape.utils import import_optional_dependency @define -class PdfLoader(BaseLoader): +class PdfLoader(BaseFileLoader): def parse( self, source: bytes, @@ -20,7 +20,6 @@ def parse( **kwargs, ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") - reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) pages = [TextArtifact(p.extract_text()) for p in reader.pages] diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 0f907b152b..5a4c95180b 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -15,8 +15,8 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) - def load(self, source: str, *args, **kwargs) -> TableArtifact: - rows = self.sql_driver.execute_query(source) - artifact = TableArtifact([row.cells for row in rows] if rows else []) + def fetch(self, source: str, *args, **kwargs) -> list[BaseSqlDriver.RowResult]: + return self.sql_driver.execute_query(source) or [] - return artifact + def parse(self, source: list[BaseSqlDriver.RowResult], *args, **kwargs) -> TableArtifact: + return TableArtifact([row.cells for row in source]) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index f97c23dcba..4d31a81e42 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,11 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from attrs import Factory, define, field -from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver from griptape.loaders import BaseLoader +if TYPE_CHECKING: + from griptape.artifacts import TextArtifact + @define class WebLoader(BaseLoader): @@ -14,8 +18,8 @@ class WebLoader(BaseLoader): kw_only=True, ) - def fetch(self, source: str, *args, **kwargs) -> bytes: - return self.web_scraper_driver.scrape_url(source).value.encode() + def fetch(self, source: str, *args, **kwargs) -> str: + return self.web_scraper_driver.fetch_url(source) - def parse(self, source: bytes, *args, **kwargs) -> TextArtifact: - return TextArtifact(source.decode()) + def parse(self, source: str, *args, **kwargs) -> TextArtifact: + return self.web_scraper_driver.extract_page(source) diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py index 19c9f699ce..9059686ed1 100644 --- a/griptape/utils/file_utils.py +++ b/griptape/utils/file_utils.py @@ -5,9 +5,10 @@ from typing import Optional import griptape.utils as utils +from tests.unit.loaders.conftest import BytesIO -def load_file(path: str) -> bytes: +def load_file(path: str) -> BytesIO: """Load a file from the given path and return its content as bytes. Args: @@ -16,10 +17,10 @@ def load_file(path: str) -> bytes: Returns: The content of the file. """ - return Path(path).read_bytes() + return BytesIO(Path(path).read_bytes()) -def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, bytes]: +def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, BytesIO]: """Load multiple files concurrently and return a dictionary of their content. Args: diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 84ce617685..8ca343daee 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -6,8 +6,8 @@ from moto import mock_s3 from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import AmazonS3FileManagerDriver -from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -154,8 +154,7 @@ def test_list_files_failure(self, workdir, path, expected, driver): def test_load_file(self, driver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -185,9 +184,8 @@ def test_load_file_failure(self, workdir, path, expected, driver): def test_load_file_with_encoding(self, driver): artifact = driver.load_file("resources/test.txt") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) + assert isinstance(artifact, BlobArtifact) + assert artifact.encoding == "utf-8" @pytest.mark.parametrize( ("workdir", "path", "content"), @@ -240,9 +238,7 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_c def test_save_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -253,9 +249,7 @@ def test_save_file_with_encoding(self, session, bucket, get_s3_value): def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, loaders={"txt": TextLoader(encoding="ascii")}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -264,13 +258,10 @@ def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): assert get_s3_value(expected_s3_key) == "foobar" assert result.value == "Successfully saved file" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.load_file(path) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, BlobArtifact) + assert result.encoding == "ascii" diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index 394a838a3f..e08e412703 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -4,9 +4,9 @@ import pytest -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader class TestLocalFileManagerDriver: @@ -127,8 +127,7 @@ def test_list_files_failure(self, workdir, path, expected, temp_dir, driver): def test_load_file(self, driver: LocalFileManagerDriver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -156,23 +155,6 @@ def test_load_file_failure(self, workdir, path, expected, temp_dir, driver): with pytest.raises(expected): driver.load_file(path) - def test_load_file_with_encoding(self, driver: LocalFileManagerDriver): - artifact = driver.load_file("resources/test.txt") - - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) - - def test_load_file_with_encoding_failure(self, driver): - driver = LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, - workdir=os.path.normpath(os.path.abspath(os.path.dirname(__file__) + "../../../../")), - ) - - with pytest.raises(UnicodeDecodeError): - driver.load_file("resources/bitcoin.pdf") - @pytest.mark.parametrize( ("workdir", "path", "content"), [ @@ -224,25 +206,24 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver): driver.save_file(path, "foobar") def test_save_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" def test_save_and_load_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.load_file(os.path.join("test", "foobar.txt")) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, BlobArtifact) + assert result.encoding == "ascii" def _to_driver_workdir(self, temp_dir, workdir): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. diff --git a/tests/unit/loaders/conftest.py b/tests/unit/loaders/conftest.py index 1f698738ab..0bbf839b8a 100644 --- a/tests/unit/loaders/conftest.py +++ b/tests/unit/loaders/conftest.py @@ -1,4 +1,5 @@ import os +from io import BytesIO, StringIO from pathlib import Path import pytest @@ -14,15 +15,15 @@ def create_source(resource_path: str) -> Path: @pytest.fixture() def bytes_from_resource_path(path_from_resource_path): - def create_source(resource_path: str) -> bytes: - return Path(path_from_resource_path(resource_path)).read_bytes() + def create_source(resource_path: str) -> BytesIO: + return BytesIO(Path(path_from_resource_path(resource_path)).read_bytes()) return create_source @pytest.fixture() def str_from_resource_path(path_from_resource_path): - def test_csv_str(resource_path: str) -> str: - return Path(path_from_resource_path(resource_path)).read_text() + def test_csv_str(resource_path: str) -> StringIO: + return StringIO(Path(path_from_resource_path(resource_path)).read_text()) return test_csv_str diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index c6e22739b4..7b35167222 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -9,9 +9,9 @@ class TestAudioLoader: def loader(self): return AudioLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) def test_load(self, resource_path, mime_type, loader, create_source): diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index 4812e669c8..2042381bc4 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -11,7 +11,7 @@ def loader(self, request): kwargs = {"encoding": encoding} if encoding is not None else {} return BlobLoader(**kwargs) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index a63322290f..d92c8aef62 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,7 +1,6 @@ import pytest from griptape.loaders.csv_loader import CsvLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestCsvLoader: @@ -9,15 +8,15 @@ class TestCsvLoader: def loader(self, request): encoding = request.param if encoding is None: - return CsvLoader(embedding_driver=MockEmbeddingDriver()) + return CsvLoader() else: - return CsvLoader(encoding=encoding, embedding_driver=MockEmbeddingDriver()) + return CsvLoader(encoding=encoding) @pytest.fixture() def loader_with_pipe_delimiter(self): - return CsvLoader(delimiter="|", embedding_driver=MockEmbeddingDriver()) + return CsvLoader(delimiter="|") - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) @@ -30,7 +29,6 @@ def test_load(self, loader, create_source): first_artifact = artifact.value[0] assert first_artifact["Foo"] == "foo1" assert first_artifact["Bar"] == "bar1" - assert artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): source = create_source("test-pipe.csv") @@ -41,7 +39,6 @@ def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): first_artifact = artifact.value[0] assert first_artifact["Foo"] == "bar1" assert first_artifact["Bar"] == "foo1" - assert artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): resource_paths = ["test-1.csv", "test-2.csv"] @@ -58,7 +55,6 @@ def test_load_collection(self, loader, create_source): first_artifact = artifact.value[0] assert first_artifact["Foo"] == "foo1" assert first_artifact["Bar"] == "bar1" - assert artifact.embedding == [0, 1] def test_to_text(self, loader, create_source): source = create_source("test-1.csv") diff --git a/tests/unit/loaders/test_dataframe_loader.py b/tests/unit/loaders/test_dataframe_loader.py deleted file mode 100644 index fa9a540844..0000000000 --- a/tests/unit/loaders/test_dataframe_loader.py +++ /dev/null @@ -1,52 +0,0 @@ -import os - -import pandas as pd -import pytest - -from griptape.loaders.dataframe_loader import DataFrameLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - - -class TestDataFrameLoader: - @pytest.fixture() - def loader(self): - return DataFrameLoader(embedding_driver=MockEmbeddingDriver()) - - def test_load_with_path(self, loader): - # test loading a file delimited by comma - path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - - artifact = loader.load(pd.read_csv(path)) - - assert len(artifact) == 10 - first_artifact = artifact.value[0] - assert first_artifact["Foo"] == "foo1" - assert first_artifact["Bar"] == "bar1" - - assert artifact.embedding == [0, 1] - - def test_load_collection_with_path(self, loader): - path1 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-1.csv") - path2 = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources/test-2.csv") - df1 = pd.read_csv(path1) - df2 = pd.read_csv(path2) - collection = loader.load_collection([df1, df2]) - - key1 = loader.to_key(df1) - key2 = loader.to_key(df2) - - assert list(collection.keys()) == [key1, key2] - - artifact = collection[key1] - assert len(artifact) == 10 - first_artifact = artifact.value[0] - assert first_artifact["Foo"] == "foo1" - assert first_artifact["Bar"] == "bar1" - - artifact = collection[key2] - assert len(artifact) == 10 - first_artifact = artifact.value[0] - assert first_artifact["Bar"] == "bar1" - assert first_artifact["Foo"] == "foo1" - - assert artifact.embedding == [0, 1] diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index 7093894b00..9c491fb881 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -13,9 +13,9 @@ def loader(self): def png_loader(self): return ImageLoader(format="png") - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize( ("resource_path", "mime_type"), diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 119481525a..45027b95ca 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -8,9 +8,9 @@ class TestPdfLoader: def loader(self): return PdfLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("bitcoin.pdf") @@ -18,8 +18,8 @@ def test_load(self, loader, create_source): artifact = loader.load(source) assert len(artifact) == 9 - assert artifact.value.startswith("Bitcoin: A Peer-to-Peer") - assert artifact.value.endswith('its applications," 1957.\n9') + assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact[-1].value.endswith('its applications," 1957.\n9') def test_load_collection(self, loader, create_source): resource_paths = ["bitcoin.pdf", "bitcoin-2.pdf"] @@ -34,5 +34,5 @@ def test_load_collection(self, loader, create_source): for key in keys: artifact = collection[key] assert len(artifact) == 9 - assert artifact.value.startswith("Bitcoin: A Peer-to-Peer") - assert artifact.value.endswith('its applications," 1957.\n9') + assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact[-1].value.endswith('its applications," 1957.\n9') diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index e977d3c5f5..b02e58a5aa 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -3,7 +3,6 @@ from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -16,7 +15,6 @@ def loader(self): engine_url="sqlite:///:memory:", create_engine_params={"connect_args": {"check_same_thread": False}, "poolclass": StaticPool}, ), - embedding_driver=MockEmbeddingDriver(), ) sql_loader.sql_driver.execute_query( @@ -42,8 +40,6 @@ def test_load(self, loader): assert artifact.value[1] == {"id": 2, "name": "Bob", "age": 30, "city": "Los Angeles"} assert artifact.value[2] == {"id": 3, "name": "Charlie", "age": 22, "city": "Chicago"} - assert artifact.embedding == [0, 1] - def test_load_collection(self, loader): artifacts = loader.load_collection(["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"]) @@ -57,5 +53,3 @@ def test_load_collection(self, loader): {"age": 25, "city": "New York", "id": 1, "name": "Alice"}, {"age": 30, "city": "Los Angeles", "id": 2, "name": "Bob"}, ] - - assert list(artifacts.values())[0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 383f3e7a44..c75417f56f 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -12,7 +12,7 @@ def loader(self, request): else: return TextLoader(encoding=encoding) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index f7cccb6664..d6e958042b 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,7 +1,6 @@ import pytest from griptape.loaders import WebLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -13,15 +12,12 @@ def _mock_trafilatura_fetch_url(self, mocker): @pytest.fixture() def loader(self): - return WebLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return WebLoader() def test_load(self, loader): - artifacts = loader.load("https://github.com/griptape-ai/griptape") + artifact = loader.load("https://github.com/griptape-ai/griptape") - assert len(artifacts) == 1 - assert "foobar" in artifacts[0].value.lower() - - assert artifacts[0].embedding == [0, 1] + assert "foobar" in artifact.value.lower() def test_load_exception(self, mocker, loader): mocker.patch("trafilatura.fetch_url", side_effect=Exception("error")) @@ -38,9 +34,7 @@ def test_load_collection(self, loader): loader.to_key("https://github.com/griptape-ai/griptape"), loader.to_key("https://github.com/griptape-ai/griptape-docs"), ] - assert "foobar" in [a.value for artifact_list in artifacts.values() for a in artifact_list][0].lower() - - assert list(artifacts.values())[0][0].embedding == [0, 1] + assert "foobar" in [a.value for a in artifacts.values()] def test_empty_page_string_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value="") diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index 00df6958df..1f3de5a9ff 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -10,7 +10,7 @@ class TestFileUtils: def test_load_file(self): dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")) + file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")).read() assert file.decode("utf-8").startswith("foobar foobar foobar") @@ -21,31 +21,27 @@ def test_load_files(self): files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1)) assert len(files) == 2 - test_file = files[utils.str_to_hash(sources[0])] + test_file = files[utils.str_to_hash(sources[0])].read() assert test_file.decode("utf-8").startswith("foobar foobar foobar") - small_file = files[utils.str_to_hash(sources[2])] + small_file = files[utils.str_to_hash(sources[2])].read() assert len(small_file) == 97 assert small_file[:8] == b"\x89PNG\r\n\x1a\n" def test_load_file_with_loader(self): dirname = os.path.dirname(__file__) file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) + artifact = TextLoader().load(file) - assert len(artifacts) == 39 - assert isinstance(artifacts, list) - assert artifacts[0].value.startswith("foobar foobar foobar") + assert artifact.value.startswith("foobar foobar foobar") def test_load_files_with_loader(self): dirname = os.path.dirname(__file__) sources = ["resources/foobar-many.txt"] sources = [os.path.join(dirname, "../../", source) for source in sources] files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS) + loader = TextLoader() collection = loader.load_collection(list(files.values())) test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] - assert len(test_file_artifacts) == 39 - assert isinstance(test_file_artifacts, list) - assert test_file_artifacts[0].value.startswith("foobar foobar foobar") + assert test_file_artifacts.value.startswith("foobar foobar foobar")