diff --git a/docs/examples/src/load_query_and_chat_marqo_1.py b/docs/examples/src/load_query_and_chat_marqo_1.py index 013a0264f7..cdcb376bb2 100644 --- a/docs/examples/src/load_query_and_chat_marqo_1.py +++ b/docs/examples/src/load_query_and_chat_marqo_1.py @@ -1,7 +1,6 @@ import os from griptape import utils -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import MarqoVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -27,9 +26,6 @@ # Load artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert the artifacts into the vector store vector_store.upsert_text_artifacts( { diff --git a/docs/examples/src/query_webpage_1.py b/docs/examples/src/query_webpage_1.py index 2ea32b7187..b9e3286d65 100644 --- a/docs/examples/src/query_webpage_1.py +++ b/docs/examples/src/query_webpage_1.py @@ -1,14 +1,11 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])) artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) for a in artifacts: vector_store.upsert_text_artifact(a, namespace="griptape") diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 3309e1dcdf..4590a6b595 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts import ErrorArtifact from griptape.drivers import ( AstraDbVectorStoreDriver, OpenAiChatPromptDriver, @@ -45,8 +44,7 @@ ) artifacts = WebLoader(max_tokens=256).load(input_blogpost) -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) + vector_store_driver.upsert_text_artifacts({namespace: artifacts}) rag_tool = RagTool( diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index b4ab720297..3c29f4c741 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -1,6 +1,5 @@ import requests -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -32,8 +31,6 @@ ) artifacts = PdfLoader().load(response.content) -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) vector_store.upsert_text_artifacts({namespace: artifacts}) diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 0412ed977a..3e973da2d7 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -1,4 +1,3 @@ -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule @@ -28,9 +27,6 @@ artifacts = WebLoader().load("https://en.wikipedia.org/wiki/Physics") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - vector_store_driver.upsert_text_artifacts({namespace: artifacts}) rag_tool = RagTool( diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py index a4e54da3a0..7f7e98e134 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_1.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_1.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -12,8 +11,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) # Upsert Artifacts into the Vector Store Driver [vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index b7645bd824..39a21121d5 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import OpenAiEmbeddingDriver, QdrantVectorStoreDriver from griptape.loaders import WebLoader @@ -22,9 +21,6 @@ # Load Artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Recreate Qdrant collection vector_store_driver.client.recreate_collection( collection_name=vector_store_driver.collection_name, diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py index 965f97715f..a8d9ceed12 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_11.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_11.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import AstraDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -23,9 +22,6 @@ # Load Artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver [vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py index d2cfc8142f..559eaec5a7 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_3.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_3.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts import ErrorArtifact from griptape.drivers import OpenAiEmbeddingDriver, PineconeVectorStoreDriver from griptape.loaders import WebLoader @@ -17,9 +16,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver [vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts] diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py index fe35f1ff55..f2f0091a09 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_4.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_4.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts import ErrorArtifact from griptape.drivers import MarqoVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -22,9 +21,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py index 867195a486..7649579c75 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_5.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_5.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import MongoDbAtlasVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -28,9 +27,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py index 9c5c9cab65..78a7cc3e64 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_6.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_6.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import AzureMongoDbVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -28,9 +27,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py index c08d9ff3b7..d34ff86493 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_7.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_7.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import OpenAiEmbeddingDriver, RedisVectorStoreDriver from griptape.loaders import WebLoader @@ -18,9 +17,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py index a57363eb36..18e50a397c 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_8.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_8.py @@ -2,7 +2,6 @@ import boto3 -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import AmazonOpenSearchVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader @@ -19,9 +18,6 @@ # Load Artifacts from the web artifacts = WebLoader(max_tokens=200).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py index c5aface635..ad5abf9322 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_9.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_9.py @@ -1,6 +1,5 @@ import os -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import OpenAiEmbeddingDriver, PgVectorVectorStoreDriver from griptape.loaders import WebLoader @@ -25,9 +24,6 @@ # Load Artifacts from the web artifacts = WebLoader().load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - # Upsert Artifacts into the Vector Store Driver vector_store_driver.upsert_text_artifacts( { diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index c257cd4df2..a8a9cc06bf 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,4 +1,3 @@ -from griptape.artifacts import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule @@ -11,8 +10,6 @@ vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) vector_store.upsert_text_artifacts( { diff --git a/docs/griptape-framework/engines/src/summary_engines_1.py b/docs/griptape-framework/engines/src/summary_engines_1.py index 092665b37b..b5adf2a5a9 100644 --- a/docs/griptape-framework/engines/src/summary_engines_1.py +++ b/docs/griptape-framework/engines/src/summary_engines_1.py @@ -1,6 +1,5 @@ import requests -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import PromptSummaryEngine from griptape.loaders import PdfLoader @@ -12,9 +11,6 @@ artifacts = PdfLoader().load(response.content) -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) - text = "\n\n".join([a.value for a in artifacts]) engine.summarize_text(text) diff --git a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py index 266398d5e3..26c87e2552 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_tool_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py @@ -1,4 +1,3 @@ -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent @@ -9,8 +8,6 @@ ) artifacts = WebLoader().load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise Exception(artifacts.value) vector_store_driver.upsert_text_artifacts({"griptape": artifacts}) vector_db = VectorStoreTool( diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index 1c4f1dd6ac..dce5388122 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -41,60 +41,39 @@ class BaseFileManagerDriver(ABC): ) def list_files(self, path: str) -> TextArtifact | ErrorArtifact: - try: - entries = self.try_list_files(path) - return TextArtifact("\n".join(list(entries))) - except FileNotFoundError: - return ErrorArtifact("Path not found") - except NotADirectoryError: - return ErrorArtifact("Path is not a directory") - except Exception as e: - return ErrorArtifact(f"Failed to list files: {str(e)}") + entries = self.try_list_files(path) + return TextArtifact("\n".join(list(entries))) @abstractmethod def try_list_files(self, path: str) -> list[str]: ... def load_file(self, path: str) -> BaseArtifact: - try: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - source = self.try_load_file(path) - result = loader.load(source) - - if isinstance(result, BaseArtifact): - return result - else: - return ListArtifact(result) - except FileNotFoundError: - return ErrorArtifact("Path not found") - except IsADirectoryError: - return ErrorArtifact("Path is a directory") - except NotADirectoryError: - return ErrorArtifact("Not a directory") - except Exception as e: - return ErrorArtifact(f"Failed to load file: {str(e)}") + extension = path.split(".")[-1] + loader = self.loaders.get(extension) or self.default_loader + source = self.try_load_file(path) + result = loader.load(source) + + if isinstance(result, BaseArtifact): + return result + else: + return ListArtifact(result) @abstractmethod def try_load_file(self, path: str) -> bytes: ... - def save_file(self, path: str, value: bytes | str) -> InfoArtifact | ErrorArtifact: - try: - extension = path.split(".")[-1] - loader = self.loaders.get(extension) or self.default_loader - encoding = None if loader is None else loader.encoding + def save_file(self, path: str, value: bytes | str) -> InfoArtifact: + extension = path.split(".")[-1] + loader = self.loaders.get(extension) or self.default_loader + encoding = None if loader is None else loader.encoding - if isinstance(value, str): - value = value.encode() if encoding is None else value.encode(encoding=encoding) - elif isinstance(value, (bytearray, memoryview)): - raise ValueError(f"Unsupported type: {type(value)}") + if isinstance(value, str): + value = value.encode() if encoding is None else value.encode(encoding=encoding) + elif isinstance(value, (bytearray, memoryview)): + raise ValueError(f"Unsupported type: {type(value)}") - self.try_save_file(path, value) + self.try_save_file(path, value) - return InfoArtifact("Successfully saved file") - except IsADirectoryError: - return ErrorArtifact("Path is a directory") - except Exception as e: - return ErrorArtifact(f"Failed to save file: {str(e)}") + return InfoArtifact("Successfully saved file") @abstractmethod def try_save_file(self, path: str, value: bytes) -> None: ... diff --git a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py index 305d149954..a6e2064b6e 100644 --- a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py +++ b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact +from griptape.artifacts import BaseArtifact, InfoArtifact from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver @@ -23,28 +23,25 @@ class GriptapeCloudStructureRunDriver(BaseStructureRunDriver): structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True) async_run: bool = field(default=False, kw_only=True) - def try_run(self, *args: BaseArtifact) -> BaseArtifact: - from requests import HTTPError, Response, exceptions, post + def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact: + from requests import Response, post url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs") - try: - response: Response = post( - url, - json={"args": [arg.value for arg in args], "env": self.env}, - headers=self.headers, - ) - response.raise_for_status() - response_json = response.json() - - if self.async_run: - return InfoArtifact("Run started successfully") - else: - return self._get_structure_run_result(response_json["structure_run_id"]) - except (exceptions.RequestException, HTTPError) as err: - return ErrorArtifact(str(err)) - - def _get_structure_run_result(self, structure_run_id: str) -> InfoArtifact | BaseArtifact | ErrorArtifact: + response: Response = post( + url, + json={"args": [arg.value for arg in args], "env": self.env}, + headers=self.headers, + ) + response.raise_for_status() + response_json = response.json() + + if self.async_run: + return InfoArtifact("Run started successfully") + else: + return self._get_structure_run_result(response_json["structure_run_id"]) + + def _get_structure_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact: url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}") result = self._get_structure_run_result_attempt(url) @@ -59,12 +56,10 @@ def _get_structure_run_result(self, structure_run_id: str) -> InfoArtifact | Bas status = result["status"] if wait_attempts >= self.structure_run_max_wait_time_attempts: - return ErrorArtifact( - f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts.", - ) + raise Exception(f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts.") if status != "SUCCEEDED": - return ErrorArtifact(result) + raise Exception(f"Run failed with status: {status}") if "output" in result: return BaseArtifact.from_dict(result["output"]) diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index fb1fab6c47..d3a50585db 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -9,7 +9,7 @@ from griptape.configs import Defaults if TYPE_CHECKING: - from griptape.artifacts import ErrorArtifact, ListArtifact + from griptape.artifacts import ListArtifact from griptape.drivers import BasePromptDriver from griptape.rules import Ruleset @@ -54,4 +54,4 @@ def extract( *, rulesets: Optional[list[Ruleset]] = None, **kwargs, - ) -> ListArtifact | ErrorArtifact: ... + ) -> ListArtifact: ... diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index c9c040f651..b45bdf7f5e 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import CsvRowArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 @@ -27,17 +27,14 @@ def extract( *, rulesets: Optional[list[Ruleset]] = None, **kwargs, - ) -> ListArtifact | ErrorArtifact: - try: - return ListArtifact( - self._extract_rec( - cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], - [], - ), - item_separator="\n", - ) - except Exception as e: - return ErrorArtifact(f"error extracting CSV rows: {e}") + ) -> ListArtifact: + return ListArtifact( + self._extract_rec( + cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], + [], + ), + item_separator="\n", + ) def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]: rows = [] diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 8f2f4a3fed..a4cd3a4386 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import PromptStack from griptape.common.prompt_stack.messages.message import Message from griptape.engines import BaseExtractionEngine @@ -32,18 +32,15 @@ def extract( *, rulesets: Optional[list[Ruleset]] = None, **kwargs, - ) -> ListArtifact | ErrorArtifact: - try: - return ListArtifact( - self._extract_rec( - cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], - [], - rulesets=rulesets, - ), - item_separator="\n", - ) - except Exception as e: - return ErrorArtifact(f"error extracting JSON: {e}") + ) -> ListArtifact: + return ListArtifact( + self._extract_rec( + cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], + [], + rulesets=rulesets, + ), + item_separator="\n", + ) def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index 4f53cc5f9a..7e4854d00c 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -6,12 +6,12 @@ from attrs import Factory, define, field from griptape import utils -from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: from collections.abc import Sequence + from griptape.artifacts import TextArtifact from griptape.drivers import BaseVectorStoreDriver from griptape.engines.rag import RagContext from griptape.loaders import BaseTextLoader @@ -38,8 +38,6 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: loader_output = self.loader.load(source) - if isinstance(loader_output, ErrorArtifact): - raise Exception(loader_output.to_text() if loader_output.exception is None else loader_output.exception) self.vector_store_driver.upsert_text_artifacts({namespace: loader_output}) return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/loaders/base_text_loader.py b/griptape/loaders/base_text_loader.py index 369f3f1fc8..196cb0087c 100644 --- a/griptape/loaders/base_text_loader.py +++ b/griptape/loaders/base_text_loader.py @@ -1,12 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.loaders import BaseLoader from griptape.tokenizers import OpenAiTokenizer @@ -40,11 +39,11 @@ class BaseTextLoader(BaseLoader, ABC): reference: Optional[Reference] = field(default=None, kw_only=True) @abstractmethod - def load(self, source: Any, *args, **kwargs) -> ErrorArtifact | list[TextArtifact]: ... + def load(self, source: Any, *args, **kwargs) -> list[TextArtifact]: ... - def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, ErrorArtifact | list[TextArtifact]]: + def load_collection(self, sources: list[Any], *args, **kwargs) -> dict[str, list[TextArtifact]]: return cast( - dict[str, Union[ErrorArtifact, list[TextArtifact]]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index fffabb8498..d0099b47bc 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -1,20 +1,20 @@ from __future__ import annotations -from typing import Any, Union, cast +from typing import Any, cast from attrs import define -from griptape.artifacts import BlobArtifact, ErrorArtifact +from griptape.artifacts import BlobArtifact from griptape.loaders import BaseLoader @define class BlobLoader(BaseLoader): - def load(self, source: Any, *args, **kwargs) -> BlobArtifact | ErrorArtifact: + def load(self, source: Any, *args, **kwargs) -> BlobArtifact: if self.encoding is None: return BlobArtifact(source) else: return BlobArtifact(source, encoding=self.encoding) - def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact | ErrorArtifact]: - return cast(dict[str, Union[BlobArtifact, ErrorArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[bytes | str], *args, **kwargs) -> dict[str, BlobArtifact]: + return cast(dict[str, BlobArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index dc73ca52c4..14dfe3e4a6 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -2,11 +2,11 @@ import csv from io import StringIO -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, cast from attrs import define, field -from griptape.artifacts import CsvRowArtifact, ErrorArtifact +from griptape.artifacts import CsvRowArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -19,16 +19,13 @@ class CsvLoader(BaseLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> ErrorArtifact | list[CsvRowArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[CsvRowArtifact]: artifacts = [] if isinstance(source, bytes): - try: - source = source.decode(encoding=self.encoding) - except UnicodeDecodeError: - return ErrorArtifact(f"Failed to decode bytes to string using encoding: {self.encoding}") + source = source.decode(encoding=self.encoding) elif isinstance(source, (bytearray, memoryview)): - return ErrorArtifact(f"Unsupported source type: {type(source)}") + raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) chunks = [CsvRowArtifact(row) for row in reader] @@ -47,8 +44,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, ErrorArtifact | list[CsvRowArtifact]]: + ) -> dict[str, list[CsvRowArtifact]]: return cast( - dict[str, Union[ErrorArtifact, list[CsvRowArtifact]]], + dict[str, list[CsvRowArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index 82f34bd8ae..f6c9ca4062 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -1,12 +1,11 @@ from __future__ import annotations import imaplib -import logging -from typing import Optional, Union, cast +from typing import Optional, cast from attrs import astuple, define, field -from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency @@ -33,50 +32,46 @@ class EmailQuery: username: str = field(kw_only=True) password: str = field(kw_only=True) - def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact | ErrorArtifact: + def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact: mailparser = import_optional_dependency("mailparser") label, key, search_criteria, max_count = astuple(source) artifacts = [] - try: - with imaplib.IMAP4_SSL(self.imap_url) as client: - client.login(self.username, self.password) + with imaplib.IMAP4_SSL(self.imap_url) as client: + client.login(self.username, self.password) - mailbox = client.select(f'"{label}"', readonly=True) - if mailbox[0] != "OK": - raise Exception(mailbox[1][0].decode()) + mailbox = client.select(f'"{label}"', readonly=True) + if mailbox[0] != "OK": + raise Exception(mailbox[1][0].decode()) - if key and search_criteria: - _typ, [message_numbers] = client.search(None, key, f'"{search_criteria}"') - messages_count = self._count_messages(message_numbers) - elif len(mailbox) > 1 and mailbox[1] and mailbox[1][0] is not None: - messages_count = int(mailbox[1][0]) - else: - raise Exception("unable to parse number of messages") + if key and search_criteria: + _typ, [message_numbers] = client.search(None, key, f'"{search_criteria}"') + messages_count = self._count_messages(message_numbers) + elif len(mailbox) > 1 and mailbox[1] and mailbox[1][0] is not None: + messages_count = int(mailbox[1][0]) + else: + raise Exception("unable to parse number of messages") - top_n = max(0, messages_count - max_count) if max_count else 0 - for i in range(messages_count, top_n, -1): - _result, data = client.fetch(str(i), "(RFC822)") + top_n = max(0, messages_count - max_count) if max_count else 0 + for i in range(messages_count, top_n, -1): + _result, data = client.fetch(str(i), "(RFC822)") - if data is None or not data or data[0] is None: - continue + if data is None or not data or data[0] is None: + continue - message = mailparser.parse_from_bytes(data[0][1]) + message = mailparser.parse_from_bytes(data[0][1]) - # Note: mailparser only populates the text_plain field - # if the message content type is explicitly set to 'text/plain'. - if message.text_plain: - artifacts.append(TextArtifact("\n".join(message.text_plain))) + # Note: mailparser only populates the text_plain field + # if the message content type is explicitly set to 'text/plain'. + if message.text_plain: + artifacts.append(TextArtifact("\n".join(message.text_plain))) - client.close() + client.close() - return ListArtifact(artifacts) - except Exception as e: - logging.error(e) - return ErrorArtifact(f"error retrieving email: {e}") + return ListArtifact(artifacts) def _count_messages(self, message_numbers: bytes) -> int: return len(list(filter(None, message_numbers.decode().split(" ")))) - def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact | ErrorArtifact]: - return cast(dict[str, Union[ListArtifact, ErrorArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[EmailQuery], *args, **kwargs) -> dict[str, ListArtifact]: + return cast(dict[str, ListArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index b38e2cd77e..419bfabf4b 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,12 +1,11 @@ from __future__ import annotations from io import BytesIO -from typing import Optional, Union, cast +from typing import Optional, cast from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.chunkers import PdfChunker from griptape.loaders import BaseTextLoader from griptape.utils import import_optional_dependency @@ -26,13 +25,13 @@ def load( password: Optional[str] = None, *args, **kwargs, - ) -> ErrorArtifact | list[TextArtifact]: + ) -> list[TextArtifact]: pypdf = import_optional_dependency("pypdf") reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password) return self._text_to_artifacts("\n".join([p.extract_text() for p in reader.pages])) - def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ErrorArtifact | list[TextArtifact]]: + def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, list[TextArtifact]]: return cast( - dict[str, Union[ErrorArtifact, list[TextArtifact]]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index e356a2cdbd..79e551a8e3 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, cast from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.chunkers import TextChunker from griptape.loaders import BaseTextLoader from griptape.tokenizers import OpenAiTokenizer @@ -36,14 +35,11 @@ class TextLoader(BaseTextLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> ErrorArtifact | list[TextArtifact]: + def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: if isinstance(source, bytes): - try: - source = source.decode(encoding=self.encoding) - except UnicodeDecodeError: - return ErrorArtifact(f"Failed to decode bytes to string using encoding: {self.encoding}") + source = source.decode(encoding=self.encoding) elif isinstance(source, (bytearray, memoryview)): - return ErrorArtifact(f"Unsupported source type: {type(source)}") + raise ValueError(f"Unsupported source type: {type(source)}") return self._text_to_artifacts(source) @@ -52,8 +48,8 @@ def load_collection( sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, ErrorArtifact | list[TextArtifact]]: + ) -> dict[str, list[TextArtifact]]: return cast( - dict[str, Union[ErrorArtifact, list[TextArtifact]]], + dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index 3798f94886..720ab34a19 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -4,7 +4,6 @@ from attrs import Factory, define, field -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver from griptape.loaders import BaseTextLoader @@ -19,9 +18,6 @@ class WebLoader(BaseTextLoader): kw_only=True, ) - def load(self, source: str, *args, **kwargs) -> ErrorArtifact | list[TextArtifact]: - try: - single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) - return self._text_to_artifacts(single_chunk_text_artifact.value) - except Exception as e: - return ErrorArtifact(f"Error loading from source: {source}", exception=e) + def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: + single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) + return self._text_to_artifacts(single_chunk_text_artifact.value) diff --git a/griptape/tasks/code_execution_task.py b/griptape/tasks/code_execution_task.py index 68e0d66ada..4c5113a9ac 100644 --- a/griptape/tasks/code_execution_task.py +++ b/griptape/tasks/code_execution_task.py @@ -16,4 +16,4 @@ def run(self) -> BaseArtifact: try: return self.run_fn(self) except Exception as e: - return ErrorArtifact(f"error during Code Execution Task: {e}") + return ErrorArtifact(f"error during Code Execution Task: {e}", exception=e) diff --git a/griptape/tools/file_manager/tool.py b/griptape/tools/file_manager/tool.py index ece6a0e923..2ca14d565d 100644 --- a/griptape/tools/file_manager/tool.py +++ b/griptape/tools/file_manager/tool.py @@ -93,9 +93,16 @@ def save_memory_artifacts_to_disk(self, params: dict) -> ErrorArtifact | InfoArt for artifact in list_artifact.value: formatted_file_name = f"{artifact.name}-{file_name}" if len(list_artifact) > 1 else file_name - result = self.file_manager_driver.save_file(os.path.join(dir_name, formatted_file_name), artifact.value) - if isinstance(result, ErrorArtifact): - return result + try: + self.file_manager_driver.save_file(os.path.join(dir_name, formatted_file_name), artifact.value) + except FileNotFoundError: + return ErrorArtifact("Path not found") + except IsADirectoryError: + return ErrorArtifact("Path is a directory") + except NotADirectoryError: + return ErrorArtifact("Not a directory") + except Exception as e: + return ErrorArtifact(f"Failed to load file: {str(e)}") return InfoArtifact("Successfully saved memory artifacts to disk") diff --git a/griptape/tools/variation_image_generation/tool.py b/griptape/tools/variation_image_generation/tool.py index 9691f62068..0d4456c2fb 100644 --- a/griptape/tools/variation_image_generation/tool.py +++ b/griptape/tools/variation_image_generation/tool.py @@ -51,9 +51,6 @@ def image_variation_from_file(self, params: dict[str, dict[str, str]]) -> ImageA image_artifact = self.image_loader.load(Path(image_file).read_bytes()) - if isinstance(image_artifact, ErrorArtifact): - return image_artifact - return self._generate_variation(prompt, negative_prompt, image_artifact) @activity( diff --git a/griptape/tools/web_scraper/tool.py b/griptape/tools/web_scraper/tool.py index c27aaa0660..2895d5e0dd 100644 --- a/griptape/tools/web_scraper/tool.py +++ b/griptape/tools/web_scraper/tool.py @@ -24,9 +24,6 @@ def get_content(self, params: dict) -> ListArtifact | ErrorArtifact: try: result = self.web_loader.load(url) - if isinstance(result, ErrorArtifact): - return result - else: - return ListArtifact(result) + return ListArtifact(result) except Exception as e: return ErrorArtifact("Error getting page content: " + str(e)) diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index e3ec78eeb3..84ce617685 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -5,7 +5,7 @@ import pytest from moto import mock_s3 -from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import AmazonS3FileManagerDriver from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -135,23 +135,21 @@ def test_list_files(self, workdir, path, expected, driver): ("workdir", "path", "expected"), [ # non-existent paths - ("/", "bar", "Path not found"), - ("/", "bar/", "Path not found"), - ("/", "bitcoin.pdf", "Path not found"), + ("/", "bar", FileNotFoundError), + ("/", "bar/", FileNotFoundError), + ("/", "bitcoin.pdf", FileNotFoundError), # # paths to files (not directories) - ("/", "foo.txt", "Path is not a directory"), - ("/", "/foo.txt", "Path is not a directory"), - ("/resources", "bitcoin.pdf", "Path is not a directory"), - ("/resources", "/bitcoin.pdf", "Path is not a directory"), + ("/", "foo.txt", NotADirectoryError), + ("/", "/foo.txt", NotADirectoryError), + ("/resources", "bitcoin.pdf", NotADirectoryError), + ("/resources", "/bitcoin.pdf", NotADirectoryError), ], ) def test_list_files_failure(self, workdir, path, expected, driver): driver.workdir = workdir - artifact = driver.list_files(path) - - assert isinstance(artifact, ErrorArtifact) - assert artifact.value == expected + with pytest.raises(expected): + driver.list_files(path) def test_load_file(self, driver): artifact = driver.load_file("resources/bitcoin.pdf") @@ -163,28 +161,26 @@ def test_load_file(self, driver): ("workdir", "path", "expected"), [ # non-existent files or directories - ("/", "bitcoin.pdf", "Path not found"), - ("/resources", "foo.txt", "Path not found"), - ("/", "bar/", "Path is a directory"), + ("/", "bitcoin.pdf", FileNotFoundError), + ("/resources", "foo.txt", FileNotFoundError), + ("/", "bar/", IsADirectoryError), # existing files with trailing slash - ("/", "resources/bitcoin.pdf/", "Path is a directory"), - ("/resources", "bitcoin.pdf/", "Path is a directory"), + ("/", "resources/bitcoin.pdf/", IsADirectoryError), + ("/resources", "bitcoin.pdf/", IsADirectoryError), # directories -- not files - ("/", "", "Path is a directory"), - ("/", "/", "Path is a directory"), - ("/", "resources", "Path is a directory"), - ("/", "resources/", "Path is a directory"), - ("/resources", "", "Path is a directory"), - ("/resources", "/", "Path is a directory"), + ("/", "", IsADirectoryError), + ("/", "/", IsADirectoryError), + ("/", "resources", IsADirectoryError), + ("/", "resources/", IsADirectoryError), + ("/resources", "", IsADirectoryError), + ("/resources", "/", IsADirectoryError), ], ) def test_load_file_failure(self, workdir, path, expected, driver): driver.workdir = workdir - artifact = driver.load_file(path) - - assert isinstance(artifact, ErrorArtifact) - assert artifact.value == expected + with pytest.raises(expected): + driver.load_file(path) def test_load_file_with_encoding(self, driver): artifact = driver.load_file("resources/test.txt") @@ -193,15 +189,6 @@ def test_load_file_with_encoding(self, driver): assert len(artifact.value) == 1 assert isinstance(artifact.value[0], TextArtifact) - def test_load_file_with_encoding_failure(self, session, bucket): - driver = AmazonS3FileManagerDriver( - session=session, bucket=bucket, default_loader=TextLoader(encoding="utf-8"), loaders={} - ) - - artifact = driver.load_file("resources/bitcoin.pdf") - - assert isinstance(artifact, ErrorArtifact) - @pytest.mark.parametrize( ("workdir", "path", "content"), [ @@ -231,27 +218,25 @@ def test_save_file(self, workdir, path, content, driver, get_s3_value): ("workdir", "path", "expected"), [ # non-existent directories - ("/", "bar/", "Path is a directory"), - ("/", "/bar/", "Path is a directory"), + ("/", "bar/", IsADirectoryError), + ("/", "/bar/", IsADirectoryError), # # existing directories - ("/", "", "Path is a directory"), - ("/", "/", "Path is a directory"), - ("/", "resources", "Path is a directory"), - ("/", "resources/", "Path is a directory"), - ("/resources", "", "Path is a directory"), - ("/resources", "/", "Path is a directory"), + ("/", "", IsADirectoryError), + ("/", "/", IsADirectoryError), + ("/", "resources", IsADirectoryError), + ("/", "resources/", IsADirectoryError), + ("/resources", "", IsADirectoryError), + ("/resources", "/", IsADirectoryError), # existing files with trailing slash - ("/", "resources/bitcoin.pdf/", "Path is a directory"), - ("/resources", "bitcoin.pdf/", "Path is a directory"), + ("/", "resources/bitcoin.pdf/", IsADirectoryError), + ("/resources", "bitcoin.pdf/", IsADirectoryError), ], ) def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_client, bucket): driver.workdir = workdir - artifact = driver.save_file(path, "foobar") - - assert isinstance(artifact, ErrorArtifact) - assert artifact.value == expected + with pytest.raises(expected): + driver.save_file(path, "foobar") def test_save_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index a7c244f099..394a838a3f 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -4,7 +4,7 @@ import pytest -from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader @@ -107,24 +107,22 @@ def test_list_files(self, workdir, path, expected, temp_dir, driver): ("workdir", "path", "expected"), [ # non-existent paths - ("/", "bar", "Path not found"), - ("/", "bar/", "Path not found"), - ("/", "bitcoin.pdf", "Path not found"), + ("/", "bar", FileNotFoundError), + ("/", "bar/", FileNotFoundError), + ("/", "bitcoin.pdf", FileNotFoundError), # # paths to files (not directories) - ("/", "foo.txt", "Path is not a directory"), - ("/", "/foo.txt", "Path is not a directory"), - ("/resources", "bitcoin.pdf", "Path is not a directory"), - ("/resources", "/bitcoin.pdf", "Path is not a directory"), + ("/", "foo.txt", NotADirectoryError), + ("/", "/foo.txt", NotADirectoryError), + ("/resources", "bitcoin.pdf", NotADirectoryError), + ("/resources", "/bitcoin.pdf", NotADirectoryError), ], ) def test_list_files_failure(self, workdir, path, expected, temp_dir, driver): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. driver.workdir = self._to_driver_workdir(temp_dir, workdir) - artifact = driver.list_files(path) - - assert isinstance(artifact, ErrorArtifact) - assert artifact.value == expected + with pytest.raises(expected): + driver.list_files(path) def test_load_file(self, driver: LocalFileManagerDriver): artifact = driver.load_file("resources/bitcoin.pdf") @@ -136,29 +134,27 @@ def test_load_file(self, driver: LocalFileManagerDriver): ("workdir", "path", "expected"), [ # # non-existent files or directories - ("/", "bitcoin.pdf", "Path not found"), - ("/resources", "foo.txt", "Path not found"), - ("/", "bar/", "Path is a directory"), + ("/", "bitcoin.pdf", FileNotFoundError), + ("/resources", "foo.txt", FileNotFoundError), + ("/", "bar/", IsADirectoryError), # existing files with trailing slash - ("/", "resources/bitcoin.pdf/", "Path is a directory"), - ("/resources", "bitcoin.pdf/", "Path is a directory"), + ("/", "resources/bitcoin.pdf/", IsADirectoryError), + ("/resources", "bitcoin.pdf/", IsADirectoryError), # directories -- not files - ("/", "", "Path is a directory"), - ("/", "/", "Path is a directory"), - ("/", "resources", "Path is a directory"), - ("/", "resources/", "Path is a directory"), - ("/resources", "", "Path is a directory"), - ("/resources", "/", "Path is a directory"), + ("/", "", IsADirectoryError), + ("/", "/", IsADirectoryError), + ("/", "resources", IsADirectoryError), + ("/", "resources/", IsADirectoryError), + ("/resources", "", IsADirectoryError), + ("/resources", "/", IsADirectoryError), ], ) def test_load_file_failure(self, workdir, path, expected, temp_dir, driver): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. driver.workdir = self._to_driver_workdir(temp_dir, workdir) - artifact = driver.load_file(path) - - assert isinstance(artifact, ErrorArtifact) - assert artifact.value == expected + with pytest.raises(expected): + driver.load_file(path) def test_load_file_with_encoding(self, driver: LocalFileManagerDriver): artifact = driver.load_file("resources/test.txt") @@ -167,14 +163,15 @@ def test_load_file_with_encoding(self, driver: LocalFileManagerDriver): assert len(artifact.value) == 1 assert isinstance(artifact.value[0], TextArtifact) - def test_load_file_with_encoding_failure(self): + def test_load_file_with_encoding_failure(self, driver): driver = LocalFileManagerDriver( - default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=os.path.abspath(os.path.dirname(__file__)) + default_loader=TextLoader(encoding="utf-8"), + loaders={}, + workdir=os.path.normpath(os.path.abspath(os.path.dirname(__file__) + "../../../../")), ) - artifact = driver.load_file("resources/bitcoin.pdf") - - assert isinstance(artifact, ErrorArtifact) + with pytest.raises(UnicodeDecodeError): + driver.load_file("resources/bitcoin.pdf") @pytest.mark.parametrize( ("workdir", "path", "content"), @@ -205,28 +202,26 @@ def test_save_file(self, workdir, path, content, temp_dir, driver): ("workdir", "path", "expected"), [ # non-existent directories - ("/", "bar/", "Path is a directory"), - ("/", "/bar/", "Path is a directory"), + ("/", "bar/", IsADirectoryError), + ("/", "/bar/", IsADirectoryError), # existing directories - ("/", "", "Path is a directory"), - ("/", "/", "Path is a directory"), - ("/", "resources", "Path is a directory"), - ("/", "resources/", "Path is a directory"), - ("/resources", "", "Path is a directory"), - ("/resources", "/", "Path is a directory"), + ("/", "", IsADirectoryError), + ("/", "/", IsADirectoryError), + ("/", "resources", IsADirectoryError), + ("/", "resources/", IsADirectoryError), + ("/resources", "", IsADirectoryError), + ("/resources", "/", IsADirectoryError), # existing files with trailing slash - ("/", "resources/bitcoin.pdf/", "Path is a directory"), - ("/resources", "bitcoin.pdf/", "Path is a directory"), + ("/", "resources/bitcoin.pdf/", IsADirectoryError), + ("/resources", "bitcoin.pdf/", IsADirectoryError), ], ) def test_save_file_failure(self, workdir, path, expected, temp_dir, driver): # Treat the workdir as an absolute path, but modify it to be relative to the temp_dir. driver.workdir = self._to_driver_workdir(temp_dir, workdir) - artifact = driver.save_file(path, "foobar") - - assert isinstance(artifact, ErrorArtifact) - assert artifact.value == expected + with pytest.raises(expected): + driver.save_file(path, "foobar") def test_save_file_with_encoding(self, temp_dir): driver = LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), loaders={}, workdir=temp_dir) diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index 48430f1e54..9d6442579e 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -1,7 +1,6 @@ import pytest from schema import Schema -from griptape.artifacts import ErrorArtifact from griptape.engines import JsonExtractionEngine from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -25,7 +24,9 @@ def test_extract(self, engine): def test_extract_error(self, engine): engine.template_schema = lambda: "non serializable" - assert isinstance(engine.extract("foo"), ErrorArtifact) + + with pytest.raises(TypeError): + engine.extract("foo") def test_json_to_text_artifacts(self, engine): assert [ diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index f1e057453e..ade0627435 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -6,7 +6,7 @@ import pytest -from griptape.artifacts import ErrorArtifact, ListArtifact +from griptape.artifacts import ListArtifact from griptape.loaders import EmailLoader @@ -79,20 +79,16 @@ def test_load_returns_error_artifact_when_select_returns_non_ok(self, loader, mo mock_select.return_value = (None, [b"NOT-OK"]) # When - artifact = loader.load(EmailLoader.EmailQuery(label="INBOX")) - - # Then - assert isinstance(artifact, ErrorArtifact) + with pytest.raises(Exception, match="NOT-OK"): + loader.load(EmailLoader.EmailQuery(label="INBOX")) def test_load_returns_error_artifact_when_login_throws(self, loader, mock_login): # Given mock_login.side_effect = Exception("login-failed") # When - artifact = loader.load(EmailLoader.EmailQuery(label="INBOX")) - - # Then - assert isinstance(artifact, ErrorArtifact) + with pytest.raises(Exception, match="login-failed"): + loader.load(EmailLoader.EmailQuery(label="INBOX")) def test_load_collection(self, loader, mock_fetch): # Given diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index f264ce667a..f7cccb6664 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,6 +1,5 @@ import pytest -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.loaders import WebLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -27,10 +26,8 @@ def test_load(self, loader): def test_load_exception(self, mocker, loader): mocker.patch("trafilatura.fetch_url", side_effect=Exception("error")) source = "https://github.com/griptape-ai/griptape" - artifact = loader.load(source) - - assert isinstance(artifact, ErrorArtifact) - assert f"Error loading from source: {source}" == artifact.value + with pytest.raises(Exception, match="error"): + loader.load(source) def test_load_collection(self, loader): artifacts = loader.load_collection( @@ -48,13 +45,11 @@ def test_load_collection(self, loader): def test_empty_page_string_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value="") - artifact = loader.load("https://example.com/") - assert isinstance(artifact, ErrorArtifact) - assert str(artifact.exception) == "can't extract page" + with pytest.raises(Exception, match="can't extract page"): + loader.load("https://example.com/") def test_empty_page_none_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value=None) - artifact = loader.load("https://example.com/") - assert isinstance(artifact, ErrorArtifact) - assert str(artifact.exception) == "can't extract page" + with pytest.raises(Exception, match="can't extract page"): + loader.load("https://example.com/") diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index dccf2f1a2d..469918a025 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -6,7 +6,6 @@ import pytest from griptape.artifacts import ListArtifact, TextArtifact -from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader from griptape.tools import FileManagerTool @@ -55,9 +54,8 @@ def test_load_files_from_disk_with_encoding_failure(self): ) ) - result = file_manager.load_files_from_disk({"values": {"paths": ["../../resources/bitcoin.pdf"]}}) - - assert isinstance(result.value[0], ErrorArtifact) + with pytest.raises(UnicodeDecodeError): + file_manager.load_files_from_disk({"values": {"paths": ["../../resources/bitcoin.pdf"]}}) def test_save_memory_artifacts_to_disk_for_one_artifact(self, temp_dir): memory = defaults.text_task_memory("Memory1")