Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 30, 2024
1 parent b8cee74 commit 63469be
Show file tree
Hide file tree
Showing 30 changed files with 169 additions and 313 deletions.
48 changes: 9 additions & 39 deletions griptape/drivers/file_manager/base_file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from attrs import Factory, define, field
from attrs import define, field

import griptape.loaders as loaders
from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact
from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, TextArtifact


@define
Expand All @@ -17,28 +17,7 @@ class BaseFileManagerDriver(ABC):
loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts.
"""

default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True)
loaders: dict[str, loaders.BaseLoader] = field(
default=Factory(
lambda: {
"pdf": loaders.PdfLoader(),
"csv": loaders.CsvLoader(),
"txt": loaders.TextLoader(),
"html": loaders.TextLoader(),
"json": loaders.TextLoader(),
"yaml": loaders.TextLoader(),
"xml": loaders.TextLoader(),
"png": loaders.ImageLoader(),
"jpg": loaders.ImageLoader(),
"jpeg": loaders.ImageLoader(),
"webp": loaders.ImageLoader(),
"gif": loaders.ImageLoader(),
"bmp": loaders.ImageLoader(),
"tiff": loaders.ImageLoader(),
},
),
kw_only=True,
)
encoding: Optional[str] = field(default=None, kw_only=True)

def list_files(self, path: str) -> TextArtifact | ErrorArtifact:
entries = self.try_list_files(path)
Expand All @@ -47,27 +26,18 @@ def list_files(self, path: str) -> TextArtifact | ErrorArtifact:
@abstractmethod
def try_list_files(self, path: str) -> list[str]: ...

def load_file(self, path: str) -> BaseArtifact:
extension = path.split(".")[-1]
loader = self.loaders.get(extension) or self.default_loader
source = self.try_load_file(path)
result = loader.load(source)

if isinstance(result, BaseArtifact):
return result
def load_file(self, path: str) -> BlobArtifact:
if self.encoding is None:
return BlobArtifact(self.try_load_file(path))
else:
return ListArtifact(result)
return BlobArtifact(self.try_load_file(path), encoding=self.encoding)

@abstractmethod
def try_load_file(self, path: str) -> bytes: ...

def save_file(self, path: str, value: bytes | str) -> InfoArtifact:
extension = path.split(".")[-1]
loader = self.loaders.get(extension) or self.default_loader
encoding = None if loader is None else loader.encoding

if isinstance(value, str):
value = value.encode() if encoding is None else value.encode(encoding=encoding)
value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding)
elif isinstance(value, (bytearray, memoryview)):
raise ValueError(f"Unsupported type: {type(value)}")

Expand Down
7 changes: 4 additions & 3 deletions griptape/drivers/file_manager/local_file_manager_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from pathlib import Path
from typing import Optional

from attrs import Attribute, Factory, define, field

Expand All @@ -16,11 +17,11 @@ class LocalFileManagerDriver(BaseFileManagerDriver):
workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory.
"""

workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True)
workdir: Optional[str] = field(default=Factory(lambda: os.getcwd()), kw_only=True)

@workdir.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_workdir(self, _: Attribute, workdir: str) -> None:
if not Path(workdir).is_absolute():
if self.workdir is not None and not Path(workdir).is_absolute():
raise ValueError("Workdir must be an absolute path")

def try_list_files(self, path: str) -> list[str]:
Expand All @@ -42,7 +43,7 @@ def try_save_file(self, path: str, value: bytes) -> None:

def _full_path(self, path: str) -> str:
path = path.lstrip("/")
full_path = os.path.join(self.workdir, path)
full_path = os.path.join(self.workdir, path) if self.workdir else path
# Need to keep the trailing slash if it was there,
# because it means the path is a directory.
ended_with_slash = path.endswith("/")
Expand Down
10 changes: 9 additions & 1 deletion griptape/drivers/web_scraper/base_web_scraper_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,13 @@


class BaseWebScraperDriver(ABC):
def scrape_url(self, url: str) -> TextArtifact:
source = self.fetch_url(url)

