From 62c7f7a3e522a2a05c6b1ce5eb272cceccebbc16 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 14 Aug 2024 14:13:11 -0700 Subject: [PATCH] Standardize loaders interface --- CHANGELOG.md | 2 - .../src/load_query_and_chat_marqo_1.py | 4 +- docs/examples/src/query_webpage_1.py | 6 +- docs/examples/src/query_webpage_astra_db_1.py | 7 +- docs/examples/src/talk_to_a_pdf_1.py | 4 +- docs/examples/src/talk_to_a_webpage_1.py | 4 +- docs/griptape-framework/data/loaders.md | 13 ---- docs/griptape-framework/data/src/loaders_1.py | 11 ++-- .../griptape-framework/data/src/loaders_10.py | 7 +- docs/griptape-framework/data/src/loaders_3.py | 13 ++-- docs/griptape-framework/data/src/loaders_4.py | 13 ---- docs/griptape-framework/data/src/loaders_7.py | 8 +-- docs/griptape-framework/data/src/loaders_8.py | 13 ++-- .../drivers/src/vector_store_drivers_1.py | 6 +- .../drivers/src/vector_store_drivers_10.py | 6 +- .../drivers/src/vector_store_drivers_11.py | 6 +- .../drivers/src/vector_store_drivers_3.py | 6 +- .../drivers/src/vector_store_drivers_4.py | 6 +- .../drivers/src/vector_store_drivers_5.py | 10 ++- .../drivers/src/vector_store_drivers_6.py | 6 +- .../drivers/src/vector_store_drivers_7.py | 6 +- .../drivers/src/vector_store_drivers_8.py | 6 +- .../drivers/src/vector_store_drivers_9.py | 6 +- .../engines/src/audio_engines_2.py | 3 +- .../engines/src/rag_engines_1.py | 7 +- .../engines/src/summary_engines_1.py | 6 +- .../structures/src/tasks_18.py | 3 +- .../official-tools/src/vector_store_tool_1.py | 4 +- griptape/artifacts/__init__.py | 1 + griptape/artifacts/blob_artifact.py | 2 +- griptape/artifacts/json_artifact.py | 6 +- .../file_manager/base_file_manager_driver.py | 51 ++++----------- .../file_manager/local_file_manager_driver.py | 8 +-- .../web_scraper/base_web_scraper_driver.py | 10 ++- .../markdownify_web_scraper_driver.py | 65 ++++++++++--------- .../web_scraper/proxy_web_scraper_driver.py | 8 ++- .../trafilatura_web_scraper_driver.py | 11 +++- .../text_loader_retrieval_rag_module.py | 9 ++- griptape/loaders/__init__.py | 12 ++-- griptape/loaders/audio_loader.py | 16 ++--- griptape/loaders/base_file_loader.py | 24 +++++++ griptape/loaders/base_loader.py | 21 ++++-- griptape/loaders/base_text_loader.py | 65 ------------------- griptape/loaders/blob_loader.py | 10 +-- griptape/loaders/csv_loader.py | 38 ++--------- griptape/loaders/dataframe_loader.py | 36 ---------- griptape/loaders/email_loader.py | 35 ++++++---- griptape/loaders/image_loader.py | 31 ++------- griptape/loaders/pdf_loader.py | 30 ++++----- griptape/loaders/sql_loader.py | 22 +++---- griptape/loaders/text_loader.py | 55 ++++------------ griptape/loaders/web_loader.py | 21 +++--- griptape/tools/file_manager/tool.py | 23 ++++++- griptape/tools/sql/tool.py | 5 +- griptape/tools/web_scraper/tool.py | 6 +- griptape/utils/__init__.py | 5 +- griptape/utils/file_utils.py | 46 ++++--------- poetry.lock | 25 +++---- pyproject.toml | 5 +- .../test_amazon_s3_file_manager_driver.py | 29 +++------ .../test_local_file_manager_driver.py | 35 +++------- ...er.py => test_base_vector_store_driver.py} | 5 +- .../vector/test_local_vector_store_driver.py | 4 +- ...st_persistent_local_vector_store_driver.py | 4 +- .../test_text_loader_retrieval_rag_module.py | 2 +- tests/unit/loaders/conftest.py | 9 +-- tests/unit/loaders/test_audio_loader.py | 7 +- tests/unit/loaders/test_blob_loader.py | 2 +- tests/unit/loaders/test_csv_loader.py | 17 ++--- tests/unit/loaders/test_email_loader.py | 24 +++---- tests/unit/loaders/test_image_loader.py | 6 +- tests/unit/loaders/test_pdf_loader.py | 23 +++---- tests/unit/loaders/test_sql_loader.py | 13 ++-- tests/unit/loaders/test_text_loader.py | 22 ++----- tests/unit/loaders/test_web_loader.py | 14 ++-- tests/unit/tools/test_file_manager.py | 20 ++---- tests/unit/utils/test_file_utils.py | 51 --------------- 77 files changed, 446 insertions(+), 735 deletions(-) delete mode 100644 docs/griptape-framework/data/src/loaders_4.py create mode 100644 griptape/loaders/base_file_loader.py delete mode 100644 griptape/loaders/base_text_loader.py delete mode 100644 griptape/loaders/dataframe_loader.py rename tests/unit/drivers/vector/{test_base_local_vector_store_driver.py => test_base_vector_store_driver.py} (95%) delete mode 100644 tests/unit/utils/test_file_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 627ba90c6c..e5de54e3be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Parameter `meta: dict` on `BaseEvent`. -- `TableArtifact` for storing CSV data. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. @@ -40,7 +39,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], dict]`. This represents the runs and metadata. - **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: dict` as input. - **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`. -- **BREAKING**: `CsvLoader` now returns a `TableArtifact` instead of a `list[CsvRowArtifact]`. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. diff --git a/docs/examples/src/load_query_and_chat_marqo_1.py b/docs/examples/src/load_query_and_chat_marqo_1.py index cdcb376bb2..f318abb204 100644 --- a/docs/examples/src/load_query_and_chat_marqo_1.py +++ b/docs/examples/src/load_query_and_chat_marqo_1.py @@ -1,6 +1,7 @@ import os from griptape import utils +from griptape.chunkers import TextChunker from griptape.drivers import MarqoVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -25,11 +26,12 @@ # Load artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) # Upsert the artifacts into the vector store vector_store.upsert_text_artifacts( { - namespace: artifacts, + namespace: chunks, } ) diff --git a/docs/examples/src/query_webpage_1.py b/docs/examples/src/query_webpage_1.py index b9e3286d65..9cc0335665 100644 --- a/docs/examples/src/query_webpage_1.py +++ b/docs/examples/src/query_webpage_1.py @@ -1,13 +1,15 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])) -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) -for a in artifacts: +for a in chunks: vector_store.upsert_text_artifact(a, namespace="griptape") results = vector_store.query("creativity", count=3, namespace="griptape") diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 4590a6b595..aaf5e2fccd 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import ( AstraDbVectorStoreDriver, OpenAiChatPromptDriver, @@ -43,9 +44,9 @@ ), ) -artifacts = WebLoader(max_tokens=256).load(input_blogpost) - -vector_store_driver.upsert_text_artifacts({namespace: artifacts}) +artifacts = WebLoader().load(input_blogpost) +chunks = TextChunker().chunk(artifacts) +vector_store_driver.upsert_text_artifacts({namespace: chunks}) rag_tool = RagTool( description="A DataStax blog post", diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index 3c29f4c741..aa2b1f1537 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -1,5 +1,6 @@ import requests +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -31,8 +32,9 @@ ) artifacts = PdfLoader().load(response.content) +chunks = TextChunker().chunk(artifacts) -vector_store.upsert_text_artifacts({namespace: artifacts}) +vector_store.upsert_text_artifacts({namespace: chunks}) agent = Agent(tools=[rag_tool]) diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 3e973da2d7..5414c8769f 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -26,8 +27,9 @@ ) artifacts = WebLoader().load("https://en.wikipedia.org/wiki/Physics") +chunks = TextChunker().chunk(artifacts) -vector_store_driver.upsert_text_artifacts({namespace: artifacts}) +vector_store_driver.upsert_text_artifacts({namespace: chunks}) rag_tool = RagTool( description="Contains information about physics. " "Use it to answer any physics-related questions.", diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 0c0fc3eadc..f1826e67af 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -36,19 +36,6 @@ Can be used to load CSV files into [TextArtifact](../../reference/griptape/artif --8<-- "docs/griptape-framework/data/src/loaders_3.py" ``` - -## DataFrame - -!!! info - This driver requires the `loaders-dataframe` [extra](../index.md#extras). - -Can be used to load [pandas](https://pandas.pydata.org/) [DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)s into [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s: - -```python ---8<-- "docs/griptape-framework/data/src/loaders_4.py" -``` - - ## Text Used to load arbitrary text and text files: diff --git a/docs/griptape-framework/data/src/loaders_1.py b/docs/griptape-framework/data/src/loaders_1.py index 2b7f316134..3732d8ac60 100644 --- a/docs/griptape-framework/data/src/loaders_1.py +++ b/docs/griptape-framework/data/src/loaders_1.py @@ -2,18 +2,15 @@ from pathlib import Path from griptape.loaders import PdfLoader -from griptape.utils import load_file, load_files urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "attention.pdf") # Load a single PDF file -PdfLoader().load(Path("attention.pdf").read_bytes()) -# You can also use the load_file utility function -PdfLoader().load(load_file("attention.pdf")) +PdfLoader().load("attention.pdf") +# You can also pass a Path object +PdfLoader().load(Path("attention.pdf")) urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "CoT.pdf") # Load multiple PDF files -PdfLoader().load_collection([Path("attention.pdf").read_bytes(), Path("CoT.pdf").read_bytes()]) -# You can also use the load_files utility function -PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) +PdfLoader().load_collection([Path("attention.pdf"), Path("CoT.pdf")]) diff --git a/docs/griptape-framework/data/src/loaders_10.py b/docs/griptape-framework/data/src/loaders_10.py index 42a64c40ed..723fbcdf1f 100644 --- a/docs/griptape-framework/data/src/loaders_10.py +++ b/docs/griptape-framework/data/src/loaders_10.py @@ -1,10 +1,9 @@ from pathlib import Path from griptape.loaders import AudioLoader -from griptape.utils import load_file # Load an image from disk -audio_artifact = AudioLoader().load(Path("tests/resources/sentences.wav").read_bytes()) +AudioLoader().load("tests/resources/sentences.wav") -# You can also use the load_file utility function -AudioLoader().load(load_file("tests/resources/sentences.wav")) +# You can also pass a Path object +AudioLoader().load(Path("tests/resources/sentences.wav")) diff --git a/docs/griptape-framework/data/src/loaders_3.py b/docs/griptape-framework/data/src/loaders_3.py index 35af0fdfc9..3bc3ceb817 100644 --- a/docs/griptape-framework/data/src/loaders_3.py +++ b/docs/griptape-framework/data/src/loaders_3.py @@ -1,16 +1,11 @@ from pathlib import Path from griptape.loaders import CsvLoader -from griptape.utils import load_file, load_files # Load a single CSV file -CsvLoader().load(Path("tests/resources/cities.csv").read_text()) -# You can also use the load_file utility function -CsvLoader().load(load_file("tests/resources/cities.csv")) +CsvLoader().load("tests/resources/cities.csv") +# You can also pass a Path object +CsvLoader().load(Path("tests/resources/cities.csv")) # Load multiple CSV files -CsvLoader().load_collection( - [Path("tests/resources/cities.csv").read_text(), Path("tests/resources/addresses.csv").read_text()] -) -# You can also use the load_files utility function -CsvLoader().load_collection(list(load_files(["tests/resources/cities.csv", "tests/resources/addresses.csv"]).values())) +CsvLoader().load_collection([Path("tests/resources/cities.csv"), "tests/resources/addresses.csv"]) diff --git a/docs/griptape-framework/data/src/loaders_4.py b/docs/griptape-framework/data/src/loaders_4.py deleted file mode 100644 index 8d5883adf2..0000000000 --- a/docs/griptape-framework/data/src/loaders_4.py +++ /dev/null @@ -1,13 +0,0 @@ -import urllib.request - -import pandas as pd - -from griptape.loaders import DataFrameLoader - -urllib.request.urlretrieve("https://people.sc.fsu.edu/~jburkardt/data/csv/cities.csv", "cities.csv") - -DataFrameLoader().load(pd.read_csv("cities.csv")) - -urllib.request.urlretrieve("https://people.sc.fsu.edu/~jburkardt/data/csv/addresses.csv", "addresses.csv") - -DataFrameLoader().load_collection([pd.read_csv("cities.csv"), pd.read_csv("addresses.csv")]) diff --git a/docs/griptape-framework/data/src/loaders_7.py b/docs/griptape-framework/data/src/loaders_7.py index 6857886e8d..177471fd2f 100644 --- a/docs/griptape-framework/data/src/loaders_7.py +++ b/docs/griptape-framework/data/src/loaders_7.py @@ -1,9 +1,9 @@ from pathlib import Path from griptape.loaders import ImageLoader -from griptape.utils import load_file # Load an image from disk -disk_image_artifact = ImageLoader().load(Path("tests/resources/mountain.png").read_bytes()) -# You can also use the load_file utility function -ImageLoader().load(load_file("tests/resources/mountain.png")) +ImageLoader().load("tests/resources/mountain.png") + +# You can also pass a Path object +ImageLoader().load(Path("tests/resources/mountain.png")) diff --git a/docs/griptape-framework/data/src/loaders_8.py b/docs/griptape-framework/data/src/loaders_8.py index e85992d45c..d54c312467 100644 --- a/docs/griptape-framework/data/src/loaders_8.py +++ b/docs/griptape-framework/data/src/loaders_8.py @@ -1,16 +1,11 @@ from pathlib import Path from griptape.loaders import ImageLoader -from griptape.utils import load_file, load_files # Load a single image in BMP format -image_artifact_jpeg = ImageLoader(format="bmp").load(Path("tests/resources/mountain.png").read_bytes()) -# You can also use the load_file utility function -ImageLoader(format="bmp").load(load_file("tests/resources/mountain.png")) +ImageLoader(format="bmp").load("tests/resources/mountain.png") +# You can also pass a Path object +ImageLoader(format="bmp").load(Path("tests/resources/mountain.png")) # Load multiple images in BMP format -ImageLoader().load_collection( - [Path("tests/resources/mountain.png").read_bytes(), Path("tests/resources/cow.png").read_bytes()] -) -# You can also use the load_files utility function -ImageLoader().load_collection(list(load_files(["tests/resources/mountain.png", "tests/resources/cow.png"]).values())) +ImageLoader().load_collection([Path("tests/resources/mountain.png"), "tests/resources/cow.png"]) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py index 7f7e98e134..33f611a71f 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -9,10 +10,11 @@ vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index 39a21121d5..f822949133 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, QdrantVectorStoreDriver from griptape.loaders import WebLoader @@ -19,7 +20,8 @@ ) # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Recreate Qdrant collection vector_store_driver.client.recreate_collection( @@ -28,7 +30,7 @@ ) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py index a8d9ceed12..739a52c632 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import AstraDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -20,10 +21,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py index 559eaec5a7..c8ff0e0b5a 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, PineconeVectorStoreDriver from griptape.loaders import WebLoader @@ -14,10 +15,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=100).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] +[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in chunks] results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py index f2f0091a09..ad15f0744d 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import MarqoVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -19,12 +20,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py index 7649579c75..33d541f2f1 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import MongoDbAtlasVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -25,14 +26,11 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver -vector_store_driver.upsert_text_artifacts( - { - "griptape": artifacts, - } -) +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) results = vector_store_driver.query(query="What is griptape?") diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py index 78a7cc3e64..6a8a8bb043 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import AzureMongoDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -25,12 +26,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py index d34ff86493..7d14504bb4 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, RedisVectorStoreDriver from griptape.loaders import WebLoader @@ -15,12 +16,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py index 18e50a397c..0ecb497234 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py @@ -2,6 +2,7 @@ import boto3 +from griptape.chunkers import TextChunker from griptape.drivers import AmazonOpenSearchVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -16,12 +17,13 @@ ) # Load Artifacts from the web -artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=200).chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py index ad5abf9322..9feb47761b 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py @@ -1,5 +1,6 @@ import os +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiEmbeddingDriver, PgVectorVectorStoreDriver from griptape.loaders import WebLoader @@ -22,12 +23,13 @@ vector_store_driver.setup() # Load Artifacts from the web -artifacts = WebLoader().load("https://www.griptape.ai") +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifact) # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/engines/src/audio_engines_2.py b/docs/griptape-framework/engines/src/audio_engines_2.py index c04b466f82..92c87d6387 100644 --- a/docs/griptape-framework/engines/src/audio_engines_2.py +++ b/docs/griptape-framework/engines/src/audio_engines_2.py @@ -1,7 +1,6 @@ from griptape.drivers import OpenAiAudioTranscriptionDriver from griptape.engines import AudioTranscriptionEngine from griptape.loaders import AudioLoader -from griptape.utils import load_file driver = OpenAiAudioTranscriptionDriver(model="whisper-1") @@ -9,5 +8,5 @@ audio_transcription_driver=driver, ) -audio_artifact = AudioLoader().load(load_file("tests/resources/sentences.wav")) +audio_artifact = AudioLoader().load("tests/resources/sentences.wav") engine.run(audio_artifact) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index a8a9cc06bf..6ad28545ff 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule @@ -8,12 +9,12 @@ prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) -artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") - +artifact = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker(max_tokens=500).chunk(artifact) vector_store.upsert_text_artifacts( { - "griptape": artifacts, + "griptape": chunks, } ) diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index b5adf2a5a9..201aadfa42 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -1,5 +1,6 @@ import requests +from griptape.chunkers import TextChunker from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import PromptSummaryEngine from griptape.loaders import PdfLoader @@ -9,8 +10,9 @@ prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), ) -artifacts = PdfLoader().load(response.content) +artifact = PdfLoader().load(response.content) +chunks = TextChunker().chunk(artifact) -text = "\n\n".join([a.value for a in artifacts]) +text = "\n\n".join([a.value for a in chunks]) engine.summarize_text(text) diff --git a/docs/griptape-framework/structures/src/tasks_18.py b/docs/griptape-framework/structures/src/tasks_18.py index 08ece5a92a..0d3312d4cf 100644 --- a/docs/griptape-framework/structures/src/tasks_18.py +++ b/docs/griptape-framework/structures/src/tasks_18.py @@ -3,12 +3,11 @@ from griptape.loaders import AudioLoader from griptape.structures import Pipeline from griptape.tasks import AudioTranscriptionTask -from griptape.utils import load_file driver = OpenAiAudioTranscriptionDriver(model="whisper-1") task = AudioTranscriptionTask( - input=lambda _: AudioLoader().load(load_file("tests/resources/sentences2.wav")), + input=lambda _: AudioLoader().load("tests/resources/sentences2.wav"), audio_transcription_engine=AudioTranscriptionEngine( audio_transcription_driver=driver, ), diff --git a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py index 26c87e2552..bdb60d98b3 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py @@ -1,3 +1,4 @@ +from griptape.chunkers import TextChunker from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -8,8 +9,9 @@ ) artifacts = WebLoader().load("https://www.griptape.ai") +chunks = TextChunker().chunk(artifacts) -vector_store_driver.upsert_text_artifacts({"griptape": artifacts}) +vector_store_driver.upsert_text_artifacts({"griptape": chunks}) vector_db = VectorStoreTool( description="This DB has information about the Griptape Python framework", vector_store_driver=vector_store_driver, diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 2099a78de2..d56cfa95e3 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -7,6 +7,7 @@ from .audio_artifact import AudioArtifact from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact +from .table_artifact import TableArtifact from .generic_artifact import GenericArtifact diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 26e657005e..29e35b53d6 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -2,7 +2,7 @@ from typing import Any -from attrs import Converter, define, field +from attrs import define, field from griptape.artifacts import BaseSystemArtifact diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index fd514ce9c6..b3d45e4d93 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -3,9 +3,11 @@ import json from typing import Any, Union -from attrs import Converter, define, field +from attrs import define, field -from griptape.artifacts.text_artifact import TextArtifact +from griptape.artifacts import BaseArtifact + +Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] def value_to_json(value: Any) -> Json: diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index dce5388122..c904f15327 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -1,11 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional -from attrs import Factory, define, field +from attrs import define, field -import griptape.loaders as loaders -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, InfoArtifact, TextArtifact @define @@ -17,57 +17,28 @@ class BaseFileManagerDriver(ABC): loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts. """ - default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True) - loaders: dict[str, loaders.BaseLoader] = field( - default=Factory( - lambda: { - "pdf": loaders.PdfLoader(), - "csv": loaders.CsvLoader(), - "txt": loaders.TextLoader(), - "html": loaders.TextLoader(), - "json": loaders.TextLoader(), - "yaml": loaders.TextLoader(), - "xml": loaders.TextLoader(), - "png": loaders.ImageLoader(), - "jpg": loaders.ImageLoader(), - "jpeg": loaders.ImageLoader(), - "webp": loaders.ImageLoader(), - "gif": loaders.ImageLoader(), - "bmp": loaders.ImageLoader(), - "tiff": loaders.ImageLoader(), - }, - ), - kw_only=True, - ) + workdir: str = field(kw_only=True) + encoding: Optional[str] = field(default=None, kw_only=True) - def list_files(self, path: str) -> TextArtifact | ErrorArtifact: + def list_files(self, path: str) -> TextArtifact: entries = self.try_list_files(path) return TextArtifact("\n".join(list(entries))) @abstractmethod def try_list_files(self, path: str) -> list[str]: ... - def load_file(self, path: str) -> BaseArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - source = self.try_load_file(path) - result = loader.load(source) - - if isinstance(result, BaseArtifact): - return result + def load_file(self, path: str) -> BlobArtifact | TextArtifact: + if self.encoding is None: + return BlobArtifact(self.try_load_file(path)) else: - return ListArtifact(result) + return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding) @abstractmethod def try_load_file(self, path: str) -> bytes: ... def save_file(self, path: str, value: bytes | str) -> InfoArtifact: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - encoding = None if loader is None else loader.encoding - if isinstance(value, str): - value = value.encode() if encoding is None else value.encode(encoding=encoding) + value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding) elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index a6f1f0726c..b383ff7d76 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from typing import Optional from attrs import Attribute, Factory, define, field @@ -16,11 +17,11 @@ class LocalFileManagerDriver(BaseFileManagerDriver): workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory. """ - workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True) + workdir: Optional[str] = field(default=Factory(lambda: os.getcwd()), kw_only=True) @workdir.validator # pyright: ignore[reportAttributeAccessIssue] def validate_workdir(self, _: Attribute, workdir: str) -> None: - if not Path(workdir).is_absolute(): + if self.workdir is not None and not Path(workdir).is_absolute(): raise ValueError("Workdir must be an absolute path") def try_list_files(self, path: str) -> list[str]: @@ -41,8 +42,7 @@ def try_save_file(self, path: str, value: bytes) -> None: Path(full_path).write_bytes(value) def _full_path(self, path: str) -> str: - path = path.lstrip("/") - full_path = os.path.join(self.workdir, path) + full_path = path if self.workdir is None else os.path.join(self.workdir, path.lstrip("/")) # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_slash = path.endswith("/") diff --git a/griptape/drivers/web_scraper/base_web_scraper_driver.py b/griptape/drivers/web_scraper/base_web_scraper_driver.py index ae39f8eacf..0c33f3713c 100644 --- a/griptape/drivers/web_scraper/base_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/base_web_scraper_driver.py @@ -4,5 +4,13 @@ class BaseWebScraperDriver(ABC): + def scrape_url(self, url: str) -> TextArtifact: + source = self.fetch_url(url) + + return self.extract_page(source) + + @abstractmethod + def fetch_url(self, url: str) -> str: ... + @abstractmethod - def scrape_url(self, url: str) -> TextArtifact: ... + def extract_page(self, page: str) -> TextArtifact: ... diff --git a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py index b54ff072f9..4eff4948b5 100644 --- a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py @@ -38,20 +38,8 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver): exclude_ids: list[str] = field(default=Factory(list), kw_only=True) timeout: Optional[int] = field(default=None, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright - bs4 = import_optional_dependency("bs4") - markdownify = import_optional_dependency("markdownify") - - include_links = self.include_links - - # Custom MarkdownConverter to optionally linked urls. If include_links is False only - # the text of the link is returned. - class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): - def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: - if include_links: - return super().convert_a(el, text, convert_as_inline) - return text with sync_playwright() as p, p.chromium.launch(headless=True) as browser: page = browser.new_page() @@ -76,28 +64,43 @@ def skip_loading_images(route: Any) -> Any: if not content: raise Exception("can't access URL") - soup = bs4.BeautifulSoup(content, "html.parser") + return content + + def extract_page(self, page: str) -> TextArtifact: + bs4 = import_optional_dependency("bs4") + markdownify = import_optional_dependency("markdownify") + include_links = self.include_links + + # Custom MarkdownConverter to optionally linked urls. If include_links is False only + # the text of the link is returned. + class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): + def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str: + if include_links: + return super().convert_a(el, text, convert_as_inline) + return text + + soup = bs4.BeautifulSoup(page, "html.parser") - # Remove unwanted elements - exclude_selector = ",".join( - self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], - ) - if exclude_selector: - for s in soup.select(exclude_selector): - s.extract() + # Remove unwanted elements + exclude_selector = ",".join( + self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], + ) + if exclude_selector: + for s in soup.select(exclude_selector): + s.extract() - text = OptionalLinksMarkdownConverter().convert_soup(soup) + text = OptionalLinksMarkdownConverter().convert_soup(soup) - # Remove leading and trailing whitespace from the entire text - text = text.strip() + # Remove leading and trailing whitespace from the entire text + text = text.strip() - # Remove trailing whitespace from each line - text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) + # Remove trailing whitespace from each line + text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) - # Indent using 2 spaces instead of tabs - text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) + # Indent using 2 spaces instead of tabs + text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) - # Remove triple+ newlines (keep double newlines for paragraphs) - text = re.sub(r"\n\n+", "\n\n", text) + # Remove triple+ newlines (keep double newlines for paragraphs) + text = re.sub(r"\n\n+", "\n\n", text) - return TextArtifact(text) + return TextArtifact(text) diff --git a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py index 2d785fde26..94b3914eae 100644 --- a/griptape/drivers/web_scraper/proxy_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/proxy_web_scraper_driver.py @@ -12,6 +12,10 @@ class ProxyWebScraperDriver(BaseWebScraperDriver): proxies: dict = field(kw_only=True, metadata={"serializable": False}) params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True}) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: response = requests.get(url, proxies=self.proxies, **self.params) - return TextArtifact(response.text) + + return response.text + + def extract_page(self, page: str) -> TextArtifact: + return TextArtifact(page) diff --git a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py index 06f5573a4a..e87af8af63 100644 --- a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py @@ -12,7 +12,7 @@ class TrafilaturaWebScraperDriver(BaseWebScraperDriver): include_links: bool = field(default=True, kw_only=True) - def scrape_url(self, url: str) -> TextArtifact: + def fetch_url(self, url: str) -> str: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config @@ -29,6 +29,15 @@ def scrape_url(self, url: str) -> TextArtifact: if page is None: raise Exception("can't access URL") + + return page + + def extract_page(self, page: str) -> TextArtifact: + trafilatura = import_optional_dependency("trafilatura") + use_config = trafilatura.settings.use_config + + config = use_config() + extracted_page = trafilatura.extract( page, include_links=self.include_links, diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index 7e4854d00c..0348a20946 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape import utils +from griptape.chunkers import TextChunker from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -14,12 +15,13 @@ from griptape.artifacts import TextArtifact from griptape.drivers import BaseVectorStoreDriver from griptape.engines.rag import RagContext - from griptape.loaders import BaseTextLoader + from griptape.loaders import TextLoader @define(kw_only=True) class TextLoaderRetrievalRagModule(BaseRetrievalRagModule): - loader: BaseTextLoader = field() + loader: TextLoader = field() + chunker: TextChunker = field(default=Factory(lambda: TextChunker())) vector_store_driver: BaseVectorStoreDriver = field() source: Any = field() query_params: dict[str, Any] = field(factory=dict) @@ -37,7 +39,8 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params["namespace"] = namespace loader_output = self.loader.load(source) + chunks = self.chunker.chunk(loader_output) - self.vector_store_driver.upsert_text_artifacts({namespace: loader_output}) + self.vector_store_driver.upsert_text_artifacts({namespace: chunks}) return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/loaders/__init__.py b/griptape/loaders/__init__.py index b79b0ff448..b863706070 100644 --- a/griptape/loaders/__init__.py +++ b/griptape/loaders/__init__.py @@ -1,26 +1,28 @@ from .base_loader import BaseLoader -from .base_text_loader import BaseTextLoader +from .base_file_loader import BaseFileLoader + from .text_loader import TextLoader from .pdf_loader import PdfLoader from .web_loader import WebLoader from .sql_loader import SqlLoader from .csv_loader import CsvLoader -from .dataframe_loader import DataFrameLoader from .email_loader import EmailLoader + +from .blob_loader import BlobLoader + from .image_loader import ImageLoader + from .audio_loader import AudioLoader -from .blob_loader import BlobLoader __all__ = [ "BaseLoader", - "BaseTextLoader", + "BaseFileLoader", "TextLoader", "PdfLoader", "WebLoader", "SqlLoader", "CsvLoader", - "DataFrameLoader", "EmailLoader", "ImageLoader", "AudioLoader", diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py index 84d6b767ae..1a7caefce5 100644 --- a/griptape/loaders/audio_loader.py +++ b/griptape/loaders/audio_loader.py @@ -1,20 +1,20 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast +import filetype from attrs import define from griptape.artifacts import AudioArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency +from griptape.loaders.base_file_loader import BaseFileLoader @define -class AudioLoader(BaseLoader): +class AudioLoader(BaseFileLoader): """Loads audio content into audio artifacts.""" - def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: - return AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) + def load(self, source: Any, *args, **kwargs) -> AudioArtifact: + return cast(AudioArtifact, super().load(source, *args, **kwargs)) - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]: - return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs)) + def parse(self, source: bytes, *args, **kwargs) -> AudioArtifact: + return AudioArtifact(source, format=filetype.guess(source).extension) diff --git a/griptape/loaders/base_file_loader.py b/griptape/loaders/base_file_loader.py new file mode 100644 index 0000000000..650dc303cd --- /dev/null +++ b/griptape/loaders/base_file_loader.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING + +from attrs import Factory, define, field + +from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver +from griptape.loaders import BaseLoader + +if TYPE_CHECKING: + from os import PathLike + + +@define +class BaseFileLoader(BaseLoader, ABC): + file_manager_driver: BaseFileManagerDriver = field( + default=Factory(lambda: LocalFileManagerDriver(workdir=None)), + kw_only=True, + ) + encoding: str = field(default="utf-8", kw_only=True) + + def fetch(self, source: str | PathLike, *args, **kwargs) -> str | bytes: + return self.file_manager_driver.load_file(str(source), *args, **kwargs).value diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 525b4df0ad..7c16c8ded7 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -10,24 +10,33 @@ from griptape.utils.hash import bytes_to_hash, str_to_hash if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Mapping from griptape.artifacts import BaseArtifact + from griptape.common import Reference @define class BaseLoader(FuturesExecutorMixin, ABC): - encoding: Optional[str] = field(default=None, kw_only=True) + reference: Optional[Reference] = field(default=None, kw_only=True) + + def load(self, source: Any, *args, **kwargs) -> BaseArtifact: + data = self.fetch(source) + + return self.parse(data) + + @abstractmethod + def fetch(self, source: Any, *args, **kwargs) -> Any: ... @abstractmethod - def load(self, source: Any, *args, **kwargs) -> BaseArtifact | Sequence[BaseArtifact]: ... + def parse(self, source: Any, *args, **kwargs) -> BaseArtifact: ... def load_collection( self, sources: list[Any], *args, **kwargs, - ) -> Mapping[str, BaseArtifact | Sequence[BaseArtifact | Sequence[BaseArtifact]]]: + ) -> Mapping[str, BaseArtifact]: # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} @@ -39,10 +48,8 @@ def load_collection( }, ) - def to_key(self, source: Any, *args, **kwargs) -> str: + def to_key(self, source: Any) -> str: if isinstance(source, bytes): return bytes_to_hash(source) - elif isinstance(source, str): - return str_to_hash(source) else: return str_to_hash(str(source)) diff --git a/griptape/loaders/base_text_loader.py b/griptape/loaders/base_text_loader.py deleted file mode 100644 index 196cb0087c..0000000000 --- a/griptape/loaders/base_text_loader.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, cast - -from attrs import Factory, define, field - -from griptape.artifacts import TextArtifact -from griptape.chunkers import BaseChunker, TextChunker -from griptape.loaders import BaseLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.common import Reference - from griptape.drivers import BaseEmbeddingDriver - - -@define -class BaseTextLoader(BaseLoader, ABC): - MAX_TOKEN_RATIO = 0.5 - - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - max_tokens: int = field( - default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True), - kw_only=True, - ) - chunker: BaseChunker = field( - default=Factory( - lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), - takes_self=True, - ), - kw_only=True, - ) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - encoding: str = field(default="utf-8", kw_only=True) - reference: Optional[Reference] = field(default=None, kw_only=True) - - @abstractmethod - def load(self, source: Any, *args, **kwargs) -> list[TextArtifact]: ... - - def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) - - def _text_to_artifacts(self, text: str) -> list[TextArtifact]: - artifacts = [] - - chunks = self.chunker.chunk(text) if self.chunker else [TextArtifact(text)] - - for chunk in chunks: - if self.embedding_driver: - chunk.generate_embedding(self.embedding_driver) - - chunk.reference = self.reference - - chunk.encoding = self.encoding - - artifacts.append(chunk) - - return artifacts diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index d0099b47bc..dbc7e66bfb 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -5,16 +5,16 @@ from attrs import define from griptape.artifacts import BlobArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader @define -class BlobLoader(BaseLoader): +class BlobLoader(BaseFileLoader): def load(self, source: Any, *args, **kwargs) -> BlobArtifact: + return cast(BlobArtifact, super().load(source, *args, **kwargs)) + + def parse(self, source: bytes, *args, **kwargs) -> BlobArtifact: if self.encoding is None: return BlobArtifact(source) else: return BlobArtifact(source, encoding=self.encoding) - - def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]: - return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index ab00a3216b..222952401b 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,45 +2,19 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Optional, cast from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.loaders import BaseLoader - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders.base_file_loader import BaseFileLoader @define -class CsvLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class CsvLoader(BaseFileLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact: - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - - artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) - - if self.embedding_driver: - artifact.generate_embedding(self.embedding_driver) - - return artifact + def parse(self, source: bytes, *args, **kwargs) -> ListArtifact: + reader = csv.DictReader(StringIO(source.decode(self.encoding)), delimiter=self.delimiter) - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, TableArtifact]: - return cast( - dict[str, TableArtifact], - super().load_collection(sources, *args, **kwargs), - ) + return ListArtifact([TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py deleted file mode 100644 index 5fbbd51d16..0000000000 --- a/griptape/loaders/dataframe_loader.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, cast - -from attrs import define, field - -from griptape.artifacts import TableArtifact -from griptape.loaders import BaseLoader -from griptape.utils import import_optional_dependency -from griptape.utils.hash import str_to_hash - -if TYPE_CHECKING: - from pandas import DataFrame - - from griptape.drivers import BaseEmbeddingDriver - - -@define -class DataFrameLoader(BaseLoader): - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - - def load(self, source: DataFrame, *args, **kwargs) -> TableArtifact: - artifact = TableArtifact(list(source.to_dict(orient="records"))) - - if self.embedding_driver: - artifact.generate_embedding(self.embedding_driver) - - return artifact - - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, TableArtifact]: - return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) - - def to_key(self, source: DataFrame, *args, **kwargs) -> str: - hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object - - return str_to_hash(str(hash_pandas_object(source, index=True).values)) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index f6c9ca4062..0a3d3f600c 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations import imaplib -from typing import Optional, cast +from typing import Any, Optional, cast from attrs import astuple, define, field @@ -32,11 +32,13 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: - mailparser = import_optional_dependency("mailparser") + def load(self, source: Any, *args, **kwargs) -> ListArtifact: + return cast(ListArtifact, super().load(source, *args, **kwargs)) + + def fetch(self, source: EmailQuery, *args, **kwargs) -> list[bytes]: label, key, search_criteria, max_count = astuple(source) - artifacts = [] + mail_bytes = [] with imaplib.IMAP4_SSL(self.imap_url) as client: client.login(self.username, self.password) @@ -59,19 +61,24 @@ def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: if data is None or not data or data[0] is None: continue - message = mailparser.parse_from_bytes(data[0][1]) - - # Note: mailparser only populates the text_plain field - # if the message content type is explicitly set to 'text/plain'. - if message.text_plain: - artifacts.append(TextArtifact("\n".join(message.text_plain))) + mail_bytes.append(data[0][1]) client.close() - return ListArtifact(artifacts) + return mail_bytes + + def parse(self, source: list[bytes], *args, **kwargs) -> ListArtifact: + mailparser = import_optional_dependency("mailparser") + artifacts = [] + for byte in source: + message = mailparser.parse_from_bytes(byte) + + # Note: mailparser only populates the text_plain field + # if the message content type is explicitly set to 'text/plain'. + if message.text_plain: + artifacts.append(TextArtifact("\n".join(message.text_plain))) + + return ListArtifact(artifacts) def _count_messages(self, message_numbers: bytes) -> int: return len(list(filter(None, message_numbers.decode().split(" ")))) - - def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact]: - return cast(dict[str, ListArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 83060dfa89..513910c78c 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -1,17 +1,17 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Any, Optional, cast from attrs import define, field from griptape.artifacts import ImageArtifact -from griptape.loaders import BaseLoader +from griptape.loaders import BaseFileLoader from griptape.utils import import_optional_dependency @define -class ImageLoader(BaseLoader): +class ImageLoader(BaseFileLoader): """Loads images into image artifacts. Attributes: @@ -22,20 +22,13 @@ class ImageLoader(BaseLoader): format: Optional[str] = field(default=None, kw_only=True) - FORMAT_TO_MIME_TYPE = { - "bmp": "image/bmp", - "gif": "image/gif", - "jpeg": "image/jpeg", - "png": "image/png", - "tiff": "image/tiff", - "webp": "image/webp", - } + def load(self, source: Any, *args, **kwargs) -> ImageArtifact: + return cast(ImageArtifact, super().load(source, *args, **kwargs)) - def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: + def parse(self, source: bytes, *args, **kwargs) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") image = pil_image.open(BytesIO(source)) - # Normalize format only if requested. if self.format is not None: byte_stream = BytesIO() image.save(byte_stream, format=self.format) @@ -43,15 +36,3 @@ def load(self, source: bytes, *args, **kwargs) -> ImageArtifact: source = byte_stream.getvalue() return ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height) - - def _get_mime_type(self, image_format: str | None) -> str: - if image_format is None: - raise ValueError("image_format is None") - - if image_format.lower() not in self.FORMAT_TO_MIME_TYPE: - raise ValueError(f"Unsupported image format {image_format}") - - return self.FORMAT_TO_MIME_TYPE[image_format.lower()] - - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]: - return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 419bfabf4b..b1f9013864 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,37 +1,29 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, cast +from typing import Any, Optional, cast -from attrs import Factory, define, field +from attrs import define -from griptape.artifacts import TextArtifact -from griptape.chunkers import PdfChunker -from griptape.loaders import BaseTextLoader +from griptape.artifacts import ListArtifact, TextArtifact +from griptape.loaders.base_file_loader import BaseFileLoader from griptape.utils import import_optional_dependency @define -class PdfLoader(BaseTextLoader): - chunker: PdfChunker = field( - default=Factory(lambda self: PdfChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True), - kw_only=True, - ) - encoding: None = field(default=None, kw_only=True) +class PdfLoader(BaseFileLoader): + def load(self, source: Any, *args, **kwargs) -> TextArtifact: + return cast(TextArtifact, super().load(source, *args, **kwargs)) - def load( + def parse( self, source: bytes, password: Optional[str] = None, *args, **kwargs, - ) -> list[TextArtifact]: + ) -> ListArtifact: pypdf = import_optional_dependency("pypdf") reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) - return self._text_to_artifacts("\n".join([p.extract_text() for p in reader.pages])) + pages = [TextArtifact(p.extract_text()) for p in reader.pages] - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + return ListArtifact(pages) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index d723a5a9fb..d633eede60 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,29 +1,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, cast from attrs import define, field -from griptape.artifacts import TableArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver, BaseSqlDriver + from griptape.drivers import BaseSqlDriver @define class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: str, *args, **kwargs) -> TableArtifact: - rows = self.sql_driver.execute_query(source) - artifact = TableArtifact([row.cells for row in rows] if rows else []) + def load(self, source: Any, *args, **kwargs) -> ListArtifact: + return cast(ListArtifact, super().load(source, *args, **kwargs)) - if self.embedding_driver: - artifact.generate_embedding(self.embedding_driver) + def fetch(self, source: str, *args, **kwargs) -> list[BaseSqlDriver.RowResult]: + return self.sql_driver.execute_query(source) or [] - return artifact - - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, TableArtifact]: - return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) + def parse(self, source: list[BaseSqlDriver.RowResult], *args, **kwargs) -> ListArtifact: + return ListArtifact([TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(source)]) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 79e551a8e3..68dc324aef 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,55 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import Any, cast -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.chunkers import TextChunker -from griptape.loaders import BaseTextLoader -from griptape.tokenizers import OpenAiTokenizer - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingDriver +from griptape.loaders import BaseFileLoader @define -class TextLoader(BaseTextLoader): - MAX_TOKEN_RATIO = 0.5 - - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), - kw_only=True, - ) - max_tokens: int = field( - default=Factory(lambda self: round(self.tokenizer.max_input_tokens * self.MAX_TOKEN_RATIO), takes_self=True), - kw_only=True, - ) - chunker: TextChunker = field( - default=Factory( - lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), - takes_self=True, - ), - kw_only=True, - ) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) +class TextLoader(BaseFileLoader): encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: - if isinstance(source, bytes): - source = source.decode(encoding=self.encoding) - elif isinstance(source, (bytearray, memoryview)): - raise ValueError(f"Unsupported source type: {type(source)}") - - return self._text_to_artifacts(source) + def load(self, source: Any, *args, **kwargs) -> TextArtifact: + return cast(TextArtifact, super().load(source, *args, **kwargs)) - def load_collection( - self, - sources: list[bytes | str], - *args, - **kwargs, - ) -> dict[str, list[TextArtifact]]: - return cast( - dict[str, list[TextArtifact]], - super().load_collection(sources, *args, **kwargs), - ) + def parse(self, source: str | bytes, *args, **kwargs) -> TextArtifact: + if isinstance(source, str): + return TextArtifact(source, encoding=self.encoding) + else: + return TextArtifact(source.decode(self.encoding), encoding=self.encoding) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index 720ab34a19..d31a348167 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,23 +1,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any, cast from attrs import Factory, define, field +from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver -from griptape.loaders import BaseTextLoader - -if TYPE_CHECKING: - from griptape.artifacts import TextArtifact +from griptape.loaders import BaseLoader @define -class WebLoader(BaseTextLoader): +class WebLoader(BaseLoader): web_scraper_driver: BaseWebScraperDriver = field( default=Factory(lambda: TrafilaturaWebScraperDriver()), kw_only=True, ) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: - single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) - return self._text_to_artifacts(single_chunk_text_artifact.value) + def load(self, source: Any, *args, **kwargs) -> TextArtifact: + return cast(TextArtifact, super().load(source, *args, **kwargs)) + + def fetch(self, source: str, *args, **kwargs) -> str: + return self.web_scraper_driver.fetch_url(source) + + def parse(self, source: str, *args, **kwargs) -> TextArtifact: + return self.web_scraper_driver.extract_page(source) diff --git a/griptape/tools/file_manager/tool.py b/griptape/tools/file_manager/tool.py index b72f823292..8dc0c93936 100644 --- a/griptape/tools/file_manager/tool.py +++ b/griptape/tools/file_manager/tool.py @@ -5,9 +5,12 @@ from attrs import Factory, define, field from schema import Literal, Schema +import griptape.loaders as loaders from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver +from griptape.loaders.blob_loader import BlobLoader from griptape.tools import BaseTool +from griptape.utils import get_mime_type from griptape.utils.decorators import activity @@ -21,6 +24,20 @@ class FileManagerTool(BaseTool): file_manager_driver: BaseFileManagerDriver = field(default=Factory(lambda: LocalFileManagerDriver()), kw_only=True) + loaders: dict[str, loaders.BaseLoader] = field( + default=Factory( + lambda self: { + "application/pdf": loaders.PdfLoader(file_manager_driver=self.file_manager_driver), + "text/csv": loaders.CsvLoader(file_manager_driver=self.file_manager_driver), + "text": loaders.TextLoader(file_manager_driver=self.file_manager_driver), + "image": loaders.ImageLoader(file_manager_driver=self.file_manager_driver), + "application/octet-stream": BlobLoader(file_manager_driver=self.file_manager_driver), + }, + takes_self=True, + ), + kw_only=True, + ) + @activity( config={ "description": "Can be used to list files on disk", @@ -51,7 +68,11 @@ def load_files_from_disk(self, params: dict) -> ListArtifact | ErrorArtifact: artifacts = [] for path in paths: - artifact = self.file_manager_driver.load_file(path) + abs_path = os.path.join(self.file_manager_driver.workdir, path) + mime_type = get_mime_type(abs_path) + loader = next((loader for key, loader in self.loaders.items() if mime_type.startswith(key))) + + artifact = loader.load(path) if isinstance(artifact, ListArtifact): artifacts.extend(artifact.value) else: diff --git a/griptape/tools/sql/tool.py b/griptape/tools/sql/tool.py index aca41b4ac6..59d0a3dbaa 100644 --- a/griptape/tools/sql/tool.py +++ b/griptape/tools/sql/tool.py @@ -5,12 +5,11 @@ from attrs import define, field from schema import Schema -from griptape.artifacts import ErrorArtifact, InfoArtifact +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity if TYPE_CHECKING: - from griptape.artifacts import TableArtifact from griptape.loaders import SqlLoader @@ -44,7 +43,7 @@ def table_schema(self) -> Optional[str]: "schema": Schema({"sql_query": str}), }, ) - def execute_query(self, params: dict) -> TableArtifact | InfoArtifact | ErrorArtifact: + def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: try: query = params["values"]["sql_query"] rows = self.sql_loader.load(query) diff --git a/griptape/tools/web_scraper/tool.py b/griptape/tools/web_scraper/tool.py index 2895d5e0dd..982123b305 100644 --- a/griptape/tools/web_scraper/tool.py +++ b/griptape/tools/web_scraper/tool.py @@ -4,6 +4,7 @@ from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, ListArtifact +from griptape.chunkers import TextChunker from griptape.loaders import WebLoader from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -12,6 +13,7 @@ @define class WebScraperTool(BaseTool): web_loader: WebLoader = field(default=Factory(lambda: WebLoader()), kw_only=True) + text_chunker: TextChunker = field(default=Factory(lambda: TextChunker()), kw_only=True) @activity( config={ @@ -24,6 +26,8 @@ def get_content(self, params: dict) -> ListArtifact | ErrorArtifact: try: result = self.web_loader.load(url) - return ListArtifact(result) + chunks = TextChunker().chunk(result) + + return ListArtifact(chunks) except Exception as e: return ErrorArtifact("Error getting page content: " + str(e)) diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 03725f59d4..77e3f3b0a0 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -8,7 +8,6 @@ from .futures import execute_futures_dict, execute_futures_list, execute_futures_list_dict from .token_counter import TokenCounter from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively -from .file_utils import load_file, load_files from .hash import str_to_hash from .import_utils import import_optional_dependency from .import_utils import is_dependency_installed @@ -17,6 +16,7 @@ from .deprecation import deprecation_warn from .structure_visualizer import StructureVisualizer from .reference_utils import references_from_artifacts +from .file_utils import get_mime_type def minify_json(value: str) -> str: @@ -44,8 +44,7 @@ def minify_json(value: str) -> str: "Stream", "load_artifact_from_memory", "deprecation_warn", - "load_file", - "load_files", "StructureVisualizer", "references_from_artifacts", + "get_mime_type", ] diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py index 19c9f699ce..0dbbbc0933 100644 --- a/griptape/utils/file_utils.py +++ b/griptape/utils/file_utils.py @@ -1,38 +1,16 @@ -from __future__ import annotations +import mimetypes -from concurrent import futures -from pathlib import Path -from typing import Optional +import filetype -import griptape.utils as utils +def get_mime_type(file_path: str) -> str: + filetype_guess = filetype.guess(file_path) -def load_file(path: str) -> bytes: - """Load a file from the given path and return its content as bytes. - - Args: - path (str): The path to the file to load. - - Returns: - The content of the file. - """ - return Path(path).read_bytes() - - -def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, bytes]: - """Load multiple files concurrently and return a dictionary of their content. - - Args: - paths: The paths to the files to load. - futures_executor: The executor to use for concurrent loading. If None, a new ThreadPoolExecutor will be created. - - Returns: - A dictionary where the keys are a hash of the path and the values are the content of the files. - """ - if futures_executor is None: - futures_executor = futures.ThreadPoolExecutor() - - with futures_executor as executor: - return utils.execute_futures_dict( - {utils.str_to_hash(str(path)): executor.submit(load_file, path) for path in paths}, - ) + if filetype_guess is None: + type_, _ = mimetypes.guess_type(file_path) + if type_ is None: + return "application/octet-stream" + else: + return type_ + else: + return filetype_guess.mime diff --git a/poetry.lock b/poetry.lock index 640b3bab08..e8db049a99 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1595,7 +1595,7 @@ typing = ["typing-extensions (>=4.8)"] name = "filetype" version = "1.2.0" description = "Infer file type and MIME type of any file/buffer. No external dependencies." -optional = true +optional = false python-versions = "*" files = [ {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, @@ -5862,48 +5862,36 @@ python-versions = ">=3.7" files = [ {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d0b2cf8791ab5fb9e3aa3d9a79a0d5d51f55b6357eecf532a120ba3b5524db"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:243f92596f4fd4c8bd30ab8e8dd5965afe226363d75cab2468f2c707f64cd83b"}, - {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ea54f7300553af0a2a7235e9b85f4204e1fc21848f917a3213b0e0818de9a24"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:173f5f122d2e1bff8fbd9f7811b7942bead1f5e9f371cdf9e670b327e6703ebd"}, - {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:196958cde924a00488e3e83ff917be3b73cd4ed8352bbc0f2989333176d1c54d"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd90c221ed4e60ac9d476db967f436cfcecbd4ef744537c0f2d5291439848768"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-win32.whl", hash = "sha256:3166dfff2d16fe9be3241ee60ece6fcb01cf8e74dd7c5e0b64f8e19fab44911b"}, {file = "SQLAlchemy-2.0.34-cp310-cp310-win_amd64.whl", hash = "sha256:6831a78bbd3c40f909b3e5233f87341f12d0b34a58f14115c9e94b4cdaf726d3"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7db3db284a0edaebe87f8f6642c2b2c27ed85c3e70064b84d1c9e4ec06d5d84"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:430093fce0efc7941d911d34f75a70084f12f6ca5c15d19595c18753edb7c33b"}, - {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79cb400c360c7c210097b147c16a9e4c14688a6402445ac848f296ade6283bbc"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1b30f31a36c7f3fee848391ff77eebdd3af5750bf95fbf9b8b5323edfdb4ec"}, - {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fddde2368e777ea2a4891a3fb4341e910a056be0bb15303bf1b92f073b80c02"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80bd73ea335203b125cf1d8e50fef06be709619eb6ab9e7b891ea34b5baa2287"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-win32.whl", hash = "sha256:6daeb8382d0df526372abd9cb795c992e18eed25ef2c43afe518c73f8cccb721"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-win_amd64.whl", hash = "sha256:5bc08e75ed11693ecb648b7a0a4ed80da6d10845e44be0c98c03f2f880b68ff4"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:53e68b091492c8ed2bd0141e00ad3089bcc6bf0e6ec4142ad6505b4afe64163e"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bcd18441a49499bf5528deaa9dee1f5c01ca491fc2791b13604e8f972877f812"}, - {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:165bbe0b376541092bf49542bd9827b048357f4623486096fc9aaa6d4e7c59a2"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3330415cd387d2b88600e8e26b510d0370db9b7eaf984354a43e19c40df2e2b"}, - {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97b850f73f8abbffb66ccbab6e55a195a0eb655e5dc74624d15cff4bfb35bd74"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee4c6917857fd6121ed84f56d1dc78eb1d0e87f845ab5a568aba73e78adf83"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-win32.whl", hash = "sha256:fbb034f565ecbe6c530dff948239377ba859420d146d5f62f0271407ffb8c580"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-win_amd64.whl", hash = "sha256:707c8f44931a4facd4149b52b75b80544a8d824162602b8cd2fe788207307f9a"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24af3dc43568f3780b7e1e57c49b41d98b2d940c1fd2e62d65d3928b6f95f021"}, - {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e60ed6ef0a35c6b76b7640fe452d0e47acc832ccbb8475de549a5cc5f90c2c06"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:413c85cd0177c23e32dee6898c67a5f49296640041d98fddb2c40888fe4daa2e"}, - {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:25691f4adfb9d5e796fd48bf1432272f95f4bbe5f89c475a788f31232ea6afba"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:526ce723265643dbc4c7efb54f56648cc30e7abe20f387d763364b3ce7506c82"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-win32.whl", hash = "sha256:13be2cc683b76977a700948411a94c67ad8faf542fa7da2a4b167f2244781cf3"}, {file = "SQLAlchemy-2.0.34-cp37-cp37m-win_amd64.whl", hash = "sha256:e54ef33ea80d464c3dcfe881eb00ad5921b60f8115ea1a30d781653edc2fd6a2"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:43f28005141165edd11fbbf1541c920bd29e167b8bbc1fb410d4fe2269c1667a"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b68094b165a9e930aedef90725a8fcfafe9ef95370cbb54abc0464062dbf808f"}, - {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1e03db964e9d32f112bae36f0cc1dcd1988d096cfd75d6a588a3c3def9ab2b"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:203d46bddeaa7982f9c3cc693e5bc93db476ab5de9d4b4640d5c99ff219bee8c"}, - {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ae92bebca3b1e6bd203494e5ef919a60fb6dfe4d9a47ed2453211d3bd451b9f5"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9661268415f450c95f72f0ac1217cc6f10256f860eed85c2ae32e75b60278ad8"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-win32.whl", hash = "sha256:895184dfef8708e15f7516bd930bda7e50ead069280d2ce09ba11781b630a434"}, {file = "SQLAlchemy-2.0.34-cp38-cp38-win_amd64.whl", hash = "sha256:6e7cde3a2221aa89247944cafb1b26616380e30c63e37ed19ff0bba5e968688d"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dbcdf987f3aceef9763b6d7b1fd3e4ee210ddd26cac421d78b3c206d07b2700b"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ce119fc4ce0d64124d37f66a6f2a584fddc3c5001755f8a49f1ca0a177ef9796"}, - {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a17d8fac6df9835d8e2b4c5523666e7051d0897a93756518a1fe101c7f47f2f0"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ebc11c54c6ecdd07bb4efbfa1554538982f5432dfb8456958b6d46b9f834bb7"}, - {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2e6965346fc1491a566e019a4a1d3dfc081ce7ac1a736536367ca305da6472a8"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:220574e78ad986aea8e81ac68821e47ea9202b7e44f251b7ed8c66d9ae3f4278"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-win32.whl", hash = "sha256:b75b00083e7fe6621ce13cfce9d4469c4774e55e8e9d38c305b37f13cf1e874c"}, {file = "SQLAlchemy-2.0.34-cp39-cp39-win_amd64.whl", hash = "sha256:c29d03e0adf3cc1a8c3ec62d176824972ae29b67a66cbb18daff3062acc6faa8"}, @@ -6375,6 +6363,11 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -6956,7 +6949,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["accelerate", "anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "torch", "trafilatura", "transformers", "voyageai"] +all = ["accelerate", "anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "sentencepiece", "snowflake-sqlalchemy", "sqlalchemy", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -6998,8 +6991,6 @@ drivers-vector-redis = ["redis"] drivers-web-scraper-markdownify = ["beautifulsoup4", "markdownify", "playwright"] drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-search-duckduckgo = ["duckduckgo-search"] -loaders-audio = ["filetype"] -loaders-dataframe = ["pandas"] loaders-email = ["mail-parser"] loaders-image = ["pillow"] loaders-pdf = ["pypdf"] @@ -7008,4 +6999,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ad31933841926602b18aea04927c1f86d6825d8891495210a28dc1732768499e" +content-hash = "61d8eb5080886ac092c8e9ff29c6260276acc94ec2062f9363e838c004c5e37f" diff --git a/pyproject.toml b/pyproject.toml index a0854e0a2d..eb4650e5cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ numpy = "^1.26.4" stringcase = "^1.2.0" docker = "^7.1.0" requests = "^2.32.0" +filetype = "^1.2" # drivers cohere = { version = "^5.5.4", optional = true } @@ -70,7 +71,6 @@ pandas = {version = "^1.3", optional = true} pypdf = {version = "^3.9", optional = true} pillow = {version = "^10.2.0", optional = true} mail-parser = {version = "^3.15.0", optional = true} -filetype = {version = "^1.2", optional = true} [tool.poetry.extras] drivers-prompt-cohere = ["cohere"] @@ -150,11 +150,9 @@ drivers-image-generation-huggingface = [ "pillow", ] -loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] loaders-email = ["mail-parser"] -loaders-audio = ["filetype"] loaders-sql = ["sqlalchemy"] all = [ @@ -201,7 +199,6 @@ all = [ "pandas", "pypdf", "mail-parser", - "filetype", ] [tool.poetry.group.test] diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 84ce617685..3d131ef7ae 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -5,9 +5,9 @@ import pytest from moto import mock_s3 -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import AmazonS3FileManagerDriver -from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -154,8 +154,7 @@ def test_list_files_failure(self, workdir, path, expected, driver): def test_load_file(self, driver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -185,9 +184,8 @@ def test_load_file_failure(self, workdir, path, expected, driver): def test_load_file_with_encoding(self, driver): artifact = driver.load_file("resources/test.txt") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) + assert isinstance(artifact, BlobArtifact) + assert artifact.encoding == "utf-8" @pytest.mark.parametrize( ("workdir", "path", "content"), @@ -240,9 +238,7 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_c def test_save_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -253,9 +249,7 @@ def test_save_file_with_encoding(self, session, bucket, get_s3_value): def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, loaders={"txt": TextLoader(encoding="ascii")}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.save_file(path, "foobar") @@ -264,13 +258,10 @@ def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): assert get_s3_value(expected_s3_key) == "foobar" assert result.value == "Successfully saved file" - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=workdir - ) + driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" result = driver.load_file(path) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, TextArtifact) + assert result.encoding == "ascii" diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index 394a838a3f..99f0285bc7 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -4,9 +4,9 @@ import pytest -from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, TextArtifact +from griptape.artifacts.blob_artifact import BlobArtifact from griptape.drivers import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader class TestLocalFileManagerDriver: @@ -127,8 +127,7 @@ def test_list_files_failure(self, workdir, path, expected, temp_dir, driver): def test_load_file(self, driver: LocalFileManagerDriver): artifact = driver.load_file("resources/bitcoin.pdf") - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 4 + assert isinstance(artifact, BlobArtifact) @pytest.mark.parametrize( ("workdir", "path", "expected"), @@ -156,23 +155,6 @@ def test_load_file_failure(self, workdir, path, expected, temp_dir, driver): with pytest.raises(expected): driver.load_file(path) - def test_load_file_with_encoding(self, driver: LocalFileManagerDriver): - artifact = driver.load_file("resources/test.txt") - - assert isinstance(artifact, ListArtifact) - assert len(artifact.value) == 1 - assert isinstance(artifact.value[0], TextArtifact) - - def test_load_file_with_encoding_failure(self, driver): - driver = LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, - workdir=os.path.normpath(os.path.abspath(os.path.dirname(__file__) + "../../../../")), - ) - - with pytest.raises(UnicodeDecodeError): - driver.load_file("resources/bitcoin.pdf") - @pytest.mark.parametrize( ("workdir", "path", "content"), [ @@ -224,25 +206,24 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver): driver.save_file(path, "foobar") def test_save_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" def test_save_and_load_file_with_encoding(self, temp_dir): - driver = LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir) + driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.load_file(os.path.join("test", "foobar.txt")) - assert isinstance(result, ListArtifact) - assert len(result.value) == 1 - assert isinstance(result.value[0], TextArtifact) + assert isinstance(result, TextArtifact) + assert result.encoding == "ascii" def _to_driver_workdir(self, temp_dir, workdir): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_vector_store_driver.py similarity index 95% rename from tests/unit/drivers/vector/test_base_local_vector_store_driver.py rename to tests/unit/drivers/vector/test_base_vector_store_driver.py index 20a3e2b50c..8438568add 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_vector_store_driver.py @@ -4,12 +4,13 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.drivers import BaseVectorStoreDriver -class BaseLocalVectorStoreDriver(ABC): +class TestBaseVectorStoreDriver(ABC): @pytest.fixture() @abstractmethod - def driver(self): ... + def driver(self, *args, **kwargs) -> BaseVectorStoreDriver: ... def test_upsert(self, driver): namespace = driver.upsert_text_artifact(TextArtifact(id="foo1", value="foobar")) diff --git a/tests/unit/drivers/vector/test_local_vector_store_driver.py b/tests/unit/drivers/vector/test_local_vector_store_driver.py index 2504b24864..9722bb25d6 100644 --- a/tests/unit/drivers/vector/test_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_local_vector_store_driver.py @@ -3,10 +3,10 @@ from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_vector_store_driver import TestBaseVectorStoreDriver -class TestLocalVectorStoreDriver(BaseLocalVectorStoreDriver): +class TestLocalVectorStoreDriver(TestBaseVectorStoreDriver): @pytest.fixture() def driver(self): return LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py index c130858b55..9fa725e80a 100644 --- a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py @@ -6,10 +6,10 @@ from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_vector_store_driver import TestBaseVectorStoreDriver -class TestPersistentLocalVectorStoreDriver(BaseLocalVectorStoreDriver): +class TestPersistentLocalVectorStoreDriver(TestBaseVectorStoreDriver): @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py index 69e334c7fa..b2dbb613d2 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py @@ -18,7 +18,7 @@ def test_run(self): embedding_driver = MockEmbeddingDriver() module = TextLoaderRetrievalRagModule( - loader=WebLoader(max_tokens=MAX_TOKENS, embedding_driver=embedding_driver), + loader=WebLoader(), vector_store_driver=LocalVectorStoreDriver(embedding_driver=embedding_driver), source="https://www.griptape.ai", ) diff --git a/tests/unit/loaders/conftest.py b/tests/unit/loaders/conftest.py index 1f698738ab..0bbf839b8a 100644 --- a/tests/unit/loaders/conftest.py +++ b/tests/unit/loaders/conftest.py @@ -1,4 +1,5 @@ import os +from io import BytesIO, StringIO from pathlib import Path import pytest @@ -14,15 +15,15 @@ def create_source(resource_path: str) -> Path: @pytest.fixture() def bytes_from_resource_path(path_from_resource_path): - def create_source(resource_path: str) -> bytes: - return Path(path_from_resource_path(resource_path)).read_bytes() + def create_source(resource_path: str) -> BytesIO: + return BytesIO(Path(path_from_resource_path(resource_path)).read_bytes()) return create_source @pytest.fixture() def str_from_resource_path(path_from_resource_path): - def test_csv_str(resource_path: str) -> str: - return Path(path_from_resource_path(resource_path)).read_text() + def test_csv_str(resource_path: str) -> StringIO: + return StringIO(Path(path_from_resource_path(resource_path)).read_text()) return test_csv_str diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index b7ebdd9120..7b35167222 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -9,9 +9,9 @@ class TestAudioLoader: def loader(self): return AudioLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize(("resource_path", "mime_type"), [("sentences.wav", "audio/wav")]) def test_load(self, resource_path, mime_type, loader, create_source): @@ -21,7 +21,6 @@ def test_load(self, resource_path, mime_type, loader, create_source): assert isinstance(artifact, AudioArtifact) assert artifact.mime_type == mime_type - assert len(artifact.value) > 0 def test_load_collection(self, create_source, loader): resource_paths = ["sentences.wav", "sentences2.wav"] diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index 4812e669c8..2042381bc4 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -11,7 +11,7 @@ def loader(self, request): kwargs = {"encoding": encoding} if encoding is not None else {} return BlobLoader(**kwargs) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index cfe251ff32..9c5d0febb7 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,7 +1,6 @@ import pytest from griptape.loaders.csv_loader import CsvLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestCsvLoader: @@ -9,37 +8,35 @@ class TestCsvLoader: def loader(self, request): encoding = request.param if encoding is None: - return CsvLoader(embedding_driver=MockEmbeddingDriver()) + return CsvLoader() else: - return CsvLoader(encoding=encoding, embedding_driver=MockEmbeddingDriver()) + return CsvLoader(encoding=encoding) @pytest.fixture() def loader_with_pipe_delimiter(self): - return CsvLoader(delimiter="|", embedding_driver=MockEmbeddingDriver()) + return CsvLoader(delimiter="|") - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("test-1.csv") - artifact = loader.load(source) + artifacts = loader.load(source) assert len(artifacts) == 10 first_artifact = artifacts[0] assert first_artifact.value == "Foo: foo1\nBar: bar1" - assert first_artifact.embedding == [0, 1] def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): source = create_source("test-pipe.csv") - artifact = loader_with_pipe_delimiter.load(source) + artifacts = loader_with_pipe_delimiter.load(source) assert len(artifacts) == 10 first_artifact = artifacts[0] assert first_artifact.value == "Bar: foo1\nFoo: bar1" - assert first_artifact.embedding == [0, 1] def test_load_collection(self, loader, create_source): resource_paths = ["test-1.csv", "test-2.csv"] @@ -51,7 +48,5 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys assert collection[loader.to_key(sources[0])][0].value == "Foo: foo1\nBar: bar1" - assert collection[loader.to_key(sources[0])][0].embedding == [0, 1] assert collection[loader.to_key(sources[1])][0].value == "Bar: bar1\nFoo: foo1" - assert collection[loader.to_key(sources[1])][0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index ade0627435..1812dc531e 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -1,14 +1,16 @@ from __future__ import annotations import email -from email import message -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from griptape.artifacts import ListArtifact from griptape.loaders import EmailLoader +if TYPE_CHECKING: + from email.message import Message + class TestEmailLoader: @pytest.fixture(autouse=True) @@ -127,23 +129,21 @@ def to_fetch_message(body: str, content_type: Optional[str]): return to_fetch_response(to_message(body, content_type)) -def to_fetch_response(message: message): +def to_fetch_response(message: Message): return (None, ((None, message.as_bytes()),)) -def to_message(body: str, content_type: Optional[str]) -> message: +def to_message(body: str, content_type: Optional[str]) -> Message: message = email.message_from_string(body) if content_type: message.set_type(content_type) return message -def to_value_set(artifact_or_dict: ListArtifact | dict[str, ListArtifact]) -> set[str]: - if isinstance(artifact_or_dict, ListArtifact): - return {value.value for value in artifact_or_dict.value} - elif isinstance(artifact_or_dict, dict): - return { - text_artifact.value for list_artifact in artifact_or_dict.values() for text_artifact in list_artifact.value - } +def to_value_set(artifacts: ListArtifact | dict[str, ListArtifact]) -> set[str]: + if isinstance(artifacts, dict): + return set( + {text_artifact.value for list_artifact in artifacts.values() for text_artifact in list_artifact.value} + ) else: - raise Exception + return {artifact.value for artifact in artifacts.value} diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index 7093894b00..9c491fb881 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -13,9 +13,9 @@ def loader(self): def png_loader(self): return ImageLoader(format="png") - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) @pytest.mark.parametrize( ("resource_path", "mime_type"), diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 3f4f7848e5..45027b95ca 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -1,29 +1,25 @@ import pytest from griptape.loaders import PdfLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestPdfLoader: @pytest.fixture() def loader(self): - return PdfLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return PdfLoader() - @pytest.fixture() - def create_source(self, bytes_from_resource_path): - return bytes_from_resource_path + @pytest.fixture(params=["path_from_resource_path"]) + def create_source(self, request): + return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("bitcoin.pdf") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 151 - assert artifacts[0].value.startswith("Bitcoin: A Peer-to-Peer") - assert artifacts[-1].value.endswith('its applications," 1957.\n9') - assert artifacts[0].embedding == [0, 1] + assert len(artifact) == 9 + assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") + assert artifact[-1].value.endswith('its applications," 1957.\n9') def test_load_collection(self, loader, create_source): resource_paths = ["bitcoin.pdf", "bitcoin-2.pdf"] @@ -37,7 +33,6 @@ def test_load_collection(self, loader, create_source): for key in keys: artifact = collection[key] - assert len(artifact) == 151 + assert len(artifact) == 9 assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer") assert artifact[-1].value.endswith('its applications," 1957.\n9') - assert artifact[0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index 16ca2af5ae..4d33b634a8 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -3,7 +3,6 @@ from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -16,7 +15,6 @@ def loader(self): engine_url="sqlite:///:memory:", create_engine_params={"connect_args": {"check_same_thread": False}, "poolclass": StaticPool}, ), - embedding_driver=MockEmbeddingDriver(), ) sql_loader.sql_driver.execute_query( @@ -37,12 +35,10 @@ def loader(self): def test_load(self, loader): artifact = loader.load("SELECT * FROM test_table;") - assert len(artifacts) == 3 - assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" - assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" - - assert artifact.embedding == [0, 1] + assert len(artifact) == 3 + assert artifact[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" + assert artifact[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" + assert artifact[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" def test_load_collection(self, loader): sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"] @@ -55,4 +51,3 @@ def test_load_collection(self, loader): assert artifacts[loader.to_key(sources[0])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" assert artifacts[loader.to_key(sources[1])][0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" - assert list(artifacts.values())[0][0].embedding == [0, 1] diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 07527f9e62..c75417f56f 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -1,9 +1,6 @@ import pytest from griptape.loaders.text_loader import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver - -MAX_TOKENS = 50 class TestTextLoader: @@ -11,23 +8,21 @@ class TestTextLoader: def loader(self, request): encoding = request.param if encoding is None: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return TextLoader() else: - return TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return TextLoader(encoding=encoding) - @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) + @pytest.fixture(params=["path_from_resource_path"]) def create_source(self, request): return request.getfixturevalue(request.param) def test_load(self, loader, create_source): source = create_source("test.txt") - artifacts = loader.load(source) + artifact = loader.load(source) - assert len(artifacts) == 39 - assert artifacts[0].value.startswith("foobar foobar foobar") - assert artifacts[0].encoding == loader.encoding - assert artifacts[0].embedding == [0, 1] + assert artifact.value.startswith("foobar foobar foobar") + assert artifact.encoding == loader.encoding def test_load_collection(self, loader, create_source): resource_paths = ["test.txt"] @@ -39,9 +34,6 @@ def test_load_collection(self, loader, create_source): assert collection.keys() == keys key = next(iter(keys)) - artifacts = collection[key] - assert len(artifacts) == 39 + artifact = collection[key] - artifact = artifacts[0] - assert artifact.embedding == [0, 1] assert artifact.encoding == loader.encoding diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index f7cccb6664..d6e958042b 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,7 +1,6 @@ import pytest from griptape.loaders import WebLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -13,15 +12,12 @@ def _mock_trafilatura_fetch_url(self, mocker): @pytest.fixture() def loader(self): - return WebLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + return WebLoader() def test_load(self, loader): - artifacts = loader.load("https://github.com/griptape-ai/griptape") + artifact = loader.load("https://github.com/griptape-ai/griptape") - assert len(artifacts) == 1 - assert "foobar" in artifacts[0].value.lower() - - assert artifacts[0].embedding == [0, 1] + assert "foobar" in artifact.value.lower() def test_load_exception(self, mocker, loader): mocker.patch("trafilatura.fetch_url", side_effect=Exception("error")) @@ -38,9 +34,7 @@ def test_load_collection(self, loader): loader.to_key("https://github.com/griptape-ai/griptape"), loader.to_key("https://github.com/griptape-ai/griptape-docs"), ] - assert "foobar" in [a.value for artifact_list in artifacts.values() for a in artifact_list][0].lower() - - assert list(artifacts.values())[0][0].embedding == [0, 1] + assert "foobar" in [a.value for a in artifacts.values()] def test_empty_page_string_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value="") diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 469918a025..4e035bdee9 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -7,7 +7,6 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver -from griptape.loaders.text_loader import TextLoader from griptape.tools import FileManagerTool from tests.utils import defaults @@ -36,7 +35,7 @@ def test_load_files_from_disk(self, file_manager): result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/bitcoin.pdf"]}}) assert isinstance(result, ListArtifact) - assert len(result.value) == 4 + assert len(result.value) == 9 def test_load_files_from_disk_with_encoding(self, file_manager): result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/test.txt"]}}) @@ -48,8 +47,7 @@ def test_load_files_from_disk_with_encoding(self, file_manager): def test_load_files_from_disk_with_encoding_failure(self): file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), - loaders={}, + encoding="utf-8", workdir=os.path.abspath(os.path.dirname(__file__)), ) ) @@ -116,9 +114,7 @@ def test_save_content_to_file(self, temp_dir): assert result.value == "Successfully saved file" def test_save_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), workdir=temp_dir) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -127,9 +123,7 @@ def test_save_content_to_file_with_encoding(self, temp_dir): assert result.value == "Successfully saved file" def test_save_and_load_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -137,11 +131,7 @@ def test_save_and_load_content_to_file_with_encoding(self, temp_dir): assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - file_manager = FileManagerTool( - file_manager_driver=LocalFileManagerDriver( - default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir - ) - ) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.load_files_from_disk({"values": {"paths": [os.path.join("test", "foobar.txt")]}}) assert isinstance(result, ListArtifact) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py deleted file mode 100644 index 00df6958df..0000000000 --- a/tests/unit/utils/test_file_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -from concurrent import futures - -from griptape import utils -from griptape.loaders import TextLoader - -MAX_TOKENS = 50 - - -class TestFileUtils: - def test_load_file(self): - dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")) - - assert file.decode("utf-8").startswith("foobar foobar foobar") - - def test_load_files(self): - dirname = os.path.dirname(__file__) - sources = ["resources/foobar-many.txt", "resources/foobar-many.txt", "resources/small.png"] - sources = [os.path.join(dirname, "../../", source) for source in sources] - files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1)) - assert len(files) == 2 - - test_file = files[utils.str_to_hash(sources[0])] - assert test_file.decode("utf-8").startswith("foobar foobar foobar") - - small_file = files[utils.str_to_hash(sources[2])] - assert len(small_file) == 97 - assert small_file[:8] == b"\x89PNG\r\n\x1a\n" - - def test_load_file_with_loader(self): - dirname = os.path.dirname(__file__) - file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) - - assert len(artifacts) == 39 - assert isinstance(artifacts, list) - assert artifacts[0].value.startswith("foobar foobar foobar") - - def test_load_files_with_loader(self): - dirname = os.path.dirname(__file__) - sources = ["resources/foobar-many.txt"] - sources = [os.path.join(dirname, "../../", source) for source in sources] - files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS) - collection = loader.load_collection(list(files.values())) - - test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] - assert len(test_file_artifacts) == 39 - assert isinstance(test_file_artifacts, list) - assert test_file_artifacts[0].value.startswith("foobar foobar foobar")