Skip to content

Commit

Permalink
Clean up examples
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 30, 2024
1 parent d93e1fa commit 370a5ac
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 98 deletions.
1 change: 0 additions & 1 deletion docs/examples/src/query_webpage_astra_db_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
)

artifacts = WebLoader(max_tokens=256).load(input_blogpost)

vector_store_driver.upsert_text_artifacts({namespace: artifacts})

rag_tool = RagTool(
Expand Down
1 change: 0 additions & 1 deletion docs/griptape-framework/engines/src/rag_engines_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())
artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai")


vector_store.upsert_text_artifacts(
{
"griptape": artifacts,
Expand Down
10 changes: 4 additions & 6 deletions griptape/loaders/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
class AudioLoader(BaseLoader):
"""Loads audio content into audio artifacts."""

def load(self, source: bytes, *args, **kwargs) -> list[AudioArtifact]:
audio_artifact = [AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension)]
def load(self, source: bytes, *args, **kwargs) -> AudioArtifact:
return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension)

return audio_artifact

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[AudioArtifact]]:
return cast(dict[str, list[AudioArtifact]], super().load_collection(sources, *args, **kwargs))
def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]:
return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs))
10 changes: 4 additions & 6 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from griptape.utils.hash import bytes_to_hash, str_to_hash

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Mapping

from griptape.artifacts import BaseArtifact

Expand All @@ -20,14 +20,14 @@ class BaseLoader(FuturesExecutorMixin, ABC):
encoding: Optional[str] = field(default=None, kw_only=True)

@abstractmethod
def load(self, source: Any, *args, **kwargs) -> Sequence[BaseArtifact]: ...
def load(self, source: Any, *args, **kwargs) -> BaseArtifact: ...

def load_collection(
self,
sources: list[Any],
*args,
**kwargs,
) -> Mapping[str, Sequence[BaseArtifact]]:
) -> Mapping[str, BaseArtifact]:
# Create a dictionary before actually submitting the jobs to the executor
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}
Expand All @@ -39,10 +39,8 @@ def load_collection(
},
)

def to_key(self, source: Any, *args, **kwargs) -> str:
def to_key(self, source: Any) -> str:
if isinstance(source, bytes):
return bytes_to_hash(source)
elif isinstance(source, str):
return str_to_hash(source)
else:
return str_to_hash(str(source))
42 changes: 5 additions & 37 deletions griptape/loaders/base_text_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,28 @@
from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.chunkers import BaseChunker, TextChunker
from griptape.loaders import BaseLoader
from griptape.tokenizers import OpenAiTokenizer

if TYPE_CHECKING:
from griptape.common import Reference
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import BaseTokenizer


@define
class BaseTextLoader(BaseLoader, ABC):
MAX_TOKEN_RATIO = 0.5

tokenizer: OpenAiTokenizer = field(
tokenizer: BaseTokenizer = field(
default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)),
kw_only=True,
)
max_tokens: int = field(
default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True),
kw_only=True,
)
chunker: BaseChunker = field(
default=Factory(
lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens),
takes_self=True,
),
kw_only=True,
)
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)
reference: Optional[Reference] = field(default=None, kw_only=True)

@abstractmethod
def load(self, source: Any, *args, **kwargs) -> list[TextArtifact]: ...
def load(self, source: Any, *args, **kwargs) -> TextArtifact: ...

def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, list[TextArtifact]]:
def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, TextArtifact]:
return cast(
dict[str, list[TextArtifact]],
dict[str, TextArtifact],
super().load_collection(sources, *args, **kwargs),
)

def _text_to_artifacts(self, text: str) -> list[TextArtifact]:
artifacts = []

chunks = self.chunker.chunk(text) if self.chunker else [TextArtifact(text)]

for chunk in chunks:
if self.embedding_driver:
chunk.generate_embedding(self.embedding_driver)

chunk.reference = self.reference

chunk.encoding = self.encoding

artifacts.append(chunk)

return artifacts
4 changes: 2 additions & 2 deletions griptape/loaders/blob_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
class BlobLoader(BaseLoader):
def load(self, source: Any, *args, **kwargs) -> BlobArtifact:
if self.encoding is None:
return [BlobArtifact(source)]
return BlobArtifact(source)
else:
return [BlobArtifact(source, encoding=self.encoding)]
return BlobArtifact(source, encoding=self.encoding)