return self.extract_page(source)

@abstractmethod
def fetch_url(self, url: str) -> str: ...

@abstractmethod
def scrape_url(self, url: str) -> TextArtifact: ...
def extract_page(self, page: str) -> TextArtifact: ...
65 changes: 34 additions & 31 deletions griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,8 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver):
exclude_ids: list[str] = field(default=Factory(list), kw_only=True)
timeout: Optional[int] = field(default=None, kw_only=True)

def scrape_url(self, url: str) -> TextArtifact:
def fetch_url(self, url: str) -> str:
sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright
bs4 = import_optional_dependency("bs4")
markdownify = import_optional_dependency("markdownify")

include_links = self.include_links

# Custom MarkdownConverter to optionally linked urls. If include_links is False only
# the text of the link is returned.
class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter):
def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str:
if include_links:
return super().convert_a(el, text, convert_as_inline)
return text

with sync_playwright() as p, p.chromium.launch(headless=True) as browser:
page = browser.new_page()
Expand All @@ -76,28 +64,43 @@ def skip_loading_images(route: Any) -> Any:
if not content:
raise Exception("can't access URL")

soup = bs4.BeautifulSoup(content, "html.parser")
return content

def extract_page(self, page: str) -> TextArtifact:
bs4 = import_optional_dependency("bs4")
markdownify = import_optional_dependency("markdownify")
include_links = self.include_links

# Custom MarkdownConverter to optionally linked urls. If include_links is False only
# the text of the link is returned.
class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter):
def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str:
if include_links:
return super().convert_a(el, text, convert_as_inline)
return text

soup = bs4.BeautifulSoup(page, "html.parser")

# Remove unwanted elements
exclude_selector = ",".join(
self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids],
)
if exclude_selector:
for s in soup.select(exclude_selector):
s.extract()
# Remove unwanted elements
exclude_selector = ",".join(
self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids],
)
if exclude_selector:
for s in soup.select(exclude_selector):
s.extract()

text = OptionalLinksMarkdownConverter().convert_soup(soup)
text = OptionalLinksMarkdownConverter().convert_soup(soup)

# Remove leading and trailing whitespace from the entire text
text = text.strip()
# Remove leading and trailing whitespace from the entire text
text = text.strip()

# Remove trailing whitespace from each line
text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)
# Remove trailing whitespace from each line
text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)

# Indent using 2 spaces instead of tabs
text = re.sub(r"(\n?\s*?)\t", r"\1 ", text)
# Indent using 2 spaces instead of tabs
text = re.sub(r"(\n?\s*?)\t", r"\1 ", text)

# Remove triple+ newlines (keep double newlines for paragraphs)
text = re.sub(r"\n\n+", "\n\n", text)
# Remove triple+ newlines (keep double newlines for paragraphs)
text = re.sub(r"\n\n+", "\n\n", text)

return TextArtifact(text)
return TextArtifact(text)
8 changes: 6 additions & 2 deletions griptape/drivers/web_scraper/proxy_web_scraper_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class ProxyWebScraperDriver(BaseWebScraperDriver):
proxies: dict = field(kw_only=True, metadata={"serializable": False})
params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True})

def scrape_url(self, url: str) -> TextArtifact:
def fetch_url(self, url: str) -> str:
response = requests.get(url, proxies=self.proxies, **self.params)
return TextArtifact(response.text)

return response.text

def extract_page(self, page: str) -> TextArtifact:
return TextArtifact(page)
11 changes: 10 additions & 1 deletion griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class TrafilaturaWebScraperDriver(BaseWebScraperDriver):
include_links: bool = field(default=True, kw_only=True)

def scrape_url(self, url: str) -> TextArtifact:
def fetch_url(self, url: str) -> str:
trafilatura = import_optional_dependency("trafilatura")
use_config = trafilatura.settings.use_config

Expand All @@ -29,6 +29,15 @@ def scrape_url(self, url: str) -> TextArtifact:

if page is None:
raise Exception("can't access URL")

return page

