diff --git a/CHANGELOG.md b/CHANGELOG.md index 233e5e263..16062edb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TavilyWebSearchDriver` to integrate Tavily's web search capabilities. - `ExaWebSearchDriver` to integrate Exa's web search capabilities. - `Workflow.outputs` to access the outputs of a Workflow. +- `BaseFileLoader` for Loaders that load from a path. +- `BaseLoader.fetch()` method for fetching data from a source. +- `BaseLoader.parse()` method for parsing fetched data. +- `BaseFileManager.encoding` to specify the encoding when loading and saving files. +- `BaseWebScraperDriver.extract_page()` method for extracting data from an already scraped web page. +- `TextLoaderRetrievalRagModule.chunker` for specifying the chunking strategy. +- `file_utils.get_mime_type` utility for getting the MIME type of a file. ### Changed - **BREAKING**: Renamed parameters on several classes to `client`: @@ -33,7 +40,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `model_client` on `GooglePromptDriver`. - `model_client` on `GoogleTokenizer`. - **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`. +- **BREAKING**: Update `pypdf` dependency to `^5.0.1`. +- **BREAKING**: Update `redis` dependency to `^5.1.0`. +- **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`. +- **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data. +- **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`. +- **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed. +- **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. Passing `bytes` is still supported but deprecated. +- **BREAKING**: Removed `DataframeLoader`. - Several places where API clients are initialized are now lazy loaded. +- `BaseVectorStoreDriver.upsert_text_artifacts` now returns a list or dictionary of upserted vector ids. +- `LocalFileManagerDriver.workdir` is now optional. +- `filetype` is now a core dependency. +- `FileManagerTool` now uses `filetype` for more accurate file type detection. +- `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set. - `Structure.output`'s type is now `BaseArtifact` and raises an exception if the output is `None`. - **BREAKING**: Update `pypdf` dependency to `^5.0.1`. - **BREAKING**: Update `redis` dependency to `^5.1.0`. @@ -59,8 +79,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Removed `CsvRowArtifact`. Use `TextArtifact` instead. +- **BREAKING**: Removed `DataframeLoader`. - **BREAKING**: Removed `MediaArtifact`, use `ImageArtifact` or `AudioArtifact` instead. -- **BREAKING**: `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return `list[TextArtifact]`. +- **BREAKING**: `CsvLoader` and `SqlLoader` now return `ListArtifact[TextArtifact]`. - **BREAKING**: Removed `ImageArtifact.media_type`. - **BREAKING**: Removed `AudioArtifact.media_type`. - **BREAKING**: Removed `BlobArtifact.dir_name`. diff --git a/MIGRATION.md b/MIGRATION.md index 28c0f0be6..d40c1bdfe 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -153,6 +153,219 @@ print(image_artifact.meta["prompt"], image_artifact.meta["model"]) # Generate an ``` +## 0.31.X to 0.32.X + +### Removed `DataframeLoader` + +`DataframeLoader` has been removed. Use `CsvLoader.parse` or build `TextArtifact`s from the dataframe instead. + +#### Before + +```python +DataframeLoader().load(df) +``` + +#### After +```python +# Convert the dataframe to csv bytes and parse it +CsvLoader().parse(bytes(df.to_csv(line_terminator='\r\n', index=False), encoding='utf-8')) +# Or build TextArtifacts from the dataframe +[TextArtifact(row) for row in source.to_dict(orient="records")] +``` + +### `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`. + +#### Before +```python +PdfLoader().load(Path("attention.pdf").read_bytes()) +PdfLoader().load_collection([Path("attention.pdf").read_bytes(), Path("CoT.pdf").read_bytes()]) +``` + +#### After +```python +PdfLoader().load("attention.pdf") +PdfLoader().load_collection([Path("attention.pdf"), "CoT.pdf"]) +``` + +### Removed `fileutils.load_file` and `fileutils.load_files` + +`griptape.utils.file_utils.load_file` and `griptape.utils.file_utils.load_files` have been removed. +You can now pass the file path directly to the Loader. + +#### Before + +```python +PdfLoader().load(load_file("attention.pdf").read_bytes()) +PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) +``` + +```python +PdfLoader().load("attention.pdf") +PdfLoader().load_collection(["attention.pdf", "CoT.pdf"]) +``` + +### Loaders no longer chunk data + +Loaders no longer chunk the data after loading it. If you need to chunk the data, use a [Chunker](https://docs.griptape.ai/stable/griptape-framework/data/chunkers/) after loading the data. + +#### Before + +```python +chunks = PdfLoader().load("attention.pdf") +vector_store.upsert_text_artifacts( + { + "griptape": chunks, + } +) +``` + +#### After +```python +artifact = PdfLoader().load("attention.pdf") +chunks = Chunker().chunk(artifact) +vector_store.upsert_text_artifacts( + { + "griptape": chunks, + } +) +``` + +### Removed `MediaArtifact` + +`MediaArtifact` has been removed. Use `ImageArtifact` or `AudioArtifact` instead. + +#### Before + +```python +image_media = MediaArtifact( + b"image_data", + media_type="image", + format="jpeg" +) + +audio_media = MediaArtifact( + b"audio_data", + media_type="audio", + format="wav" +) +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg" +) + +audio_artifact = AudioArtifact( + b"audio_data", + format="wav" +) +``` + +### `ImageArtifact.format` is now required + +`ImageArtifact.format` is now a required parameter. Update any code that does not provide a `format` parameter. + +#### Before + +```python +image_artifact = ImageArtifact( + b"image_data" +) +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg" +) +``` + +### Removed `CsvRowArtifact` + +`CsvRowArtifact` has been removed. Use `TextArtifact` instead. + +#### Before + +```python +artifact = CsvRowArtifact({"name": "John", "age": 30}) +print(artifact.value) # {"name": "John", "age": 30} +print(type(artifact.value)) # +``` + +#### After +```python +artifact = TextArtifact("name: John\nage: 30") +print(artifact.value) # name: John\nage: 30 +print(type(artifact.value)) # +``` + +If you require storing a dictionary as an Artifact, you can use `GenericArtifact` instead. + +### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types + +`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`. + +If you require a dictionary, set a custom `formatter_fn` and then parse the text to a dictionary. + +#### Before + +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # +``` + +#### After +```python +results = CsvLoader().load(Path("people.csv").read_text()) + +print(type(results)) # +print(results[0].value) # name: John\nAge: 30 +print(type(results[0].value)) # + +# Customize formatter_fn +results = CsvLoader(formatter_fn=lambda x: json.dumps(x)).load(Path("people.csv").read_text()) +print(results[0].value) # {"name": "John", "age": 30} +print(type(results[0].value)) # + +dict_results = [json.loads(result.value) for result in results] +print(dict_results[0]) # {"name": "John", "age": 30} +print(type(dict_results[0])) # +``` + +### Moved `ImageArtifact.prompt` and `ImageArtifact.model` to `ImageArtifact.meta` + +`ImageArtifact.prompt` and `ImageArtifact.model` have been moved to `ImageArtifact.meta`. + +#### Before + +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg", + prompt="Generate an image of a cat", + model="DALL-E" +) + +print(image_artifact.prompt, image_artifact.model) # Generate an image of a cat, DALL-E +``` + +#### After +```python +image_artifact = ImageArtifact( + b"image_data", + format="jpeg", + meta={"prompt": "Generate an image of a cat", "model": "DALL-E"} +) + +print(image_artifact.meta["prompt"], image_artifact.meta["model"]) # Generate an image of a cat, DALL-E +``` + + ## 0.30.X to 0.31.X ### Exceptions Over `ErrorArtifact`s diff --git a/docs/examples/src/load_query_and_chat_marqo_1.py b/docs/examples/src/load_query_and_chat_marqo_1.py index cdcb376bb..f318abb20 100644 --- a/docs/examples/src/load_query_and_chat_marqo_1.py +++ b/docs/examples/src/load_query_and_chat_marqo_1.py @@ -1,6 +1,7 @@ import os from griptape import utils +from griptape.chunkers import TextChunker from griptape.drivers import MarqoVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -25,11 +26,12 @@ # Load artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) # Upsert the artifacts into the vector store vector_store.upsert_text_artifacts( { - namespace: artifacts, + namespace: chunks, } ) diff --git a/docs/examples/src/query_webpage_1.py b/docs/examples/src/query_webpage_1.py index b9e3286d6..b839b1302 100644 --- a/docs/examples/src/query_webpage_1.py +++ b/docs/examples/src/query_webpage_1.py @@ -1,14 +1,15 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])) -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) -for a in artifacts: - vector_store.upsert_text_artifact(a, namespace="griptape") +vector_store.upsert_text_artifacts({"griptape": chunks}) results = vector_store.query("creativity", count=3, namespace="griptape") diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 4590a6b59..aaf5e2fcc 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import ( AstraDbVectorStoreDriver, OpenAiChatPromptDriver, @@ -43,9 +44,9 @@ ), ) -artifacts = WebLoader(max_tokens=256).load(input_blogpost) - -vector_store_driver.upsert_text_artifacts({namespace: artifacts}) +artifacts = WebLoader().load(input_blogpost) +chunks = TextChunker().chunk(artifacts) +vector_store_driver.upsert_text_artifacts({namespace: chunks}) rag_tool = RagTool( description="A DataStax blog post", diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index 3c29f4c74..be0848b9e 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -1,5 +1,6 @@ import requests +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -30,9 +31,10 @@ rag_engine=engine, ) -artifacts = PdfLoader().load(response.content) +artifacts = PdfLoader().parse(response.content) +chunks = TextChunker().chunk(artifacts) -vector_store.upsert_text_artifacts({namespace: artifacts}) +vector_store.upsert_text_artifacts({namespace: chunks}) agent = Agent(tools=[rag_tool]) diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 3e973da2d..5414c8769 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -26,8 +27,9 @@ ) artifacts = WebLoader().load("https://en.wikipedia.org/wiki/Physics") +chunks = TextChunker().chunk(artifacts) -vector_store_driver.upsert_text_artifacts({namespace: artifacts}) +vector_store_driver.upsert_text_artifacts({namespace: chunks}) rag_tool = RagTool( description="Contains information about physics. " "Use it to answer any physics-related questions.", diff --git a/docs/griptape-framework/data/chunkers.md b/docs/griptape-framework/data/chunkers.md index 507645923..bafbc1c80 100644 --- a/docs/griptape-framework/data/chunkers.md +++ b/docs/griptape-framework/data/chunkers.md @@ -18,3 +18,7 @@ Here is how to use a chunker: ```python --8<-- "docs/griptape-framework/data/src/chunkers_1.py" ``` + +The most common use of a Chunker is to split up a long text into smaller chunks for inserting into a Vector Database when doing Retrieval Augmented Generation (RAG). + +See [RagEngine](../../griptape-framework/engines/rag-engines.md) for more information on how to use Chunkers in RAG pipelines. diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 0c0fc3ead..a8a8cb7c5 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -5,109 +5,96 @@ search: ## Overview -Loaders are used to load textual data from different sources and chunk it into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s. -Each loader can be used to load a single "document" with [load()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load) or -multiple documents with [load_collection()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load_collection). +Loaders are used to load data from sources and parse it into [Artifact](../../griptape-framework/data/artifacts.md)s. +Each loader can be used to load a single "source" with [load()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load) or +multiple sources with [load_collection()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load_collection). -## PDF -!!! info - This driver requires the `loaders-pdf` [extra](../index.md#extras). +## File + +The following Loaders load a file using a [FileManagerDriver](../../reference/griptape/drivers/file_manager/base_file_manager_driver.md) and loads the resulting data into an [Artifact](../../griptape-framework/data/artifacts.md) for the respective file type. + +### Text -Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) and can be used to load PDFs from a path or from an IO stream: +Loads text files into [TextArtifact](../../griptape-framework/data/artifacts.md#text)s: ```python ---8<-- "docs/griptape-framework/data/src/loaders_1.py" +--8<-- "docs/griptape-framework/data/src/loaders_5.py" ``` -## SQL +### PDF -Can be used to load data from a SQL database into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: +!!! info + This driver requires the `loaders-pdf` [extra](../index.md#extras). + +Loads PDF files into [ListArtifact](../../griptape-framework/data/artifacts.md#list)s, where each element is a [TextArtifact](../../griptape-framework/data/artifacts.md#text) containing a page of the PDF: ```python ---8<-- "docs/griptape-framework/data/src/loaders_2.py" +--8<-- "docs/griptape-framework/data/src/loaders_1.py" ``` -## CSV +### CSV -Can be used to load CSV files into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: +Loads CSV files into [ListArtifact](../../griptape-framework/data/artifacts.md#list)s, where each element is a [TextArtifact](../../griptape-framework/data/artifacts.md#text) containing a row of the CSV: ```python --8<-- "docs/griptape-framework/data/src/loaders_3.py" ``` - -## DataFrame +### Image !!! info - This driver requires the `loaders-dataframe` [extra](../index.md#extras). + This driver requires the `loaders-image` [extra](../index.md#extras). + +Loads images into [ImageArtifact](../../griptape-framework/data/artifacts.md#image)s: -Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: ```python ---8<-- "docs/griptape-framework/data/src/loaders_4.py" +--8<-- "docs/griptape-framework/data/src/loaders_7.py" ``` - -## Text - -Used to load arbitrary text and text files: +By default, the Image Loader will load images in their native format, but not all models work on all formats. To normalize the format of Artifacts returned by the Loader, set the `format` field. ```python ---8<-- "docs/griptape-framework/data/src/loaders_5.py" +--8<-- "docs/griptape-framework/data/src/loaders_8.py" ``` -You can set a custom [tokenizer](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.tokenizer), [max_tokens](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.max_tokens) parameter, and [chunker](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.chunker). - -## Web +### Audio -!!! info - This driver requires the `loaders-web` [extra](../index.md#extras). +Loads audio files into [AudioArtifact](../../griptape-framework/data/artifacts.md#audio)s: -Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) and can be used to load web pages: +The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. ```python ---8<-- "docs/griptape-framework/data/src/loaders_6.py" +--8<-- "docs/griptape-framework/data/src/loaders_10.py" ``` -## Image +## Web !!! info - This driver requires the `loaders-image` [extra](../index.md#extras). + This driver requires the `loaders-web` [extra](../index.md#extras). -The Image Loader is used to load an image as an [ImageArtifact](./artifacts.md#image). The Loader operates on image bytes that can be sourced from files on disk, downloaded images, or images in memory. +Scrapes web pages using a [WebScraperDriver](../drivers/web-scraper-drivers.md) and loads the resulting text into [TextArtifact](../../griptape-framework/data/artifacts.md#text)s. ```python ---8<-- "docs/griptape-framework/data/src/loaders_7.py" +--8<-- "docs/griptape-framework/data/src/loaders_6.py" ``` -By default, the Image Loader will load images in their native format, but not all models work on all formats. To normalize the format of Artifacts returned by the Loader, set the `format` field. +## SQL + +Loads data from a SQL database using a [SQLDriver](../drivers/sql-drivers.md) and loads the resulting data into [ListArtifact](../../griptape-framework/data/artifacts.md#list)s, where each element is a [CsvRowArtifact](../../griptape-framework/data/artifacts.md#csv) containing a row of the SQL query. ```python ---8<-- "docs/griptape-framework/data/src/loaders_8.py" +--8<-- "docs/griptape-framework/data/src/loaders_2.py" ``` - ## Email !!! info This driver requires the `loaders-email` [extra](../index.md#extras). -Can be used to load email from an imap server: +Loads data from an imap email server into a [ListArtifact](../../reference/griptape/artifacts/list_artifact.md)s, where each element is a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) containing an email. ```python --8<-- "docs/griptape-framework/data/src/loaders_9.py" ``` - -## Audio - -!!! info - This driver requires the `loaders-audio` [extra](../index.md#extras). - -The [Audio Loader](../../reference/griptape/loaders/audio_loader.md) is used to load audio content as an [AudioArtifact](./artifacts.md#audio). The Loader operates on audio bytes that can be sourced from files on disk, downloaded audio, or audio in memory. - -The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. - -```python ---8<-- "docs/griptape-framework/data/src/loaders_10.py" -``` diff --git a/docs/griptape-framework/data/src/loaders_1.py b/docs/griptape-framework/data/src/loaders_1.py index 2b7f31613..3732d8ac6 100644 --- a/docs/griptape-framework/data/src/loaders_1.py +++ b/docs/griptape-framework/data/src/loaders_1.py @@ -2,18 +2,15 @@ from pathlib import Path from griptape.loaders import PdfLoader -from griptape.utils import load_file, load_files urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "attention.pdf") # Load a single PDF file -PdfLoader().load(Path("attention.pdf").read_bytes()) -# You can also use the load_file utility function -PdfLoader().load(load_file("attention.pdf")) +PdfLoader().load("attention.pdf") +# You can also pass a Path object +PdfLoader().load(Path("attention.pdf")) urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "CoT.pdf") # Load multiple PDF files -PdfLoader().load_collection([Path("attention.pdf").read_bytes(), Path("CoT.pdf").read_bytes()]) -# You can also use the load_files utility function -PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) +PdfLoader().load_collection([Path("attention.pdf"), Path("CoT.pdf")]) diff --git a/docs/griptape-framework/data/src/loaders_10.py b/docs/griptape-framework/data/src/loaders_10.py index 42a64c40e..723fbcdf1 100644 --- a/docs/griptape-framework/data/src/loaders_10.py +++ b/docs/griptape-framework/data/src/loaders_10.py @@ -1,10 +1,9 @@ from pathlib import Path from griptape.loaders import AudioLoader -from griptape.utils import load_file # Load an image from disk -audio_artifact = AudioLoader().load(Path("tests/resources/sentences.wav").read_bytes()) +AudioLoader().load("tests/resources/sentences.wav") -# You can also use the load_file utility function -AudioLoader().load(load_file("tests/resources/sentences.wav")) +# You can also pass a Path object +AudioLoader().load(Path("tests/resources/sentences.wav")) diff --git a/docs/griptape-framework/data/src/loaders_3.py b/docs/griptape-framework/data/src/loaders_3.py index 35af0fdfc..3bc3ceb81 100644 --- a/docs/griptape-framework/data/src/loaders_3.py +++ b/docs/griptape-framework/data/src/loaders_3.py @@ -1,16 +1,11 @@ from pathlib import Path from griptape.loaders import CsvLoader -from griptape.utils import load_file, load_files # Load a single CSV file -CsvLoader().load(Path("tests/resources/cities.csv").read_text()) -# You can also use the load_file utility function -CsvLoader().load(load_file("tests/resources/cities.csv")) +CsvLoader().load("tests/resources/cities.csv") +# You can also pass a Path object +CsvLoader().load(Path("tests/resources/cities.csv")) # Load multiple CSV files -CsvLoader().load_collection( - [Path("tests/resources/cities.csv").read_text(), Path("tests/resources/addresses.csv").read_text()] -) -# You can also use the load_files utility function -CsvLoader().load_collection(list(load_files(["tests/resources/cities.csv", "tests/resources/addresses.csv"]).values())) +CsvLoader().load_collection([Path("tests/resources/cities.csv"), "tests/resources/addresses.csv"]) diff --git a/docs/griptape-framework/data/src/loaders_4.py b/docs/griptape-framework/data/src/loaders_4.py deleted file mode 100644 index 8d5883adf..000000000 --- a/docs/griptape-framework/data/src/loaders_4.py +++ /dev/null @@ -1,13 +0,0 @@ -import urllib.request - -import pandas as pd - -from griptape.loaders import DataFrameLoader - -urllib.request.urlretrieve("https://people.sc.fsu.edu/~jburkardt/data/csv/cities.csv", "cities.csv") - -DataFrameLoader().load(pd.read_csv("cities.csv")) - -urllib.request.urlretrieve("https://people.sc.fsu.edu/~jburkardt/data/csv/addresses.csv", "addresses.csv") - -DataFrameLoader().load_collection([pd.read_csv("cities.csv"), pd.read_csv("addresses.csv")]) diff --git a/docs/griptape-framework/data/src/loaders_7.py b/docs/griptape-framework/data/src/loaders_7.py index 6857886e8..177471fd2 100644 --- a/docs/griptape-framework/data/src/loaders_7.py +++ b/docs/griptape-framework/data/src/loaders_7.py @@ -1,9 +1,9 @@ from pathlib import Path from griptape.loaders import ImageLoader -from griptape.utils import load_file # Load an image from disk -disk_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) -# You can also use the load_file utility function -ImageLoader().load(load_file("tests/resources/mountain.png")) +ImageLoader().load("tests/resources/mountain.png") + +# You can also pass a Path object +ImageLoader().load(Path("tests/resources/mountain.png")) diff --git a/docs/griptape-framework/data/src/loaders_8.py b/docs/griptape-framework/data/src/loaders_8.py index e85992d45..d54c31246 100644 --- a/docs/griptape-framework/data/src/loaders_8.py +++ b/docs/griptape-framework/data/src/loaders_8.py @@ -1,16 +1,11 @@ from pathlib import Path from griptape.loaders import ImageLoader -from griptape.utils import load_file, load_files # Load a single image in BMP format -image_artifact_jpeg = ImageLoader(format="bmp").load(Path("tests/resources/mountain.png").read_bytes()) -# You can also use the load_file utility function -ImageLoader(format="bmp").load(load_file("tests/resources/mountain.png")) +ImageLoader(format="bmp").load("tests/resources/mountain.png") +# You can also pass a Path object +ImageLoader(format="bmp").load(Path("tests/resources/mountain.png")) # Load multiple images in BMP format -ImageLoader().load_collection( - [Path("tests/resources/mountain.png").read_bytes(), Path("tests/resources/cow.png").read_bytes()] -) -# You can also use the load_files utility function -ImageLoader().load_collection(list(load_files(["tests/resources/mountain.png", "tests/resources/cow.png"]).values())) +ImageLoader().load_collection([Path("tests/resources/mountain.png"), "tests/resources/cow.png"]) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_8.py b/docs/griptape-framework/drivers/src/image_generation_drivers_8.py index 69437a3a5..470b47707 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_8.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_8.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.artifacts import TextArtifact from griptape.drivers import ( HuggingFacePipelineImageGenerationDriver, @@ -11,7 +9,7 @@ from griptape.tasks import VariationImageGenerationTask prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") -input_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +input_image_artifact = ImageLoader().load("tests/resources/mountain.png") image_variation_task = VariationImageGenerationTask( input=(prompt_artifact, input_image_artifact), diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_9.py b/docs/griptape-framework/drivers/src/image_generation_drivers_9.py index 2054588d9..ab3dc3113 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_9.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_9.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.artifacts import TextArtifact from griptape.drivers import ( HuggingFacePipelineImageGenerationDriver, @@ -11,7 +9,7 @@ from griptape.tasks import VariationImageGenerationTask prompt_artifact = TextArtifact("landscape photograph, verdant, countryside, 8k") -control_image_artifact = ImageLoader().load(Path("canny_control_image.png").read_bytes()) +control_image_artifact = ImageLoader().load("canny_control_image.png") controlnet_task = VariationImageGenerationTask( input=(prompt_artifact, control_image_artifact), diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_1.py b/docs/griptape-framework/drivers/src/image_query_drivers_1.py index 0c9db5be7..0e0165d97 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_1.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_1.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AnthropicImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,6 +11,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_2.py b/docs/griptape-framework/drivers/src/image_query_drivers_2.py index 8d605c0d9..4b5b3cc9f 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_2.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_2.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AnthropicImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,9 +11,9 @@ image_query_driver=driver, ) -image_artifact1 = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact1 = ImageLoader().load("tests/resources/mountain.png") -image_artifact2 = ImageLoader().load(Path("tests/resources/cow.png").read_bytes()) +image_artifact2 = ImageLoader().load("tests/resources/cow.png") result = engine.run("Describe the weather in the image", [image_artifact1, image_artifact2]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_3.py b/docs/griptape-framework/drivers/src/image_query_drivers_3.py index 14070312b..0653d3f6e 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_3.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_3.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -13,6 +11,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_4.py b/docs/griptape-framework/drivers/src/image_query_drivers_4.py index 9ebf5ef59..cff4c2a10 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_4.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_4.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from griptape.drivers import AzureOpenAiImageQueryDriver from griptape.engines import ImageQueryEngine @@ -17,6 +16,6 @@ image_query_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/image_query_drivers_5.py b/docs/griptape-framework/drivers/src/image_query_drivers_5.py index 2bab9a7fd..c364a24cc 100644 --- a/docs/griptape-framework/drivers/src/image_query_drivers_5.py +++ b/docs/griptape-framework/drivers/src/image_query_drivers_5.py @@ -1,5 +1,3 @@ -from pathlib import Path - import boto3 from griptape.drivers import AmazonBedrockImageQueryDriver, BedrockClaudeImageQueryModelDriver @@ -16,7 +14,7 @@ engine = ImageQueryEngine(image_query_driver=driver) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") result = engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py index 7f7e98e13..f531aa618 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -9,10 +10,11 @@ vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index 39a21121d..4599cfa47 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, QdrantVectorStoreDriver from griptape.loaders import WebLoader @@ -19,7 +20,8 @@ ) # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Recreate Qdrant collection vector_store_driver.client.recreate_collection( @@ -28,7 +30,7 @@ ) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py index a8d9ceed1..144f14c59 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import AstraDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -20,10 +21,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py index 559eaec5a..d84164cb4 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, PineconeVectorStoreDriver from griptape.loaders import WebLoader @@ -14,10 +15,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py index f2f0091a0..ad15f0744 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import MarqoVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -19,12 +20,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py index 7649579c7..33d541f2f 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import MongoDbAtlasVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -25,14 +26,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -vector_store_driver.upsert_text_artifacts( - { - "griptape": artifacts, - } -) +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py index 78a7cc3e6..6a8a8bb04 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import AzureMongoDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -25,12 +26,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py index d34ff8649..7d14504bb 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, RedisVectorStoreDriver from griptape.loaders import WebLoader @@ -15,12 +16,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py index 18e50a397..0ecb49723 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py @@ -2,6 +2,7 @@ import boto3 +from griptape.chunkers import TextChunker from griptape.drivers import AmazonOpenSearchVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -16,12 +17,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py index ad5abf932..9feb47761 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, PgVectorVectorStoreDriver from griptape.loaders import WebLoader @@ -22,12 +23,13 @@ vector_store_driver.setup() # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/engines/src/audio_engines_2.py b/docs/griptape-framework/engines/src/audio_engines_2.py index c04b466f8..92c87d638 100644 --- a/docs/griptape-framework/engines/src/audio_engines_2.py +++ b/docs/griptape-framework/engines/src/audio_engines_2.py @@ -1,7 +1,6 @@ from griptape.drivers import OpenAiAudioTranscriptionDriver from griptape.engines import AudioTranscriptionEngine from griptape.loaders import AudioLoader -from griptape.utils import load_file driver = OpenAiAudioTranscriptionDriver(model="whisper-1") @@ -9,5 +8,5 @@ audio_transcription_driver=driver, ) -audio_artifact = AudioLoader().load(load_file("tests/resources/sentences.wav")) +audio_artifact = AudioLoader().load("tests/resources/sentences.wav") engine.run(audio_artifact) diff --git a/docs/griptape-framework/engines/src/image_generation_engines_3.py b/docs/griptape-framework/engines/src/image_generation_engines_3.py index 83822b1bc..4bcd976d4 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_3.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_3.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,7 +13,7 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run( prompts=["A photo of a mountain landscape in winter"], diff --git a/docs/griptape-framework/engines/src/image_generation_engines_4.py b/docs/griptape-framework/engines/src/image_generation_engines_4.py index c258e1cce..e7b46b341 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_4.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_4.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,9 +13,9 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") engine.run( prompts=["A photo of a castle built into the side of a mountain"], diff --git a/docs/griptape-framework/engines/src/image_generation_engines_5.py b/docs/griptape-framework/engines/src/image_generation_engines_5.py index f91a48ec0..526ebff50 100644 --- a/docs/griptape-framework/engines/src/image_generation_engines_5.py +++ b/docs/griptape-framework/engines/src/image_generation_engines_5.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -15,9 +13,9 @@ image_generation_driver=driver, ) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") engine.run( prompts=["A photo of a mountain shrouded in clouds"], diff --git a/docs/griptape-framework/engines/src/image_query_engines_1.py b/docs/griptape-framework/engines/src/image_query_engines_1.py index b0920392a..c2d08e9a9 100644 --- a/docs/griptape-framework/engines/src/image_query_engines_1.py +++ b/docs/griptape-framework/engines/src/image_query_engines_1.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -8,6 +6,6 @@ engine = ImageQueryEngine(image_query_driver=driver) -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") engine.run("Describe the weather in the image", [image_artifact]) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index a8a9cc06b..6ad28545f 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule @@ -8,12 +9,12 @@ prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) -artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") - +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=500).chunk(artifact) vector_store.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index b5adf2a5a..5a16e4819 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -9,8 +9,6 @@ prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), ) -artifacts = PdfLoader().load(response.content) +artifact = PdfLoader().parse(response.content) -text = "\n\n".join([a.value for a in artifacts]) - -engine.summarize_text(text) +engine.summarize_artifacts(artifact) diff --git a/docs/griptape-framework/structures/src/tasks_12.py b/docs/griptape-framework/structures/src/tasks_12.py index 917b50607..1fdc99e1c 100644 --- a/docs/griptape-framework/structures/src/tasks_12.py +++ b/docs/griptape-framework/structures/src/tasks_12.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,7 +16,7 @@ ) # Load input image artifact. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_13.py b/docs/griptape-framework/structures/src/tasks_13.py index d2aa45983..4b7616d94 100644 --- a/docs/griptape-framework/structures/src/tasks_13.py +++ b/docs/griptape-framework/structures/src/tasks_13.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,9 +16,9 @@ ) # Load input image artifacts. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_14.py b/docs/griptape-framework/structures/src/tasks_14.py index ec489096d..d2e6ba2dd 100644 --- a/docs/griptape-framework/structures/src/tasks_14.py +++ b/docs/griptape-framework/structures/src/tasks_14.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.loaders import ImageLoader @@ -18,9 +16,9 @@ ) # Load input image artifacts. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") -mask_artifact = ImageLoader().load(Path("tests/resources/mountain-mask.png").read_bytes()) +mask_artifact = ImageLoader().load("tests/resources/mountain-mask.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_15.py b/docs/griptape-framework/structures/src/tasks_15.py index 0c60864f7..802ac3397 100644 --- a/docs/griptape-framework/structures/src/tasks_15.py +++ b/docs/griptape-framework/structures/src/tasks_15.py @@ -1,5 +1,3 @@ -from pathlib import Path - from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader @@ -18,7 +16,7 @@ ) # Load the input image artifact. -image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.png") # Instantiate a pipeline. pipeline = Pipeline() diff --git a/docs/griptape-framework/structures/src/tasks_18.py b/docs/griptape-framework/structures/src/tasks_18.py index 08ece5a92..0d3312d4c 100644 --- a/docs/griptape-framework/structures/src/tasks_18.py +++ b/docs/griptape-framework/structures/src/tasks_18.py @@ -3,12 +3,11 @@ from griptape.loaders import AudioLoader from griptape.structures import Pipeline from griptape.tasks import AudioTranscriptionTask -from griptape.utils import load_file driver = OpenAiAudioTranscriptionDriver(model="whisper-1") task = AudioTranscriptionTask( - input=lambda _: AudioLoader().load(load_file("tests/resources/sentences2.wav")), + input=lambda _: AudioLoader().load("tests/resources/sentences2.wav"), audio_transcription_engine=AudioTranscriptionEngine( audio_transcription_driver=driver, ), diff --git a/docs/griptape-framework/structures/src/tasks_3.py b/docs/griptape-framework/structures/src/tasks_3.py index 6584049d0..cdfe894bd 100644 --- a/docs/griptape-framework/structures/src/tasks_3.py +++ b/docs/griptape-framework/structures/src/tasks_3.py @@ -1,10 +1,8 @@ -from pathlib import Path - from griptape.loaders import ImageLoader from griptape.structures import Agent agent = Agent() -image_artifact = ImageLoader().load(Path("tests/resources/mountain.jpg").read_bytes()) +image_artifact = ImageLoader().load("tests/resources/mountain.jpg") agent.run([image_artifact, "What's in this image?"]) diff --git a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py index 26c87e255..bdb60d98b 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -8,8 +9,9 @@ ) artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) -vector_store_driver.upsert_text_artifacts({"griptape": artifacts}) +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) vector_db = VectorStoreTool( description="This DB has information about the Griptape Python framework", vector_store_driver=vector_store_driver, diff --git a/griptape/chunkers/base_chunker.py b/griptape/chunkers/base_chunker.py index 623185237..9b6ef64b9 100644 --- a/griptape/chunkers/base_chunker.py +++ b/griptape/chunkers/base_chunker.py @@ -6,6 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts import TextArtifact +from griptape.artifacts.list_artifact import ListArtifact from griptape.chunkers import ChunkSeparator from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -32,8 +33,8 @@ def validate_max_tokens(self, _: Attribute, max_tokens: int) -> None: if max_tokens < 0: raise ValueError("max_tokens must be 0 or greater.") - def chunk(self, text: TextArtifact | str) -> list[TextArtifact]: - text = text.value if isinstance(text, TextArtifact) else text + def chunk(self, text: TextArtifact | ListArtifact | str) -> list[TextArtifact]: + text = text.to_text() if isinstance(text, (TextArtifact, ListArtifact)) else text return [TextArtifact(c) for c in self._chunk_recursively(text)] diff --git a/griptape/common/prompt_stack/contents/text_message_content.py b/griptape/common/prompt_stack/contents/text_message_content.py index c862564f3..39e678f28 100644 --- a/griptape/common/prompt_stack/contents/text_message_content.py +++ b/griptape/common/prompt_stack/contents/text_message_content.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.common import BaseDeltaMessageContent, BaseMessageContent, TextDeltaMessageContent if TYPE_CHECKING: @@ -13,7 +13,7 @@ @define class TextMessageContent(BaseMessageContent): - artifact: TextArtifact = field(metadata={"serializable": True}) + artifact: BaseArtifact = field(metadata={"serializable": True}) @classmethod def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> TextMessageContent: diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index dce538812..c904f1532 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -1,11 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional -from attrs import Factory, define, field +from attrs import define, field -import griptape.loaders as loaders -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, InfoArtifact, TextArtifact @define @@ -17,57 +17,28 @@ class BaseFileManagerDriver(ABC): loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts. """ - default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True) - loaders: dict[str, loaders.BaseLoader] = field( - default=Factory( - lambda: { - "pdf": loaders.PdfLoader(), - "csv": loaders.CsvLoader(), - "txt": loaders.TextLoader(), - "html": loaders.TextLoader(), - "json": loaders.TextLoader(), - "yaml": loaders.TextLoader(), - "xml": loaders.TextLoader(), - "png": loaders.ImageLoader(), - "jpg": loaders.ImageLoader(), - "jpeg": loaders.ImageLoader(), - "webp": loaders.ImageLoader(), - "gif": loaders.ImageLoader(), - "bmp": loaders.ImageLoader(), - "tiff": loaders.ImageLoader(), - }, - ), - kw_only=True, - ) + workdir: str = field(kw_only=True) + encoding: Optional[str] = field(default=None, kw_only=True) - def list_files(self, path: str) -> TextArtifact | ErrorArtifact: + def list_files(self, path: str) -> TextArtifact: entries = self.try_list_files(path) return TextArtifact("\n".join(list(entries))) @abstractmethod def try_list_files(self, path: str) -> list[str]: ... - def load_file(self, path: str) -> BaseArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - source = self.try_load_file(path) - result = loader.load(source) - - if isinstance(result, BaseArtifact): - return result + def load_file(self, path: str) -> BlobArtifact | TextArtifact: + if self.encoding is None: + return BlobArtifact(self.try_load_file(path)) else: - return ListArtifact(result) + return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding) @abstractmethod def try_load_file(self, path: str) -> bytes: ... def save_file(self, path: str, value: bytes | str) -> InfoArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - encoding = None if loader is None else loader.encoding - if isinstance(value, str): - value = value.encode() if encoding is None else value.encode(encoding=encoding) + value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding) elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index a6f1f0726..b383ff7d7 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from typing import Optional from attrs import Attribute, Factory, define, field @@ -16,11 +17,11 @@ class LocalFileManagerDriver(BaseFileManagerDriver): workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory. """ - workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True) + workdir: Optional[str] = field(default=Factory(lambda: os.getcwd()), kw_only=True) @workdir.validator # pyright: ignore[reportAttributeAccessIssue] def validate_workdir(self, _: Attribute, workdir: str) -> None: - if not Path(workdir).is_absolute(): + if self.workdir is not None and not Path(workdir).is_absolute(): raise ValueError("Workdir must be an absolute path") def try_list_files(self, path: str) -> list[str]: @@ -41,8 +42,7 @@ def try_save_file(self, path: str, value: bytes) -> None: Path(full_path).write_bytes(value) def _full_path(self, path: str) -> str: - path = path.lstrip("/") - full_path = os.path.join(self.workdir, path) + full_path = path if self.workdir is None else os.path.join(self.workdir, path.lstrip("/")) # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_slash = path.endswith("/") diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 50810752e..e2a394bf4 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -43,9 +43,9 @@ def upsert_text_artifacts( *, meta: Optional[dict] = None, **kwargs, - ) -> None: + ) -> list[str] | dict[str, list[str]]: if isinstance(artifacts, list): - utils.execute_futures_list( + return utils.execute_futures_list( [ self.futures_executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs) for a in artifacts @@ -65,7 +65,7 @@ def upsert_text_artifacts( ) ) - utils.execute_futures_list_dict(futures_dict) + return utils.execute_futures_list_dict(futures_dict) def upsert_text_artifact( self, @@ -89,32 +89,20 @@ def upsert_text_artifact( vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver) - if isinstance(vector, list): - return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) - else: - raise ValueError("Vector must be an instance of 'list'.") + return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) def upsert_text( self, string: str, *, - vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, + vector_id: Optional[str] = None, **kwargs, ) -> str: - vector_id = self._get_default_vector_id(string) if vector_id is None else vector_id - - if self.does_entry_exist(vector_id, namespace=namespace): - return vector_id - else: - return self.upsert_vector( - self.embedding_driver.embed_string(string), - vector_id=vector_id, - namespace=namespace, - meta=meta or {}, - **kwargs, - ) + return self.upsert_text_artifact( + TextArtifact(string), vector_id=vector_id, namespace=namespace, meta=meta, **kwargs + ) def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool: try: diff --git a/griptape/drivers/web_scraper/base_web_scraper_driver.py b/griptape/drivers/web_scraper/base_web_scraper_driver.py index ae39f8eac..0c33f3713 100644 --- a/griptape/drivers/web_scraper/base_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/base_web_scraper_driver.py @@ -4,5 +4,13 @@ class BaseWebScraperDriver(ABC): + def scrape_url(self, url: str) -> TextArtifact: + source = self.fetch_url(url) + + return self.extract_page(source) + + @abstractmethod + def fetch_url(self, url: str) -> str: ... + @abstractmethod - def scrape_url(self, url: str) -> TextArtifact: ... + def extract_page(self, page: str) -> TextArtifact: ... diff --git a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py index 654af4e97..8a41fe39e 100644 --- a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py @@ -38,20 +38,8 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver): exclude_ids: list[str] = field(default=Factory(list), kw_only=True) timeout: Optional[int] = field(default=None, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright - bs4 = import_optional_dependency("bs4") - markdownify = import_optional_dependency("markdownify") - - include_links = self.include_links - - # Custom MarkdownConverter to optionally linked urls. If include_links is False only - # the text of the link is returned. - class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): - def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: - if include_links: - return super().convert_a(el, text, convert_as_inline) - return text with sync_playwright() as p, p.chromium.launch(headless=True) as browser: page = browser.new_page() @@ -76,28 +64,43 @@ def skip_loading_images(route: Any) -> Any: if not content: raise Exception("can't access URL") - soup = bs4.BeautifulSoup(content, "html.parser") + return content + + def extract_page(self, page: str) -> TextArtifact: + bs4 = import_optional_dependency("bs4") + markdownify = import_optional_dependency("markdownify") + include_links = self.include_links + + # Custom MarkdownConverter to optionally linked urls. If include_links is False only + # the text of the link is returned. + class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): + def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: + if include_links: + return super().convert_a(el, text, convert_as_inline) + return text + + soup = bs4.BeautifulSoup(page, "html.parser") - # Remove unwanted elements - exclude_selector = ",".join( - self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], - ) - if exclude_selector: - for s in soup.select(exclude_selector): - s.extract() + # Remove unwanted elements + exclude_selector = ",".join( + self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], + ) + if exclude_selector: + for s in soup.select(exclude_selector): + s.extract() - text = OptionalLinksMarkdownConverter().convert_soup(soup) + text = OptionalLinksMarkdownConverter().convert_soup(soup) - # Remove leading and trailing whitespace from the entire text - text = text.strip() + # Remove leading and trailing whitespace from the entire text + text = text.strip() - # Remove trailing whitespace from each line - text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) + # Remove trailing whitespace from each line + text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) - # Indent using 2 spaces instead of tabs - text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) + # Indent using 2 spaces instead of tabs + text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) - # Remove triple+ newlines (keep double newlines for paragraphs) - text = re.sub(r"\n\n+", "\n\n", text) + # Remove triple+ newlines (keep double newlines for paragraphs) + text = re.sub(r"\n\n+", "\n\n", text) - return TextArtifact(text) + return TextArtifact(text) diff --git a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py index 2d785fde2..94b3914ea 100644 --- a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py @@ -12,6 +12,10 @@ class ProxyWebScraperDriver(BaseWebScraperDriver): proxies: dict = field(kw_only=True, metadata={"serializable": False}) params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True}) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: response = requests.get(url, proxies=self.proxies, **self.params) - return TextArtifact(response.text) + + return response.text + + def extract_page(self, page: str) -> TextArtifact: + return TextArtifact(page) diff --git a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py index 06f5573a4..e87af8af6 100644 --- a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py @@ -12,7 +12,7 @@ class TrafilaturaWebScraperDriver(BaseWebScraperDriver): include_links: bool = field(default=True, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config @@ -29,6 +29,15 @@ def scrape_url(self, url: str) -> TextArtifact: if page is None: raise Exception("can't access URL") + + return page + + def extract_page(self, page: str) -> TextArtifact: + trafilatura = import_optional_dependency("trafilatura") + use_config = trafilatura.settings.use_config + + config = use_config() + extracted_page = trafilatura.extract( page, include_links=self.include_links, diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index 7e4854d00..0348a2094 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape import utils +from griptape.chunkers import TextChunker from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -14,12 +15,13 @@ from griptape.artifacts import TextArtifact from griptape.drivers import BaseVectorStoreDriver from griptape.engines.rag import RagContext - from griptape.loaders import BaseTextLoader + from griptape.loaders import TextLoader @define(kw_only=True) class TextLoaderRetrievalRagModule(BaseRetrievalRagModule): - loader: BaseTextLoader = field() + loader: TextLoader = field() + chunker: TextChunker = field(default=Factory(lambda: TextChunker())) vector_store_driver: BaseVectorStoreDriver = field() source: Any = field() query_params: dict[str, Any] = field(factory=dict) @@ -37,7 +39,8 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params["namespace"] = namespace loader_output = self.loader.load(source) + chunks = self.chunker.chunk(loader_output) - self.vector_store_driver.upsert_text_artifacts({namespace: loader_output}) + self.vector_store_driver.upsert_text_artifacts({namespace: chunks}) return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/loaders/__init__.py b/griptape/loaders/__init__.py index b79b0ff44..b86370607 100644 --- a/griptape/loaders/__init__.py +++ b/griptape/loaders/__init__.py @@ -1,26 +1,28 @@ from .base_loader import BaseLoader -from .base_text_loader import BaseTextLoader +from .base_file_loader import BaseFileLoader + from .text_loader import TextLoader from .pdf_loader import PdfLoader from .web_loader import WebLoader from .sql_loader import SqlLoader from .csv_loader import CsvLoader -from .dataframe_loader import DataFrameLoader from .email_loader import EmailLoader + +from .blob_loader import BlobLoader + from .image_loader import ImageLoader + from .audio_loader import AudioLoader -from .blob_loader import BlobLoader __all__ = [ "BaseLoader", - "BaseTextLoader", + "BaseFileLoader", "TextLoader", "PdfLoader", "WebLoader", "SqlLoader", "CsvLoader", - "DataFrameLoader", "EmailLoader", "ImageLoader", "AudioLoader", diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 84d6b767a..0bff5c642 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -1,20 +1,15 @@ from __future__ import annotations -from typing import cast - +import filetype from attrs import define from griptape.artifacts import AudioArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency +from griptape.loaders.base_file_loader import BaseFileLoader @define -class AudioLoader(BaseLoader): +class AudioLoader(BaseFileLoader[AudioArtifact]): """Loads audio content into audio artifacts.""" - def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: - return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) - - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]: - return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs)) + def parse(self, data: bytes) -> AudioArtifact: + return AudioArtifact(data, format=filetype.guess(data).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py new file mode 100644 index 000000000..9fcffa7ae --- /dev/null +++ b/griptape/loaders/base_file_loader.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC +from os import PathLike +from typing import TypeVar, Union + +from attrs import Factory, define, field + +from griptape.artifacts import BaseArtifact +from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver +from griptape.loaders import BaseLoader +from griptape.utils import deprecation_warn + +A = TypeVar("A", bound=BaseArtifact) + + +@define +class BaseFileLoader(BaseLoader[Union[str, PathLike], bytes, A], ABC): + file_manager_driver: BaseFileManagerDriver = field( + default=Factory(lambda: LocalFileManagerDriver(workdir=None)), + kw_only=True, + ) + encoding: str = field(default="utf-8", kw_only=True) + + def fetch(self, source: str | PathLike | bytes) -> bytes: + if isinstance(source, bytes): + deprecation_warn( + "Using bytes as the source is deprecated and will be removed in a future release. " + "Please use a string or PathLike object instead." + ) + return source + + data = self.file_manager_driver.load_file(str(source)).value + if isinstance(data, str): + return data.encode(self.encoding) + else: + return data diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 14f9aa10f..f7340283b 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -1,48 +1,72 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from attrs import define, field +from griptape.artifacts import BaseArtifact from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin from griptape.utils.futures import execute_futures_dict from griptape.utils.hash import bytes_to_hash, str_to_hash if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Mapping - from griptape.artifacts import BaseArtifact + from griptape.common import Reference + +S = TypeVar("S") # Type for the input source +F = TypeVar("F") # Type for the fetched data +A = TypeVar("A", bound=BaseArtifact) # Type for the returned Artifact @define -class BaseLoader(FuturesExecutorMixin, ABC): - encoding: Optional[str] = field(default=None, kw_only=True) +class BaseLoader(FuturesExecutorMixin, ABC, Generic[S, F, A]): + """Fetches data from a source, parses it, and returns an Artifact. + + Attributes: + reference: The optional `Reference` to set on the Artifact. + """ + + reference: Optional[Reference] = field(default=None, kw_only=True) + + def load(self, source: S) -> A: + data = self.fetch(source) + + artifact = self.parse(data) + + artifact.reference = self.reference + + return artifact @abstractmethod - def load(self, source: Any, *args, **kwargs) -> BaseArtifact | Sequence[BaseArtifact]: ... + def fetch(self, source: S) -> F: + """Fetches data from the source.""" + + ... + + @abstractmethod + def parse(self, data: F) -> A: + """Parses the fetched data and returns an Artifact.""" + + ... def load_collection( self, sources: list[Any], - *args, - **kwargs, - ) -> Mapping[str, BaseArtifact | Sequence[BaseArtifact | Sequence[BaseArtifact]]]: + ) -> Mapping[str, A]: + """Loads a collection of sources and returns a dictionary of Artifacts.""" # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} return execute_futures_dict( - { - key: self.futures_executor.submit(self.load, source, *args, **kwargs) - for key, source in sources_by_key.items() - }, + {key: self.futures_executor.submit(self.load, source) for key, source in sources_by_key.items()}, ) - def to_key(self, source: Any, *args, **kwargs) -> str: + def to_key(self, source: S) -> str: + """Converts the source to a key for the collection.""" if isinstance(source, bytes): return bytes_to_hash(source) - elif isinstance(source, str): - return str_to_hash(source) else: return str_to_hash(str(source)) diff --git a/griptape/loaders/base_text_loader.py b/griptape/loaders/base_text_loader.py deleted file mode 100644 index 196cb0087..000000000 --- a/griptape/loaders/base_text_loader.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, cast - -from attrs import Factory, define, field - -from griptape.artifacts import TextArtifact -from griptape.chunkers import BaseChunker, TextChunker -from griptape.loaders import BaseLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.common import Reference - from griptape.drivers import BaseEmbeddingDriver - - -@define -class BaseTextLoader(BaseLoader, ABC): - MAX_TOKEN_RATIO = 0.5 - - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - max_tokens: int = field( - default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True), - kw_only=True, - ) - chunker: BaseChunker = field( - default=Factory( - lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), - takes_self=True, - ), - kw_only=True, - ) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - encoding: str = field(default="utf-8", kw_only=True) - reference: Optional[Reference] = field(default=None, kw_only=True) - - @abstractmethod - def load(self, source: Any, *args, **kwargs) -> list[TextArtifact]: ... - - def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) - - def _text_to_artifacts(self, text: str) -> list[TextArtifact]: - artifacts = [] - - chunks = self.chunker.chunk(text) if self.chunker else [TextArtifact(text)] - - for chunk in chunks: - if self.embedding_driver: - chunk.generate_embedding(self.embedding_driver) - - chunk.reference = self.reference - - chunk.encoding = self.encoding - - artifacts.append(chunk) - - return artifacts diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index d0099b47b..df148e66a 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -1,20 +1,15 @@ from __future__ import annotations -from typing import Any, cast - from attrs import define from griptape.artifacts import BlobArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader @define -class BlobLoader(BaseLoader): - def load(self, source: Any, *args, **kwargs) -> BlobArtifact: +class BlobLoader(BaseFileLoader[BlobArtifact]): + def parse(self, data: bytes) -> BlobArtifact: if self.encoding is None: - return BlobArtifact(source) + return BlobArtifact(data) else: - return BlobArtifact(source, encoding=self.encoding) - - def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]: - return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs)) + return BlobArtifact(data, encoding=self.encoding) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index bcf7029d4..4487d7aec 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,53 +2,25 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import Callable from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.loaders import BaseLoader - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders import BaseFileLoader @define -class CsvLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class CsvLoader(BaseFileLoader[ListArtifact[TextArtifact]]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: - artifacts = [] - - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [TextArtifact(self.formatter_fn(row)) for row in reader] - - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) - - return artifacts + def parse(self, data: bytes) -> ListArtifact[TextArtifact]: + reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), + return ListArtifact( + [TextArtifact(self.formatter_fn(row), meta={"row_num": row_num}) for row_num, row in enumerate(reader)] ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py deleted file mode 100644 index 30d705676..000000000 --- a/griptape/loaders/dataframe_loader.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, Optional, cast - -from attrs import define, field - -from griptape.artifacts import TextArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency -from griptape.utils.hash import str_to_hash - -if TYPE_CHECKING: - from pandas import DataFrame - - from griptape.drivers import BaseEmbeddingDriver - - -@define -class DataFrameLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - formatter_fn: Callable[[dict], str] = field( - default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True - ) - - def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: - artifacts = [] - - chunks = [TextArtifact(self.formatter_fn(row)) for row in source.to_dict(orient="records")] - - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) - - return artifacts - - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) - - def to_key(self, source: DataFrame, *args, **kwargs) -> str: - hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object - - return str_to_hash(str(hash_pandas_object(source, index=True).values)) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index f6c9ca406..8e935cfa4 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations import imaplib -from typing import Optional, cast +from typing import Optional from attrs import astuple, define, field @@ -11,7 +11,7 @@ @define -class EmailLoader(BaseLoader): +class EmailLoader(BaseLoader["EmailLoader.EmailQuery", list[bytes], ListArtifact]): # pyright: ignore[reportGeneralTypeIssues] @define(frozen=True) class EmailQuery: """An email retrieval query. @@ -32,11 +32,10 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: - mailparser = import_optional_dependency("mailparser") + def fetch(self, source: EmailLoader.EmailQuery) -> list[bytes]: label, key, search_criteria, max_count = astuple(source) - artifacts = [] + mail_bytes = [] with imaplib.IMAP4_SSL(self.imap_url) as client: client.login(self.username, self.password) @@ -59,19 +58,24 @@ def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: if data is None or not data or data[0] is None: continue - message = mailparser.parse_from_bytes(data[0][1]) - - # Note: mailparser only populates the text_plain field - # if the message content type is explicitly set to 'text/plain'. - if message.text_plain: - artifacts.append(TextArtifact("\n".join(message.text_plain))) + mail_bytes.append(data[0][1]) client.close() - return ListArtifact(artifacts) + return mail_bytes + + def parse(self, data: list[bytes]) -> ListArtifact[TextArtifact]: + mailparser = import_optional_dependency("mailparser") + artifacts = [] + for byte in data: + message = mailparser.parse_from_bytes(byte) + + # Note: mailparser only populates the text_plain field + # if the message content type is explicitly set to 'text/plain'. + if message.text_plain: + artifacts.append(TextArtifact("\n".join(message.text_plain))) + + return ListArtifact(artifacts) def _count_messages(self, message_numbers: bytes) -> int: return len(list(filter(None, message_numbers.decode().split(" ")))) - - def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact]: - return cast(dict[str, ListArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 83060dfa8..3af3922dc 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -1,17 +1,17 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Optional from attrs import define, field from griptape.artifacts import ImageArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader from griptape.utils import import_optional_dependency @define -class ImageLoader(BaseLoader): +class ImageLoader(BaseFileLoader[ImageArtifact]): """Loads images into image artifacts. Attributes: @@ -22,36 +22,15 @@ class ImageLoader(BaseLoader): format: Optional[str] = field(default=None, kw_only=True) - FORMAT_TO_MIME_TYPE = { - "bmp": "image/bmp", - "gif": "image/gif", - "jpeg": "image/jpeg", - "png": "image/png", - "tiff": "image/tiff", - "webp": "image/webp", - } - - def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: + def parse(self, data: bytes) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") - image = pil_image.open(BytesIO(source)) + image = pil_image.open(BytesIO(data)) # Normalize format only if requested. if self.format is not None: byte_stream = BytesIO() image.save(byte_stream, format=self.format) image = pil_image.open(byte_stream) - source = byte_stream.getvalue() - - return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) - - def _get_mime_type(self, image_format: str | None) -> str: - if image_format is None: - raise ValueError("image_format is None") - - if image_format.lower() not in self.FORMAT_TO_MIME_TYPE: - raise ValueError(f"Unsupported image format {image_format}") - - return self.FORMAT_TO_MIME_TYPE[image_format.lower()] + data = byte_stream.getvalue() - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]: - return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs)) + return ImageArtifact(data, format=image.format.lower(), width=image.width, height=image.height) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 419bfabf4..5bf5337ae 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,37 +1,25 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Optional -from attrs import Factory, define, field +from attrs import define -from griptape.artifacts import TextArtifact -from griptape.chunkers import PdfChunker -from griptape.loaders import BaseTextLoader +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders.base_file_loader import BaseFileLoader from griptape.utils import import_optional_dependency @define -class PdfLoader(BaseTextLoader): - chunker: PdfChunker = field( - default=Factory(lambda self: PdfChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True), - kw_only=True, - ) - encoding: None = field(default=None, kw_only=True) - - def load( +class PdfLoader(BaseFileLoader): + def parse( self, - source: bytes, + data: bytes, + *, password: Optional[str] = None, - *args, - **kwargs, - ) -> list[TextArtifact]: + ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") - reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) - return self._text_to_artifacts("\n".join([p.extract_text() for p in reader.pages])) + reader = pypdf.PdfReader(BytesIO(data), strict=True, password=password) + pages = [TextArtifact(p.extract_text()) for p in reader.pages] - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + return ListArtifact(pages) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 105f585cb..0c6e8bdf9 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,38 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import Callable from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.drivers import BaseSqlDriver from griptape.loaders import BaseLoader -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver, BaseSqlDriver - @define -class SqlLoader(BaseLoader): +class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact[TextArtifact]]): sql_driver: BaseSqlDriver = field(kw_only=True) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: - rows = self.sql_driver.execute_query(source) - artifacts = [] - - chunks = [TextArtifact(self.formatter_fn(row.cells)) for row in rows] if rows else [] - - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) - - return artifacts + def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: + return self.sql_driver.execute_query(source) or [] - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[TextArtifact]: + return ListArtifact([TextArtifact(self.formatter_fn(row.cells)) for row in data]) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 79e551a8e..c33eb9018 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,55 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast - -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.chunkers import TextChunker -from griptape.loaders import BaseTextLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.loaders import BaseFileLoader @define -class TextLoader(BaseTextLoader): - MAX_TOKEN_RATIO = 0.5 - - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - max_tokens: int = field( - default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True), - kw_only=True, - ) - chunker: TextChunker = field( - default=Factory( - lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), - takes_self=True, - ), - kw_only=True, - ) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class TextLoader(BaseFileLoader[TextArtifact]): encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - return self._text_to_artifacts(source) - - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + def parse(self, data: str | bytes) -> TextArtifact: + if isinstance(data, str): + return TextArtifact(data, encoding=self.encoding) + else: + return TextArtifact(data.decode(self.encoding), encoding=self.encoding) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index 720ab34a1..697c18bba 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,23 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from attrs import Factory, define, field +from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver -from griptape.loaders import BaseTextLoader - -if TYPE_CHECKING: - from griptape.artifacts import TextArtifact +from griptape.loaders import BaseLoader @define -class WebLoader(BaseTextLoader): +class WebLoader(BaseLoader[str, str, TextArtifact]): web_scraper_driver: BaseWebScraperDriver = field( default=Factory(lambda: TrafilaturaWebScraperDriver()), kw_only=True, ) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: - single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) - return self._text_to_artifacts(single_chunk_text_artifact.value) + def fetch(self, source: str) -> str: + return self.web_scraper_driver.fetch_url(source) + + def parse(self, data: str) -> TextArtifact: + return self.web_scraper_driver.extract_page(data) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 326b2a551..4a502e7cc 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -66,4 +66,4 @@ def all_negative_rulesets(self) -> list[Ruleset]: def _read_from_file(self, path: str) -> ImageArtifact: logger.info("Reading image from %s", os.path.abspath(path)) - return ImageLoader().load(Path(path).read_bytes()) + return ImageLoader().load(Path(path)) diff --git a/griptape/tools/audio_transcription/tool.py b/griptape/tools/audio_transcription/tool.py index 4174db209..826aeb895 100644 --- a/griptape/tools/audio_transcription/tool.py +++ b/griptape/tools/audio_transcription/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -32,7 +31,7 @@ class AudioTranscriptionTool(BaseTool): def transcribe_audio_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: audio_path = params["values"]["path"] - audio_artifact = self.audio_loader.load(Path(audio_path).read_bytes()) + audio_artifact = self.audio_loader.load(audio_path) return self.engine.run(audio_artifact) diff --git a/griptape/tools/file_manager/tool.py b/griptape/tools/file_manager/tool.py index b72f82329..8dc0c9393 100644 --- a/griptape/tools/file_manager/tool.py +++ b/griptape/tools/file_manager/tool.py @@ -5,9 +5,12 @@ from attrs import Factory, define, field from schema import Literal, Schema +import griptape.loaders as loaders from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver +from griptape.loaders.blob_loader import BlobLoader from griptape.tools import BaseTool +from griptape.utils import get_mime_type from griptape.utils.decorators import activity @@ -21,6 +24,20 @@ class FileManagerTool(BaseTool): file_manager_driver: BaseFileManagerDriver = field(default=Factory(lambda: LocalFileManagerDriver()), kw_only=True) + loaders: dict[str, loaders.BaseLoader] = field( + default=Factory( + lambda self: { + "application/pdf": loaders.PdfLoader(file_manager_driver=self.file_manager_driver), + "text/csv": loaders.CsvLoader(file_manager_driver=self.file_manager_driver), + "text": loaders.TextLoader(file_manager_driver=self.file_manager_driver), + "image": loaders.ImageLoader(file_manager_driver=self.file_manager_driver), + "application/octet-stream": BlobLoader(file_manager_driver=self.file_manager_driver), + }, + takes_self=True, + ), + kw_only=True, + ) + @activity( config={ "description": "Can be used to list files on disk", @@ -51,7 +68,11 @@ def load_files_from_disk(self, params: dict) -> ListArtifact | ErrorArtifact: artifacts = [] for path in paths: - artifact = self.file_manager_driver.load_file(path) + abs_path = os.path.join(self.file_manager_driver.workdir, path) + mime_type = get_mime_type(abs_path) + loader = next((loader for key, loader in self.loaders.items() if mime_type.startswith(key))) + + artifact = loader.load(path) if isinstance(artifact, ListArtifact): artifacts.extend(artifact.value) else: diff --git a/griptape/tools/image_query/tool.py b/griptape/tools/image_query/tool.py index 9d1dbb89b..7b654bd72 100644 --- a/griptape/tools/image_query/tool.py +++ b/griptape/tools/image_query/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -41,7 +40,7 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: image_artifacts = [] for image_path in image_paths: - image_artifacts.append(self.image_loader.load(Path(image_path).read_bytes())) + image_artifacts.append(self.image_loader.load(image_path)) return self.image_query_engine.run(query, image_artifacts) diff --git a/griptape/tools/inpainting_image_generation/tool.py b/griptape/tools/inpainting_image_generation/tool.py index d32f481d9..b529cb637 100644 --- a/griptape/tools/inpainting_image_generation/tool.py +++ b/griptape/tools/inpainting_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -51,8 +50,8 @@ def image_inpainting_from_file(self, params: dict[str, dict[str, str]]) -> Image image_file = params["values"]["image_file"] mask_file = params["values"]["mask_file"] - input_artifact = self.image_loader.load(Path(image_file).read_bytes()) - mask_artifact = self.image_loader.load(Path(mask_file).read_bytes()) + input_artifact = self.image_loader.load(image_file) + mask_artifact = self.image_loader.load(mask_file) return self._generate_inpainting( prompt, negative_prompt, cast(ImageArtifact, input_artifact), cast(ImageArtifact, mask_artifact) diff --git a/griptape/tools/outpainting_image_generation/tool.py b/griptape/tools/outpainting_image_generation/tool.py index afa39e178..47863b03d 100644 --- a/griptape/tools/outpainting_image_generation/tool.py +++ b/griptape/tools/outpainting_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -51,8 +50,8 @@ def image_outpainting_from_file(self, params: dict[str, dict[str, str]]) -> Imag image_file = params["values"]["image_file"] mask_file = params["values"]["mask_file"] - input_artifact = self.image_loader.load(Path(image_file).read_bytes()) - mask_artifact = self.image_loader.load(Path(mask_file).read_bytes()) + input_artifact = self.image_loader.load(image_file) + mask_artifact = self.image_loader.load(mask_file) return self._generate_outpainting(prompt, negative_prompt, input_artifact, mask_artifact) diff --git a/griptape/tools/sql/tool.py b/griptape/tools/sql/tool.py index a84bb87be..59d0a3dba 100644 --- a/griptape/tools/sql/tool.py +++ b/griptape/tools/sql/tool.py @@ -51,6 +51,6 @@ def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArti return ErrorArtifact(f"error executing query: {e}") if len(rows) > 0: - return ListArtifact(rows) + return rows else: return InfoArtifact("No results found") diff --git a/griptape/tools/variation_image_generation/tool.py b/griptape/tools/variation_image_generation/tool.py index 0d4456c2f..1fb8c8bcc 100644 --- a/griptape/tools/variation_image_generation/tool.py +++ b/griptape/tools/variation_image_generation/tool.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, cast from attrs import define, field @@ -49,7 +48,7 @@ def image_variation_from_file(self, params: dict[str, dict[str, str]]) -> ImageA negative_prompt = params["values"]["negative_prompt"] image_file = params["values"]["image_file"] - image_artifact = self.image_loader.load(Path(image_file).read_bytes()) + image_artifact = self.image_loader.load(image_file) return self._generate_variation(prompt, negative_prompt, image_artifact) diff --git a/griptape/tools/web_scraper/tool.py b/griptape/tools/web_scraper/tool.py index 2895d5e0d..982123b30 100644 --- a/griptape/tools/web_scraper/tool.py +++ b/griptape/tools/web_scraper/tool.py @@ -4,6 +4,7 @@ from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, ListArtifact +from griptape.chunkers import TextChunker from griptape.loaders import WebLoader from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -12,6 +13,7 @@ @define class WebScraperTool(BaseTool): web_loader: WebLoader = field(default=Factory(lambda: WebLoader()), kw_only=True) + text_chunker: TextChunker = field(default=Factory(lambda: TextChunker()), kw_only=True) @activity( config={ @@ -24,6 +26,8 @@ def get_content(self, params: dict) -> ListArtifact | ErrorArtifact: try: result = self.web_loader.load(url) - return ListArtifact(result) + chunks = TextChunker().chunk(result) + + return ListArtifact(chunks) except Exception as e: return ErrorArtifact("Error getting page content: " + str(e)) diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 03725f59d..77e3f3b0a 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -8,7 +8,6 @@ from .futures import execute_futures_dict, execute_futures_list, execute_futures_list_dict from .token_counter import TokenCounter from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively -from .file_utils import load_file, load_files from .hash import str_to_hash from .import_utils import import_optional_dependency from .import_utils import is_dependency_installed @@ -17,6 +16,7 @@ from .deprecation import deprecation_warn from .structure_visualizer import StructureVisualizer from .reference_utils import references_from_artifacts +from .file_utils import get_mime_type def minify_json(value: str) -> str: @@ -44,8 +44,7 @@ def minify_json(value: str) -> str: "Stream", "load_artifact_from_memory", "deprecation_warn", - "load_file", - "load_files", "StructureVisualizer", "references_from_artifacts", + "get_mime_type", ] diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py index 19c9f699c..0dbbbc093 100644 --- a/griptape/utils/file_utils.py +++ b/griptape/utils/file_utils.py @@ -1,38 +1,16 @@ -from __future__ import annotations +import mimetypes -from concurrent import futures -from pathlib import Path -from typing import Optional +import filetype -import griptape.utils as utils +def get_mime_type(file_path: str) -> str: + filetype_guess = filetype.guess(file_path) -def load_file(path: str) -> bytes: - """Load a file from the given path and return its content as bytes. - - Args: - path (str): The path to the file to load. - - Returns: - The content of the file. - """ - return Path(path).read_bytes() - - -def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, bytes]: - """Load multiple files concurrently and return a dictionary of their content. - - Args: - paths: The paths to the files to load. - futures_executor: The executor to use for concurrent loading. If None, a new ThreadPoolExecutor will be created. - - Returns: - A dictionary where the keys are a hash of the path and the values are the content of the files. - """ - if futures_executor is None: - futures_executor = futures.ThreadPoolExecutor() - - with futures_executor as executor: - return utils.execute_futures_dict( - {utils.str_to_hash(str(path)): executor.submit(load_file, path) for path in paths}, - ) + if filetype_guess is None: + type_, _ = mimetypes.guess_type(file_path) + if type_ is None: + return "application/octet-stream" + else: + return type_ + else: + return filetype_guess.mime diff --git a/mkdocs.yml b/mkdocs.yml index f065a0d01..0be4ec7e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -131,8 +131,8 @@ nav: - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" - - Chunkers: "griptape-framework/data/chunkers.md" - Loaders: "griptape-framework/data/loaders.md" + - Chunkers: "griptape-framework/data/chunkers.md" - Misc: - Events: "griptape-framework/misc/events.md" - Tokenizers: "griptape-framework/misc/tokenizers.md" diff --git a/poetry.lock b/poetry.lock index 34e3be20f..b0d38fc42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1586,7 +1586,7 @@ typing = ["typing-extensions (>=4.8)"] name = "filetype" version = "1.2.0" description = "Infer file type and MIME type of any file/buffer. No external dependencies." -optional = true +optional = false python-versions = "*" files = [ {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, @@ -6945,7 +6945,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6989,8 +6989,6 @@ drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-search-duckduckgo = ["duckduckgo-search"] drivers-web-search-exa = ["exa-py"] drivers-web-search-tavily = ["tavily-python"] -loaders-audio = ["filetype"] -loaders-dataframe = ["pandas"] loaders-email = ["mail-parser"] loaders-image = ["pillow"] loaders-pdf = ["pypdf"] @@ -6999,4 +6997,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4fe9e2ded897a0af4b9019926eb93c2d3a80cce07ab4707bfe59a6bd24124ce3" +content-hash = "ca72d32879b2af60bd30f0401c0e11de24e4dcf7a9eadd06a867b3ca1db85d92" diff --git a/pyproject.toml b/pyproject.toml index ebc4daa1a..5835c03cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ numpy = "^1.26.4" stringcase = "^1.2.0" docker = "^7.1.0" requests = "^2.32.0" +filetype = "^1.2" # drivers cohere = { version = "^5.5.4", optional = true } @@ -69,7 +70,6 @@ pandas = {version = "^1.3", optional = true} pypdf = {version = "^5.0.1", optional = true} pillow = {version = "^10.2.0", optional = true} mail-parser = {version = "^3.15.0", optional = true} -filetype = {version = "^1.2", optional = true} [tool.poetry.extras] drivers-prompt-cohere = ["cohere"] @@ -145,11 +145,9 @@ drivers-observability-datadog = [ drivers-image-generation-huggingface = ["diffusers", "pillow"] -loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] loaders-email = ["mail-parser"] -loaders-audio = ["filetype"] loaders-sql = ["sqlalchemy"] all = [ @@ -194,7 +192,6 @@ all = [ "pandas", "pypdf", "mail-parser", - "filetype", ] [tool.poetry.group.test] diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 0c29c1ebb..2240dee58 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -5,9 +5,9 @@ import pytest from moto import mock_aws -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import AmazonS3FileManagerDriver -from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -154,8 +154,7 @@ def test_list_files_failure(self, workdir, path, expected, driver): def test_load_file(self, driver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -185,9 +184,8 @@ def test_load_file_failure(self, workdir, path, expected, driver): def test_load_file_with_encoding(self, driver): artifact = driver.load_file("resources/test.txt") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) + assert isinstance(artifact, BlobArtifact) + assert artifact.encoding == "utf-8" @pytest.mark.parametrize( ("workdir", "path", "content"), @@ -240,9 +238,7 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_c def test_save_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -253,9 +249,7 @@ def test_save_file_with_encoding(self, session, bucket, get_s3_value): def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, loaders={"txt": TextLoader(encoding="ascii")}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -264,13 +258,10 @@ def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): assert get_s3_value(expected_s3_key) == "foobar" assert result.value == "Successfully saved file" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.load_file(path) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, TextArtifact) + assert result.encoding == "ascii" diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index 394a838a3..99f0285bc 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -4,9 +4,9 @@ import pytest -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader class TestLocalFileManagerDriver: @@ -127,8 +127,7 @@ def test_list_files_failure(self, workdir, path, expected, temp_dir, driver): def test_load_file(self, driver: LocalFileManagerDriver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -156,23 +155,6 @@ def test_load_file_failure(self, workdir, path, expected, temp_dir, driver): with pytest.raises(expected): driver.load_file(path) - def test_load_file_with_encoding(self, driver: LocalFileManagerDriver): - artifact = driver.load_file("resources/test.txt") - - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) - - def test_load_file_with_encoding_failure(self, driver): - driver = LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, - workdir=os.path.normpath(os.path.abspath(os.path.dirname(__file__) + "../../../../")), - ) - - with pytest.raises(UnicodeDecodeError): - driver.load_file("resources/bitcoin.pdf") - @pytest.mark.parametrize( ("workdir", "path", "content"), [ @@ -224,25 +206,24 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver): driver.save_file(path, "foobar") def test_save_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" def test_save_and_load_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.load_file(os.path.join("test", "foobar.txt")) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, TextArtifact) + assert result.encoding == "ascii" def _to_driver_workdir(self, temp_dir, workdir): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_vector_store_driver.py similarity index 95% rename from tests/unit/drivers/vector/test_base_local_vector_store_driver.py rename to tests/unit/drivers/vector/test_base_vector_store_driver.py index 20a3e2b50..8438568ad 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_vector_store_driver.py @@ -4,12 +4,13 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.drivers import BaseVectorStoreDriver -class BaseLocalVectorStoreDriver(ABC): +class TestBaseVectorStoreDriver(ABC): @pytest.fixture() @abstractmethod - def driver(self): ... + def driver(self, *args, **kwargs) -> BaseVectorStoreDriver: ... def test_upsert(self, driver): namespace = driver.upsert_text_artifact(TextArtifact(id="foo1", value="foobar")) diff --git a/tests/unit/drivers/vector/test_local_vector_store_driver.py b/tests/unit/drivers/vector/test_local_vector_store_driver.py index 2504b2486..9722bb25d 100644 --- a/tests/unit/drivers/vector/test_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_local_vector_store_driver.py @@ -3,10 +3,10 @@ from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_vector_store_driver import TestBaseVectorStoreDriver -class TestLocalVectorStoreDriver(BaseLocalVectorStoreDriver): +class TestLocalVectorStoreDriver(TestBaseVectorStoreDriver): @pytest.fixture() def driver(self): return LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py index c130858b5..9fa725e80 100644 --- a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py @@ -6,10 +6,10 @@ from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_vector_store_driver import TestBaseVectorStoreDriver -class TestPersistentLocalVectorStoreDriver(BaseLocalVectorStoreDriver): +class TestPersistentLocalVectorStoreDriver(TestBaseVectorStoreDriver): @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py index 69e334c7f..b2dbb613d 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py @@ -18,7 +18,7 @@ def test_run(self): embedding_driver = MockEmbeddingDriver() module = TextLoaderRetrievalRagModule( - loader=WebLoader(max_tokens=MAX_TOKENS, embedding_driver=embedding_driver), + loader=WebLoader(), vector_store_driver=LocalVectorStoreDriver(embedding_driver=embedding_driver), source="https://www.griptape.ai", ) diff --git a/tests/unit/loaders/conftest.py b/tests/unit/loaders/conftest.py index 1f698738a..0bbf839b8 100644 --- a/tests/unit/loaders/conftest.py +++ b/tests/unit/loaders/conftest.py @@ -1,4 +1,5 @@ import os +from io import BytesIO, StringIO from pathlib import Path import pytest @@ -14,15 +15,15 @@ def create_source(resource_path: str) -> Path: @pytest.fixture() def bytes_from_resource_path(path_from_resource_path): - def create_source(resource_path: str) -> bytes: - return Path(path_from_resource_path(resource_path)).read_bytes() + def create_source(resource_path: str) -> BytesIO: + return BytesIO(Path(path_from_resource_path(resource_path)).read_bytes()) return create_source @pytest.fixture() def str_from_resource_path(path_from_resource_path): - def test_csv_str(resource_path: str) -> str: - return Path(path_from_resource_path(resource_path)).read_text() + def test_csv_str(resource_path: str) -> StringIO: + return StringIO(Path(path_from_resource_path(resource_path)).read_text()) return test_csv_str diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index b7ebdd912..7b3516722 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -9,9 +9,9 @@ class TestAudioLoader: def loader(self): return AudioLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) def test_load(self, resource_path, mime_type, loader, create_source): @@ -21,7 +21,6 @@ def test_load(self, resource_path, mime_type, loader, create_source): assert isinstance(artifact, AudioArtifact) assert artifact.mime_type == mime_type - assert len(artifact.value) > 0 def test_load_collection(self, create_source, loader): resource_paths = ["sentences.wav", "sentences2.wav"] diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index 4812e669c..2042381bc 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -11,7 +11,7 @@ def loader(self, request): kwargs = {"encoding": encoding} if encoding is not None else {} return BlobLoader(**kwargs) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index 7af409152..9c5d0febb 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,9 +1,6 @@ -import json - import pytest from griptape.loaders.csv_loader import CsvLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestCsvLoader: @@ -11,15 +8,15 @@ class TestCsvLoader: def loader(self, request): encoding = request.param if encoding is None: - return CsvLoader(embedding_driver=MockEmbeddingDriver()) + return CsvLoader() else: - return CsvLoader(embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return CsvLoader(encoding=encoding) @pytest.fixture() def loader_with_pipe_delimiter(self): - return CsvLoader(embedding_driver=MockEmbeddingDriver(), delimiter="|") + return CsvLoader(delimiter="|") - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) @@ -31,7 +28,6 @@ def test_load(self, loader, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] assert first_artifact.value == "Foo: foo1\nBar: bar1" - assert first_artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): source = create_source("test-pipe.csv") @@ -41,7 +37,6 @@ def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): assert len(artifacts) == 10 first_artifact = artifacts[0] assert first_artifact.value == "Bar: foo1\nFoo: bar1" - assert first_artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): resource_paths = ["test-1.csv", "test-2.csv"] @@ -53,16 +48,5 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys assert collection[loader.to_key(sources[0])][0].value == "Foo: foo1\nBar: bar1" - assert collection[loader.to_key(sources[0])][0].embedding == [0, 1] assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" - assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] - - def test_formatter_fn(self, loader, create_source): - loader.formatter_fn = lambda value: json.dumps(value) - source = create_source("test-1.csv") - - artifacts = loader.load(source) - - assert len(artifacts) == 10 - assert artifacts[0].value == '{"Foo": "foo1", "Bar": "bar1"}' diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index ade062743..1812dc531 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -1,14 +1,16 @@ from __future__ import annotations import email -from email import message -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from griptape.artifacts import ListArtifact from griptape.loaders import EmailLoader +if TYPE_CHECKING: + from email.message import Message + class TestEmailLoader: @pytest.fixture(autouse=True) @@ -127,23 +129,21 @@ def to_fetch_message(body: str, content_type: Optional[str]): return to_fetch_response(to_message(body, content_type)) -def to_fetch_response(message: message): +def to_fetch_response(message: Message): return (None, ((None, message.as_bytes()),)) -def to_message(body: str, content_type: Optional[str]) -> message: +def to_message(body: str, content_type: Optional[str]) -> Message: message = email.message_from_string(body) if content_type: message.set_type(content_type) return message -def to_value_set(artifact_or_dict: ListArtifact | dict[str, ListArtifact]) -> set[str]: - if isinstance(artifact_or_dict, ListArtifact): - return {value.value for value in artifact_or_dict.value} - elif isinstance(artifact_or_dict, dict): - return { - text_artifact.value for list_artifact in artifact_or_dict.values() for text_artifact in list_artifact.value - } +def to_value_set(artifacts: ListArtifact | dict[str, ListArtifact]) -> set[str]: + if isinstance(artifacts, dict): + return set( + {text_artifact.value for list_artifact in artifacts.values() for text_artifact in list_artifact.value} + ) else: - raise Exception + return {artifact.value for artifact in artifacts.value} diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index 7093894b0..9c491fb88 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -13,9 +13,9 @@ def loader(self): def png_loader(self): return ImageLoader(format="png") - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize( ("resource_path", "mime_type"), diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 376a9579a..45027b95c 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -1,29 +1,25 @@ import pytest from griptape.loaders import PdfLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestPdfLoader: @pytest.fixture() def loader(self): - return PdfLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return PdfLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("bitcoin.pdf") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 156 - assert artifacts[0].value.startswith("Bitcoin: A Peer-to-Peer") - assert artifacts[-1].value.endswith('its applications," 1957.\n9') - assert artifacts[0].embedding == [0, 1] + assert len(artifact) == 9 + assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact[-1].value.endswith('its applications," 1957.\n9') def test_load_collection(self, loader, create_source): resource_paths = ["bitcoin.pdf", "bitcoin-2.pdf"] @@ -37,7 +33,6 @@ def test_load_collection(self, loader, create_source): for key in keys: artifact = collection[key] - assert len(artifact) == 156 + assert len(artifact) == 9 assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") assert artifact[-1].value.endswith('its applications," 1957.\n9') - assert artifact[0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index 2ff6c7faf..4d33b634a 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -3,7 +3,6 @@ from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -16,7 +15,6 @@ def loader(self): engine_url="sqlite:///:memory:", create_engine_params={"connect_args": {"check_same_thread": False}, "poolclass": StaticPool}, ), - embedding_driver=MockEmbeddingDriver(), ) sql_loader.sql_driver.execute_query( @@ -35,14 +33,12 @@ def loader(self): return sql_loader def test_load(self, loader): - artifacts = loader.load("SELECT * FROM test_table;") + artifact = loader.load("SELECT * FROM test_table;") - assert len(artifacts) == 3 - assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" - assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" - - assert artifacts[0].embedding == [0, 1] + assert len(artifact) == 3 + assert artifact[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifact[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" + assert artifact[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" def test_load_collection(self, loader): sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"] @@ -55,4 +51,3 @@ def test_load_collection(self, loader): assert artifacts[loader.to_key(sources[0])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" assert artifacts[loader.to_key(sources[1])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert list(artifacts.values())[0][0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 07527f9e6..a435610ef 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -1,9 +1,6 @@ import pytest from griptape.loaders.text_loader import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestTextLoader: @@ -11,23 +8,21 @@ class TestTextLoader: def loader(self, request): encoding = request.param if encoding is None: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return TextLoader() else: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return TextLoader(encoding=encoding) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("test.txt") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 39 - assert artifacts[0].value.startswith("foobar foobar foobar") - assert artifacts[0].encoding == loader.encoding - assert artifacts[0].embedding == [0, 1] + assert artifact.value.startswith("foobar foobar foobar") + assert artifact.encoding == loader.encoding def test_load_collection(self, loader, create_source): resource_paths = ["test.txt"] @@ -39,9 +34,10 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys key = next(iter(keys)) - artifacts = collection[key] - assert len(artifacts) == 39 + artifact = collection[key] - artifact = artifacts[0] - assert artifact.embedding == [0, 1] assert artifact.encoding == loader.encoding + + def test_load_deprecated_bytes(self, loader): + with pytest.warns(DeprecationWarning): + loader.load(b"test.txt") diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index f7cccb666..d6e958042 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,7 +1,6 @@ import pytest from griptape.loaders import WebLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -13,15 +12,12 @@ def _mock_trafilatura_fetch_url(self, mocker): @pytest.fixture() def loader(self): - return WebLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return WebLoader() def test_load(self, loader): - artifacts = loader.load("https://github.com/griptape-ai/griptape") + artifact = loader.load("https://github.com/griptape-ai/griptape") - assert len(artifacts) == 1 - assert "foobar" in artifacts[0].value.lower() - - assert artifacts[0].embedding == [0, 1] + assert "foobar" in artifact.value.lower() def test_load_exception(self, mocker, loader): mocker.patch("trafilatura.fetch_url", side_effect=Exception("error")) @@ -38,9 +34,7 @@ def test_load_collection(self, loader): loader.to_key("https://github.com/griptape-ai/griptape"), loader.to_key("https://github.com/griptape-ai/griptape-docs"), ] - assert "foobar" in [a.value for artifact_list in artifacts.values() for a in artifact_list][0].lower() - - assert list(artifacts.values())[0][0].embedding == [0, 1] + assert "foobar" in [a.value for a in artifacts.values()] def test_empty_page_string_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value="") diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 469918a02..4e035bdee 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -7,7 +7,6 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader from griptape.tools import FileManagerTool from tests.utils import defaults @@ -36,7 +35,7 @@ def test_load_files_from_disk(self, file_manager): result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/bitcoin.pdf"]}}) assert isinstance(result, ListArtifact) - assert len(result.value) == 4 + assert len(result.value) == 9 def test_load_files_from_disk_with_encoding(self, file_manager): result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/test.txt"]}}) @@ -48,8 +47,7 @@ def test_load_files_from_disk_with_encoding(self, file_manager): def test_load_files_from_disk_with_encoding_failure(self): file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, + encoding="utf-8", workdir=os.path.abspath(os.path.dirname(__file__)), ) ) @@ -116,9 +114,7 @@ def test_save_content_to_file(self, temp_dir): assert result.value == "Successfully saved file" def test_save_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), workdir=temp_dir) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -127,9 +123,7 @@ def test_save_content_to_file_with_encoding(self, temp_dir): assert result.value == "Successfully saved file" def test_save_and_load_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -137,11 +131,7 @@ def test_save_and_load_content_to_file_with_encoding(self, temp_dir): assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver( - default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir - ) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.load_files_from_disk({"values": {"paths": [os.path.join("test", "foobar.txt")]}}) assert isinstance(result, ListArtifact) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py deleted file mode 100644 index 00df6958d..000000000 --- a/tests/unit/utils/test_file_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -from concurrent import futures - -from griptape import utils -from griptape.loaders import TextLoader - -MAX_TOKENS = 50 - - -class TestFileUtils: - def test_load_file(self): - dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")) - - assert file.decode("utf-8").startswith("foobar foobar foobar") - - def test_load_files(self): - dirname = os.path.dirname(__file__) - sources = ["resources/foobar-many.txt", "resources/foobar-many.txt", "resources/small.png"] - sources = [os.path.join(dirname, "../../", source) for source in sources] - files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1)) - assert len(files) == 2 - - test_file = files[utils.str_to_hash(sources[0])] - assert test_file.decode("utf-8").startswith("foobar foobar foobar") - - small_file = files[utils.str_to_hash(sources[2])] - assert len(small_file) == 97 - assert small_file[:8] == b"\x89PNG\r\n\x1a\n" - - def test_load_file_with_loader(self): - dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) - - assert len(artifacts) == 39 - assert isinstance(artifacts, list) - assert artifacts[0].value.startswith("foobar foobar foobar") - - def test_load_files_with_loader(self): - dirname = os.path.dirname(__file__) - sources = ["resources/foobar-many.txt"] - sources = [os.path.join(dirname, "../../", source) for source in sources] - files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS) - collection = loader.load_collection(list(files.values())) - - test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] - assert len(test_file_artifacts) == 39 - assert isinstance(test_file_artifacts, list) - assert test_file_artifacts[0].value.startswith("foobar foobar foobar")