def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]:
return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs))
3 changes: 1 addition & 2 deletions griptape/loaders/dataframe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from griptape.artifacts import TableArtifact
from griptape.loaders import BaseLoader
from griptape.utils import import_optional_dependency
from griptape.utils.hash import str_to_hash
from griptape.utils import import_optional_dependency, str_to_hash

if TYPE_CHECKING:
from pandas import DataFrame
Expand Down
10 changes: 4 additions & 6 deletions griptape/loaders/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ImageLoader(BaseLoader):
"webp": "image/webp",
}

def load(self, source: bytes, *args, **kwargs) -> list[ImageArtifact]:
def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
pil_image = import_optional_dependency("PIL.Image")
image = pil_image.open(BytesIO(source))

Expand All @@ -42,12 +42,10 @@ def load(self, source: bytes, *args, **kwargs) -> list[ImageArtifact]:
image = pil_image.open(byte_stream)
source = byte_stream.getvalue()

image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)
return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)

return [image_artifact]

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[ImageArtifact]]:
return cast(dict[str, list[ImageArtifact]], super().load_collection(sources, *args, **kwargs))
def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]:
return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs))

def _get_mime_type(self, image_format: str | None) -> str:
if image_format is None:
Expand Down
10 changes: 5 additions & 5 deletions griptape/loaders/pdf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import ListArtifact
from griptape.chunkers import PdfChunker
from griptape.loaders import BaseTextLoader
from griptape.utils import import_optional_dependency
Expand All @@ -25,14 +25,14 @@ def load(
password: Optional[str] = None,
*args,
**kwargs,
) -> list[TextArtifact]:
) -> ListArtifact:
pypdf = import_optional_dependency("pypdf")
reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password)

return self._text_to_artifacts("\n".join([p.extract_text() for p in reader.pages]))
return ListArtifact([p.extract_text() for p in reader.pages])

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[TextArtifact]]:
def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ListArtifact]:
return cast(
dict[str, list[TextArtifact]],
dict[str, ListArtifact],
super().load_collection(sources, *args, **kwargs),
)
35 changes: 6 additions & 29 deletions griptape/loaders/text_loader.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, cast
from typing import cast

from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts import TextArtifact
from griptape.chunkers import TextChunker
from griptape.loaders import BaseTextLoader
from griptape.tokenizers import OpenAiTokenizer

if TYPE_CHECKING:
from griptape.drivers import BaseEmbeddingDriver


@define
class TextLoader(BaseTextLoader):
MAX_TOKEN_RATIO = 0.5

tokenizer: OpenAiTokenizer = field(
default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)),
kw_only=True,
)
max_tokens: int = field(
default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True),
kw_only=True,
)
chunker: TextChunker = field(
default=Factory(
lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens),
takes_self=True,
),
kw_only=True,
)
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)

def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]:
def load(self, source: str | bytes, *args, **kwargs) -> TextArtifact:
if isinstance(source, bytes):
source = source.decode(encoding=self.encoding)
elif isinstance(source, (bytearray, memoryview)):
raise ValueError(f"Unsupported source type: {type(source)}")

return self._text_to_artifacts(source)
return TextArtifact(source)

def load_collection(
self,
sources: list[bytes | str],
*args,
**kwargs,
) -> dict[str, list[TextArtifact]]:
) -> dict[str, TextArtifact]:
return cast(
dict[str, list[TextArtifact]],
dict[str, TextArtifact],
super().load_collection(sources, *args, **kwargs),
)
5 changes: 2 additions & 3 deletions griptape/loaders/web_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ class WebLoader(BaseTextLoader):
kw_only=True,
)

def load(self, source: str, *args, **kwargs) -> list[TextArtifact]:
single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source)
return self._text_to_artifacts(single_chunk_text_artifact.value)
def load(self, source: str, *args, **kwargs) -> TextArtifact:
return self.web_scraper_driver.scrape_url(source)

0 comments on commit 370a5ac

Please sign in to comment.