def extract_page(self, page: str) -> TextArtifact:
trafilatura = import_optional_dependency("trafilatura")
use_config = trafilatura.settings.use_config

config = use_config()

extracted_page = trafilatura.extract(
page,
include_links=self.include_links,
Expand Down
6 changes: 3 additions & 3 deletions griptape/loaders/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from attrs import define

from griptape.artifacts import AudioArtifact
from griptape.loaders import BaseLoader
from griptape.loaders.base_file_loader import BaseFileLoader
from griptape.utils import import_optional_dependency


@define
class AudioLoader(BaseLoader):
class AudioLoader(BaseFileLoader):
"""Loads audio content into audio artifacts."""

def load(self, source: bytes, *args, **kwargs) -> AudioArtifact:
def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact:
filetype = import_optional_dependency("filetype")

return AudioArtifact(source, format=filetype.guess(source).extension)
26 changes: 13 additions & 13 deletions griptape/loaders/base_file_loader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from __future__ import annotations

from abc import ABC
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING

from attrs import define, field
from attrs import Factory, define, field

from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
from os import PathLike


@define
class BaseFileLoader(BaseLoader, ABC):
encoding: Optional[str] = field(default=None, kw_only=True)

def fetch(self, source: str | BytesIO | PathLike, *args, **kwargs) -> bytes:
if isinstance(source, (str, PathLike)):
content = Path(source).read_bytes()
elif isinstance(source, BytesIO):
content = source.read()
file_manager_driver: BaseFileManagerDriver = field(
default=Factory(lambda: LocalFileManagerDriver(workdir=None)),
kw_only=True,
)
encoding: str = field(default="utf-8", kw_only=True)

return content
def fetch(self, source: str | PathLike, *args, **kwargs) -> bytes:
return self.file_manager_driver.load_file(str(source), *args, **kwargs)
4 changes: 2 additions & 2 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def load(self, source: Any, *args, **kwargs) -> BaseArtifact:
return self.parse(data)

@abstractmethod
def fetch(self, source: Any, *args, **kwargs) -> bytes: ...
def fetch(self, source: Any, *args, **kwargs) -> Any: ...

@abstractmethod
def parse(self, source: bytes, *args, **kwargs) -> BaseArtifact: ...
def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ...

def load_collection(
self,
Expand Down
3 changes: 2 additions & 1 deletion griptape/loaders/csv_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import csv
from io import StringIO

from attrs import define, field

Expand All @@ -14,6 +15,6 @@ class CsvLoader(TextLoader):
encoding: str = field(default="utf-8", kw_only=True)

def parse(self, source: bytes, *args, **kwargs) -> TableArtifact:
reader = csv.DictReader(source.decode(self.encoding), delimiter=self.delimiter)
reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter)

return TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames)
8 changes: 4 additions & 4 deletions griptape/loaders/email_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class EmailQuery:
username: str = field(kw_only=True)
password: str = field(kw_only=True)

def fetch(self, source: EmailQuery, *args, **kwargs) -> bytes:
def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]:
label, key, search_criteria, max_count = astuple(source)

mail_bytes = []
Expand Down Expand Up @@ -62,9 +62,9 @@ def fetch(self, source: EmailQuery, *args, **kwargs) -> bytes:

client.close()

return bytes(mail_bytes)
return mail_bytes

def parse(self, source: bytes, *args, **kwargs) -> ListArtifact:
def parse(self, source: list[bytes], *args, **kwargs) -> ListArtifact:
mailparser = import_optional_dependency("mailparser")
artifacts = []
for byte in source:
Expand All @@ -73,7 +73,7 @@ def parse(self, source: bytes, *args, **kwargs) -> ListArtifact:
# Note: mailparser only populates the text_plain field
# if the message content type is explicitly set to 'text/plain'.
if message.text_plain:
artifacts.append(TextArtifact(message.text_plain))
artifacts.append(TextArtifact("\n".join(message.text_plain)))

return ListArtifact(artifacts)

Expand Down
Loading

0 comments on commit 63469be

Please sign in to comment.