diff --git a/Makefile b/Makefile index 13410094c..d14c7184d 100644 --- a/Makefile +++ b/Makefile @@ -33,22 +33,22 @@ test/integration: .PHONY: lint lint: ## Lint project. - @poetry run ruff check --fix griptape/ + @poetry run ruff check --fix .PHONY: format format: ## Format project. - @poetry run ruff format . + @poetry run ruff format .PHONY: check check: check/format check/lint check/types check/spell ## Run all checks. .PHONY: check/format check/format: - @poetry run ruff format --check griptape/ + @poetry run ruff format --check .PHONY: check/lint check/lint: - @poetry run ruff check griptape/ + @poetry run ruff check .PHONY: check/types check/types: diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 62fe85b0c..85a3f3127 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -1,11 +1,12 @@ """Generate the code reference pages and navigation.""" -from textwrap import dedent from pathlib import Path +from textwrap import dedent + import mkdocs_gen_files -def build_reference_docs(): +def build_reference_docs() -> None: nav = mkdocs_gen_files.Nav() for path in sorted(Path("griptape").rglob("*.py")): @@ -37,8 +38,8 @@ def build_reference_docs(): index_file.write( dedent( """ - # Overview - This section of the documentation is dedicated to a reference API of Griptape. + # Overview + This section of the documentation is dedicated to a reference API of Griptape. Here you will find every class, function, and method that is available to you when using the library. """ ) diff --git a/docs/plugins/swagger_ui_plugin.py b/docs/plugins/swagger_ui_plugin.py index 6d5fb52da..499d74cf5 100644 --- a/docs/plugins/swagger_ui_plugin.py +++ b/docs/plugins/swagger_ui_plugin.py @@ -1,4 +1,5 @@ import os +from typing import Any import markdown from jinja2 import Environment, FileSystemLoader, select_autoescape @@ -11,7 +12,7 @@ } -def generate_page_contents(page): +def generate_page_contents(page: Any) -> str: spec_url = config_scheme["spec_url"] tmpl_url = config_scheme["template"] env = Environment(loader=FileSystemLoader("docs/plugins/tmpl"), autoescape=select_autoescape(["html", "xml"])) @@ -23,11 +24,11 @@ def generate_page_contents(page): return tmpl_out -def on_config(config): - print("INFO - swagger-ui plugin ENABLED") +def on_config(config: Any) -> None: + pass -def on_page_read_source(page, config): +def on_page_read_source(page: Any, config: Any) -> Any: index_path = os.path.join(config["docs_dir"], config_scheme["outfile"]) page_path = os.path.join(config["docs_dir"], page.file.src_path) if index_path == page_path: diff --git a/pyproject.toml b/pyproject.toml index 4f9c3b49d..807a2560e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,6 +217,7 @@ select = [ "C4", # flake8-comprehensions "ANN", # flake8-annotations "FBT", # flake8-boolean-trap + "PT", # flake8-pytest-style ] ignore = [ "UP007", # non-pep604-annotation @@ -238,12 +239,20 @@ ignore = [ "ANN101", # missing-type-self "ANN102", # missing-type-cls "ANN401", # any-type + "PT011", # pytest-raises-too-broad ] [tool.ruff.lint.pydocstyle] convention = "google" [tool.ruff.lint.per-file-ignores] -"__init__.py" = ["I"] +"__init__.py" = [ + "I" # isort +] +"tests/*" = [ + "ANN001", # missing-type-function-argument + "ANN201", # missing-return-type-undocumented-public-function + "ANN202", # missing-return-type-private-function +] [tool.ruff.lint.flake8-tidy-imports.banned-api] "attr".msg = "The attr module is deprecated, use attrs instead." diff --git a/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py b/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py index e0345cecb..e592fce06 100644 --- a/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/integration/drivers/vector/test_pgvector_vector_store_driver.py @@ -1,9 +1,11 @@ import uuid + import pytest +from sqlalchemy import create_engine + from griptape.drivers import PgVectorVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.utils.postgres import can_connect_to_postgres -from sqlalchemy import create_engine @pytest.mark.skipif(not can_connect_to_postgres(), reason="Postgres is not present") @@ -13,11 +15,11 @@ class TestPgVectorVectorStoreDriver: vec1 = [0.1, 0.2, 0.3] vec2 = [0.4, 0.5, 0.6] - @pytest.fixture + @pytest.fixture() def embedding_driver(self): return MockEmbeddingDriver() - @pytest.fixture + @pytest.fixture() def vector_store_driver(self, embedding_driver): driver = PgVectorVectorStoreDriver( connection_string=self.connection_string, embedding_driver=embedding_driver, table_name=self.table_name @@ -28,13 +30,13 @@ def vector_store_driver(self, embedding_driver): return driver def test_initialize_requires_engine_or_connection_string(self, embedding_driver): + driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) with pytest.raises(ValueError): - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) driver.setup() def test_initialize_accepts_engine(self, embedding_driver): engine = create_engine(self.connection_string) - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) + driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) # pyright: ignore[reportArgumentType] driver.setup() @@ -86,11 +88,9 @@ def test_can_insert_and_load_entry_with_namespace(self, vector_store_driver): assert result.vector == pytest.approx(self.vec1) def test_can_load_entries(self, vector_store_driver): - """ - Depending on when this test is executed relative to the others, - we don't know exactly how many vectors will be returned. We can - ensure that at least two exist and confirm that those are found. - """ + # Depending on when this test is executed relative to the others, + # we don't know exactly how many vectors will be returned. We can + # ensure that at least two exist and confirm that those are found. vec1_id = vector_store_driver.upsert_vector(self.vec1) vec2_id = vector_store_driver.upsert_vector(self.vec2) @@ -173,8 +173,9 @@ def test_query_returns_vectors_when_requested(self, vector_store_driver): assert results[0].vector == pytest.approx(embedding) def test_can_use_custom_table_name(self, embedding_driver, vector_store_driver): - """This test ensures at least one row exists in the default table before specifying - a custom table name. After inserting another row, we should be able to query only one + """This test ensures at least one row exists in the default table before specifying a custom table name. + + After inserting another row, we should be able to query only one vector from the table, and it should be the vector added to the table with the new name. """ vector_store_driver.upsert_vector(self.vec1) diff --git a/tests/integration/rules/test_rule.py b/tests/integration/rules/test_rule.py index f04996040..a62263c57 100644 --- a/tests/integration/rules/test_rule.py +++ b/tests/integration/rules/test_rule.py @@ -1,14 +1,15 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestRule: @pytest.fixture( autouse=True, params=StructureTester.RULE_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn ) def structure_tester(self, request): - from griptape.structures import Agent from griptape.rules import Rule + from griptape.structures import Agent agent = Agent(prompt_driver=request.param, rules=[Rule("Your name is Tony.")]) diff --git a/tests/integration/tasks/test_csv_extraction_task.py b/tests/integration/tasks/test_csv_extraction_task.py index 4624431ca..db58b9615 100644 --- a/tests/integration/tasks/test_csv_extraction_task.py +++ b/tests/integration/tasks/test_csv_extraction_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestCsvExtractionTask: @pytest.fixture( @@ -9,9 +10,9 @@ class TestCsvExtractionTask: ids=StructureTester.prompt_driver_id_fn, ) def structure_tester(self, request): - from griptape.tasks import ExtractionTask - from griptape.structures import Agent from griptape.engines import CsvExtractionEngine + from griptape.structures import Agent + from griptape.tasks import ExtractionTask columns = ["Name", "Age", "Address"] diff --git a/tests/integration/tasks/test_json_extraction_task.py b/tests/integration/tasks/test_json_extraction_task.py index fdd7140f3..115f805da 100644 --- a/tests/integration/tasks/test_json_extraction_task.py +++ b/tests/integration/tasks/test_json_extraction_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestJsonExtractionTask: @pytest.fixture( @@ -9,11 +10,12 @@ class TestJsonExtractionTask: ids=StructureTester.prompt_driver_id_fn, ) def structure_tester(self, request): - from griptape.tasks import ExtractionTask - from griptape.structures import Agent - from griptape.engines import JsonExtractionEngine from schema import Schema + from griptape.engines import JsonExtractionEngine + from griptape.structures import Agent + from griptape.tasks import ExtractionTask + # Define some JSON data user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema") diff --git a/tests/integration/tasks/test_prompt_task.py b/tests/integration/tasks/test_prompt_task.py index 6734df678..1d223b4ca 100644 --- a/tests/integration/tasks/test_prompt_task.py +++ b/tests/integration/tasks/test_prompt_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestPromptTask: @pytest.fixture( diff --git a/tests/integration/tasks/test_rag_task.py b/tests/integration/tasks/test_rag_task.py index c0383002c..ce3a9140d 100644 --- a/tests/integration/tasks/test_rag_task.py +++ b/tests/integration/tasks/test_rag_task.py @@ -1,7 +1,8 @@ +import pytest + from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.defaults import rag_engine from tests.utils.structure_tester import StructureTester -import pytest class TestRagTask: @@ -11,10 +12,10 @@ class TestRagTask: ids=StructureTester.prompt_driver_id_fn, ) def structure_tester(self, request): + from griptape.artifacts import TextArtifact + from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.structures import Agent from griptape.tasks import RagTask - from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver - from griptape.artifacts import TextArtifact vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) artifact = TextArtifact("John Doe works as as software engineer at Griptape.") diff --git a/tests/integration/tasks/test_text_summary_task.py b/tests/integration/tasks/test_text_summary_task.py index 9cbf1d905..ff6597ba0 100644 --- a/tests/integration/tasks/test_text_summary_task.py +++ b/tests/integration/tasks/test_text_summary_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestTextSummaryTask: @pytest.fixture( @@ -10,8 +11,8 @@ class TestTextSummaryTask: ) def structure_tester(self, request): from griptape.engines.summary.prompt_summary_engine import PromptSummaryEngine - from griptape.tasks import TextSummaryTask from griptape.structures import Agent + from griptape.tasks import TextSummaryTask agent = Agent(conversation_memory=None, prompt_driver=request.param) agent.add_task(TextSummaryTask(summary_engine=PromptSummaryEngine(prompt_driver=request.param))) @@ -21,17 +22,17 @@ def structure_tester(self, request): def test_summary_task(self, structure_tester): structure_tester.run( """ - Meeting transcriot: - Miguel: Hi Brant, I want to discuss the workstream for our new product launch - Brant: Sure Miguel, is there anything in particular you want to discuss? - Miguel: Yes, I want to talk about how users enter into the product. - Brant: Ok, in that case let me add in Namita. - Namita: Hey everyone - Brant: Hi Namita, Miguel wants to discuss how users enter into the product. - Miguel: its too complicated and we should remove friction. for example, why do I need to fill out additional forms? I also find it difficult to find where to access the product when I first land on the landing page. - Brant: I would also add that I think there are too many steps. - Namita: Ok, I can work on the landing page to make the product more discoverable but brant can you work on the additional forms? - Brant: Yes but I would need to work with James from another team as he needs to unblock the sign up workflow. Miguel can you document any other concerns so that I can discuss with James only once? - Miguel: Sure. + Meeting transcriot: + Miguel: Hi Brant, I want to discuss the workstream for our new product launch + Brant: Sure Miguel, is there anything in particular you want to discuss? + Miguel: Yes, I want to talk about how users enter into the product. + Brant: Ok, in that case let me add in Namita. + Namita: Hey everyone + Brant: Hi Namita, Miguel wants to discuss how users enter into the product. + Miguel: its too complicated and we should remove friction. for example, why do I need to fill out additional forms? I also find it difficult to find where to access the product when I first land on the landing page. + Brant: I would also add that I think there are too many steps. + Namita: Ok, I can work on the landing page to make the product more discoverable but brant can you work on the additional forms? + Brant: Yes but I would need to work with James from another team as he needs to unblock the sign up workflow. Miguel can you document any other concerns so that I can discuss with James only once? + Miguel: Sure. """ ) diff --git a/tests/integration/tasks/test_tool_task.py b/tests/integration/tasks/test_tool_task.py index 1fa48e98d..aee0af110 100644 --- a/tests/integration/tasks/test_tool_task.py +++ b/tests/integration/tasks/test_tool_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestToolTask: @pytest.fixture( diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index a4d8c30c3..8dfcfdc73 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestToolkitTask: @pytest.fixture( @@ -10,9 +11,10 @@ class TestToolkitTask: ) def structure_tester(self, request): import os - from griptape.structures import Agent - from griptape.tools import WebScraper, WebSearch, TaskMemoryClient + from griptape.drivers import GoogleWebSearchDriver + from griptape.structures import Agent + from griptape.tools import TaskMemoryClient, WebScraper, WebSearch return StructureTester( Agent( diff --git a/tests/integration/test_code_blocks.py b/tests/integration/test_code_blocks.py index 3a267a529..2da683a2a 100644 --- a/tests/integration/test_code_blocks.py +++ b/tests/integration/test_code_blocks.py @@ -2,8 +2,8 @@ import os import pytest -from tests.utils.code_blocks import get_all_code_blocks, check_py_string +from tests.utils.code_blocks import check_py_string, get_all_code_blocks if "DOCS_ALL_CHANGED_FILES" in os.environ and os.environ["DOCS_ALL_CHANGED_FILES"] != "": docs_all_changed_files = os.environ["DOCS_ALL_CHANGED_FILES"].split() diff --git a/tests/integration/tools/test_calculator.py b/tests/integration/tools/test_calculator.py index 9015c7158..2547b947d 100644 --- a/tests/integration/tools/test_calculator.py +++ b/tests/integration/tools/test_calculator.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestCalculator: @pytest.fixture( diff --git a/tests/integration/tools/test_file_manager.py b/tests/integration/tools/test_file_manager.py index 462e66470..8a283c6e8 100644 --- a/tests/integration/tools/test_file_manager.py +++ b/tests/integration/tools/test_file_manager.py @@ -1,6 +1,7 @@ -from tests.utils.structure_tester import StructureTester import pytest +from tests.utils.structure_tester import StructureTester + class TestFileManager: @pytest.fixture( diff --git a/tests/integration/tools/test_google_docs_client.py b/tests/integration/tools/test_google_docs_client.py index dfb1eb95b..4d70aac17 100644 --- a/tests/integration/tools/test_google_docs_client.py +++ b/tests/integration/tools/test_google_docs_client.py @@ -1,5 +1,7 @@ -import pytest import os + +import pytest + from tests.utils.structure_tester import StructureTester diff --git a/tests/integration/tools/test_google_drive_client.py b/tests/integration/tools/test_google_drive_client.py index 9bbbacfb5..23ebb1b32 100644 --- a/tests/integration/tools/test_google_drive_client.py +++ b/tests/integration/tools/test_google_drive_client.py @@ -1,5 +1,7 @@ -import pytest import os + +import pytest + from tests.utils.structure_tester import StructureTester diff --git a/tests/mocks/docker/fake_api.py b/tests/mocks/docker/fake_api.py index 3d5a411e5..881093057 100644 --- a/tests/mocks/docker/fake_api.py +++ b/tests/mocks/docker/fake_api.py @@ -150,7 +150,7 @@ def post_fake_create_container(): return status_code, response -def get_fake_inspect_container(tty=False): +def get_fake_inspect_container(*, tty=False): status_code = 200 response = { "Id": FAKE_CONTAINER_ID, @@ -531,7 +531,6 @@ def post_fake_secret(): f"{prefix}/{CURRENT_VERSION}/containers/{FAKE_CONTAINER_ID}/unpause": post_fake_unpause_container, f"{prefix}/{CURRENT_VERSION}/containers/{FAKE_CONTAINER_ID}/restart": post_fake_restart_container, f"{prefix}/{CURRENT_VERSION}/containers/{FAKE_CONTAINER_ID}": delete_fake_remove_container, - f"{prefix}/{CURRENT_VERSION}/images/create": post_fake_image_create, f"{prefix}/{CURRENT_VERSION}/images/{FAKE_IMAGE_ID}": delete_fake_remove_image, f"{prefix}/{CURRENT_VERSION}/images/{FAKE_IMAGE_ID}/get": get_fake_get_image, f"{prefix}/{CURRENT_VERSION}/images/load": post_fake_load_image, @@ -544,20 +543,20 @@ def post_fake_secret(): f"{prefix}/{CURRENT_VERSION}/events": get_fake_events, (f"{prefix}/{CURRENT_VERSION}/volumes", "GET"): get_fake_volume_list, (f"{prefix}/{CURRENT_VERSION}/volumes/create", "POST"): get_fake_volume, - ("{1}/{0}/volumes/{2}".format(CURRENT_VERSION, prefix, FAKE_VOLUME_NAME), "GET"): get_fake_volume, - ("{1}/{0}/volumes/{2}".format(CURRENT_VERSION, prefix, FAKE_VOLUME_NAME), "DELETE"): fake_remove_volume, - ("{1}/{0}/nodes/{2}/update?version=1".format(CURRENT_VERSION, prefix, FAKE_NODE_ID), "POST"): post_fake_update_node, + (f"{prefix}/{CURRENT_VERSION}/volumes/{FAKE_VOLUME_NAME}", "GET"): get_fake_volume, + (f"{prefix}/{CURRENT_VERSION}/volumes/{FAKE_VOLUME_NAME}", "DELETE"): fake_remove_volume, + (f"{prefix}/{CURRENT_VERSION}/nodes/{FAKE_NODE_ID}/update?version=1", "POST"): post_fake_update_node, (f"{prefix}/{CURRENT_VERSION}/swarm/join", "POST"): post_fake_join_swarm, (f"{prefix}/{CURRENT_VERSION}/networks", "GET"): get_fake_network_list, (f"{prefix}/{CURRENT_VERSION}/networks/create", "POST"): post_fake_network, - ("{1}/{0}/networks/{2}".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), "GET"): get_fake_network, - ("{1}/{0}/networks/{2}".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), "DELETE"): delete_fake_network, + (f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}", "GET"): get_fake_network, + (f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}", "DELETE"): delete_fake_network, ( - "{1}/{0}/networks/{2}/connect".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), + f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}/connect", "POST", ): post_fake_network_connect, ( - "{1}/{0}/networks/{2}/disconnect".format(CURRENT_VERSION, prefix, FAKE_NETWORK_ID), + f"{prefix}/{CURRENT_VERSION}/networks/{FAKE_NETWORK_ID}/disconnect", "POST", ): post_fake_network_disconnect, f"{prefix}/{CURRENT_VERSION}/secrets/create": post_fake_secret, diff --git a/tests/mocks/docker/fake_api_client.py b/tests/mocks/docker/fake_api_client.py index 05b06216a..25df7ab83 100644 --- a/tests/mocks/docker/fake_api_client.py +++ b/tests/mocks/docker/fake_api_client.py @@ -1,15 +1,14 @@ import copy +from unittest import mock import docker from docker.constants import DEFAULT_DOCKER_API_VERSION -from unittest import mock + from . import fake_api class CopyReturnMagicMock(mock.MagicMock): - """ - A MagicMock which deep copies every return value. - """ + """A MagicMock which deep copies every return value.""" def _mock_call(self, *args, **kwargs): ret = super()._mock_call(*args, **kwargs) @@ -19,13 +18,11 @@ def _mock_call(self, *args, **kwargs): def make_fake_api_client(overrides=None): - """ - Returns non-complete fake APIClient. + """Returns non-complete fake APIClient. This returns most of the default cases correctly, but most arguments that change behaviour will not work. """ - if overrides is None: overrides = {} api_client = docker.APIClient(version=DEFAULT_DOCKER_API_VERSION) @@ -57,9 +54,7 @@ def make_fake_api_client(overrides=None): def make_fake_client(overrides=None): - """ - Returns a Client with a fake APIClient. - """ + """Returns a Client with a fake APIClient.""" client = docker.DockerClient(version=DEFAULT_DOCKER_API_VERSION) client.api = make_fake_api_client(overrides) return client diff --git a/tests/mocks/invalid_mock_tool/tool.py b/tests/mocks/invalid_mock_tool/tool.py index 91b2f78f7..fc761cae5 100644 --- a/tests/mocks/invalid_mock_tool/tool.py +++ b/tests/mocks/invalid_mock_tool/tool.py @@ -1,5 +1,6 @@ from attrs import define, field -from schema import Schema, Literal +from schema import Literal, Schema + from griptape.tools import BaseTool from griptape.utils.decorators import activity diff --git a/tests/mocks/mock_audio_input_task.py b/tests/mocks/mock_audio_input_task.py index d6a27d968..95b8c88d0 100644 --- a/tests/mocks/mock_audio_input_task.py +++ b/tests/mocks/mock_audio_input_task.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import TextArtifact from griptape.tasks.base_audio_input_task import BaseAudioInputTask diff --git a/tests/mocks/mock_embedding_driver.py b/tests/mocks/mock_embedding_driver.py index e21c56308..46d9bf515 100644 --- a/tests/mocks/mock_embedding_driver.py +++ b/tests/mocks/mock_embedding_driver.py @@ -1,4 +1,7 @@ -from attrs import field, define +from __future__ import annotations + +from attrs import define, field + from griptape.drivers import BaseEmbeddingDriver from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index 560fb8733..5833dd1c0 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from attrs import define from griptape.drivers import BaseEventListenerDriver diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 0dbeb8fda..18895fdc9 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -1,12 +1,17 @@ from __future__ import annotations -from collections.abc import Iterator + +from typing import TYPE_CHECKING + from attrs import define from griptape.artifacts import TextArtifact -from griptape.common import PromptStack, Message, TextMessageContent, DeltaMessage, TextDeltaMessageContent +from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, TextMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer +if TYPE_CHECKING: + from collections.abc import Iterator + @define class MockFailingPromptDriver(BasePromptDriver): diff --git a/tests/mocks/mock_image_generation_driver.py b/tests/mocks/mock_image_generation_driver.py index de94771e2..10de11071 100644 --- a/tests/mocks/mock_image_generation_driver.py +++ b/tests/mocks/mock_image_generation_driver.py @@ -1,5 +1,9 @@ +from __future__ import annotations + from typing import Optional + from attrs import define + from griptape.artifacts import ImageArtifact from griptape.drivers.image_generation.base_image_generation_driver import BaseImageGenerationDriver diff --git a/tests/mocks/mock_image_generation_task.py b/tests/mocks/mock_image_generation_task.py index 1c79b42a9..b55c5c995 100644 --- a/tests/mocks/mock_image_generation_task.py +++ b/tests/mocks/mock_image_generation_task.py @@ -13,7 +13,7 @@ def input(self) -> TextArtifact: return self._input @input.setter - def input(self, value: str): + def input(self, value: str) -> None: self._input = TextArtifact(value) def run(self) -> ImageArtifact: diff --git a/tests/mocks/mock_image_query_driver.py b/tests/mocks/mock_image_query_driver.py index d3bec164f..8f8cc888c 100644 --- a/tests/mocks/mock_image_query_driver.py +++ b/tests/mocks/mock_image_query_driver.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from typing import Optional + from attrs import define + from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver -from griptape.drivers.image_generation.base_image_generation_driver import BaseImageGenerationDriver @define diff --git a/tests/mocks/mock_multi_text_input_task.py b/tests/mocks/mock_multi_text_input_task.py index 7ab5aedf9..be00bbf65 100644 --- a/tests/mocks/mock_multi_text_input_task.py +++ b/tests/mocks/mock_multi_text_input_task.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import TextArtifact from griptape.tasks import BaseMultiTextInputTask diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 4786b78a6..5a23dd8a2 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,17 +1,19 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import Callable +from typing import TYPE_CHECKING, Callable from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.common import PromptStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent +from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, TextMessageContent from griptape.drivers import BasePromptDriver -from griptape.tokenizers import BaseTokenizer - from tests.mocks.mock_tokenizer import MockTokenizer +if TYPE_CHECKING: + from collections.abc import Iterator + + from griptape.tokenizers import BaseTokenizer + @define class MockPromptDriver(BasePromptDriver): diff --git a/tests/mocks/mock_serializable.py b/tests/mocks/mock_serializable.py index b02c071aa..b40ae25b4 100644 --- a/tests/mocks/mock_serializable.py +++ b/tests/mocks/mock_serializable.py @@ -1,5 +1,9 @@ -from attrs import define, field +from __future__ import annotations + from typing import Optional + +from attrs import define, field + from griptape.mixins import SerializableMixin @@ -13,4 +17,6 @@ class NestedMockSerializable(SerializableMixin): bar: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) baz: Optional[list[int]] = field(default=None, kw_only=True, metadata={"serializable": True}) secret: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) - nested: Optional[NestedMockSerializable] = field(default=None, kw_only=True, metadata={"serializable": True}) + nested: Optional[MockSerializable.NestedMockSerializable] = field( + default=None, kw_only=True, metadata={"serializable": True} + ) diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py index 8309f541b..3f95288f4 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_structure_config.py @@ -1,9 +1,10 @@ -from attrs import define, field, Factory +from attrs import Factory, define, field + from griptape.config import StructureConfig +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_image_query_driver import MockImageQueryDriver from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @define diff --git a/tests/mocks/mock_task.py b/tests/mocks/mock_task.py index 42595f6eb..81aa03713 100644 --- a/tests/mocks/mock_task.py +++ b/tests/mocks/mock_task.py @@ -1,5 +1,6 @@ from attrs import define, field -from griptape.artifacts import TextArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.tasks import BaseTask diff --git a/tests/mocks/mock_text_input_task.py b/tests/mocks/mock_text_input_task.py index 930c77e74..f1439bd42 100644 --- a/tests/mocks/mock_text_input_task.py +++ b/tests/mocks/mock_text_input_task.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import TextArtifact from griptape.tasks import BaseTextInputTask diff --git a/tests/mocks/mock_tokenizer.py b/tests/mocks/mock_tokenizer.py index eff103e99..b16332ce0 100644 --- a/tests/mocks/mock_tokenizer.py +++ b/tests/mocks/mock_tokenizer.py @@ -1,5 +1,7 @@ from __future__ import annotations + from attrs import define + from griptape.tokenizers import BaseTokenizer diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index e79023a6e..7d09f391e 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -1,6 +1,7 @@ from attrs import define, field -from schema import Schema, Literal -from griptape.artifacts import TextArtifact, ErrorArtifact, BaseArtifact, ListArtifact +from schema import Literal, Schema + +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -49,7 +50,7 @@ def test_str_output(self, value: dict) -> str: @activity(config={"description": "test description"}) def test_no_schema(self, value: dict) -> str: - return f"no schema" + return "no schema" @activity(config={"description": "test description"}) def test_list_output(self, value: dict) -> ListArtifact: diff --git a/tests/unit/artifacts/test_action_artifact.py b/tests/unit/artifacts/test_action_artifact.py index e415bbdaf..2530ed8c3 100644 --- a/tests/unit/artifacts/test_action_artifact.py +++ b/tests/unit/artifacts/test_action_artifact.py @@ -1,7 +1,9 @@ import json + import pytest -from griptape.common import ToolAction + from griptape.artifacts import ActionArtifact, BaseArtifact +from griptape.common import ToolAction class TestActionArtifact: @@ -11,7 +13,7 @@ def action(self) -> ToolAction: def test___add__(self, action): with pytest.raises(NotImplementedError): - result = ActionArtifact(action) + ActionArtifact(action) + ActionArtifact(action) + ActionArtifact(action) def test_to_text(self, action): assert ActionArtifact(action).to_text() == json.dumps(action.to_dict()) diff --git a/tests/unit/artifacts/test_audio_artifact.py b/tests/unit/artifacts/test_audio_artifact.py index 93ea816e4..6d44c05b3 100644 --- a/tests/unit/artifacts/test_audio_artifact.py +++ b/tests/unit/artifacts/test_audio_artifact.py @@ -1,9 +1,10 @@ import pytest + from griptape.artifacts import AudioArtifact, BaseArtifact class TestAudioArtifact: - @pytest.fixture + @pytest.fixture() def audio_artifact(self): return AudioArtifact(value=b"some binary audio data", format="pcm", model="provider/model", prompt="two words") diff --git a/tests/unit/artifacts/test_base_artifact.py b/tests/unit/artifacts/test_base_artifact.py index a7d7acaaf..6cf8f4466 100644 --- a/tests/unit/artifacts/test_base_artifact.py +++ b/tests/unit/artifacts/test_base_artifact.py @@ -1,12 +1,13 @@ import pytest + from griptape.artifacts import ( BaseArtifact, - TextArtifact, + BlobArtifact, ErrorArtifact, + ImageArtifact, InfoArtifact, ListArtifact, - BlobArtifact, - ImageArtifact, + TextArtifact, ) diff --git a/tests/unit/artifacts/test_base_media_artifact.py b/tests/unit/artifacts/test_base_media_artifact.py index 2829a1e2f..c85d070fe 100644 --- a/tests/unit/artifacts/test_base_media_artifact.py +++ b/tests/unit/artifacts/test_base_media_artifact.py @@ -1,5 +1,4 @@ import pytest - from attrs import define from griptape.artifacts import MediaArtifact @@ -10,7 +9,7 @@ class TestMediaArtifact: class ImaginaryMediaArtifact(MediaArtifact): media_type: str = "imagination" - @pytest.fixture + @pytest.fixture() def media_artifact(self): return self.ImaginaryMediaArtifact(value=b"some binary dream data", format="dream") diff --git a/tests/unit/artifacts/test_blob_artifact.py b/tests/unit/artifacts/test_blob_artifact.py index 08844d241..a50a673f4 100644 --- a/tests/unit/artifacts/test_blob_artifact.py +++ b/tests/unit/artifacts/test_blob_artifact.py @@ -1,6 +1,8 @@ import base64 + import pytest -from griptape.artifacts import BlobArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, BlobArtifact class TestBlobArtifact: diff --git a/tests/unit/artifacts/test_boolean_artifact.py b/tests/unit/artifacts/test_boolean_artifact.py index bcad67673..57bbf1662 100644 --- a/tests/unit/artifacts/test_boolean_artifact.py +++ b/tests/unit/artifacts/test_boolean_artifact.py @@ -1,4 +1,6 @@ +# ruff: noqa: FBT003 import pytest + from griptape.artifacts import BooleanArtifact @@ -13,14 +15,14 @@ def test_parse_bool(self): BooleanArtifact.parse_bool("foo") with pytest.raises(ValueError): - BooleanArtifact.parse_bool(None) + BooleanArtifact.parse_bool(None) # pyright: ignore[reportArgumentType] assert BooleanArtifact.parse_bool(True).value is True assert BooleanArtifact.parse_bool(False).value is False def test_add(self): with pytest.raises(ValueError): - BooleanArtifact(True) + BooleanArtifact(True) + BooleanArtifact(True) + BooleanArtifact(True) # pyright: ignore[reportUnusedExpression] def test_value_type_conversion(self): assert BooleanArtifact(1).value is True @@ -31,5 +33,5 @@ def test_value_type_conversion(self): assert BooleanArtifact("false").value is True assert BooleanArtifact([1]).value is True assert BooleanArtifact([]).value is False - assert BooleanArtifact(False) == False - assert BooleanArtifact(True) == True + assert BooleanArtifact(False).value is False + assert BooleanArtifact(True).value is True diff --git a/tests/unit/artifacts/test_image_artifact.py b/tests/unit/artifacts/test_image_artifact.py index 687397260..a722ebd91 100644 --- a/tests/unit/artifacts/test_image_artifact.py +++ b/tests/unit/artifacts/test_image_artifact.py @@ -1,9 +1,10 @@ import pytest -from griptape.artifacts import ImageArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, ImageArtifact class TestImageArtifact: - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact( value=b"some binary png image data", diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index 044ca8ed5..06d234645 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import ListArtifact, TextArtifact, BlobArtifact, CsvRowArtifact + +from griptape.artifacts import BlobArtifact, CsvRowArtifact, ListArtifact, TextArtifact class TestListArtifact: diff --git a/tests/unit/artifacts/test_text_artifact.py b/tests/unit/artifacts/test_text_artifact.py index 6ea2c6697..067da0912 100644 --- a/tests/unit/artifacts/test_text_artifact.py +++ b/tests/unit/artifacts/test_text_artifact.py @@ -1,6 +1,8 @@ import json + import pytest -from griptape.artifacts import TextArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.tokenizers import OpenAiTokenizer from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/chunkers/test_markdown_chunker.py b/tests/unit/chunkers/test_markdown_chunker.py index 08709c092..30db64611 100644 --- a/tests/unit/chunkers/test_markdown_chunker.py +++ b/tests/unit/chunkers/test_markdown_chunker.py @@ -1,4 +1,5 @@ import pytest + from griptape.chunkers import MarkdownChunker from tests.unit.chunkers.test_text_chunker import gen_paragraph @@ -6,7 +7,7 @@ class TestTextChunker: - @pytest.fixture + @pytest.fixture() def chunker(self): return MarkdownChunker(max_tokens=MAX_TOKENS) diff --git a/tests/unit/chunkers/test_pdf_chunker.py b/tests/unit/chunkers/test_pdf_chunker.py index 605c2f6e6..dc072ca36 100644 --- a/tests/unit/chunkers/test_pdf_chunker.py +++ b/tests/unit/chunkers/test_pdf_chunker.py @@ -1,13 +1,15 @@ import os + import pytest from pypdf import PdfReader + from griptape.chunkers import PdfChunker MAX_TOKENS = 500 class TestPdfChunker: - @pytest.fixture + @pytest.fixture() def chunker(self): return PdfChunker(max_tokens=MAX_TOKENS) diff --git a/tests/unit/chunkers/test_text_chunker.py b/tests/unit/chunkers/test_text_chunker.py index 243b287e1..d9d01f16c 100644 --- a/tests/unit/chunkers/test_text_chunker.py +++ b/tests/unit/chunkers/test_text_chunker.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.chunkers import TextChunker from tests.unit.chunkers.utils import gen_paragraph @@ -7,7 +8,7 @@ class TestTextChunker: - @pytest.fixture + @pytest.fixture() def chunker(self): return TextChunker(max_tokens=MAX_TOKENS) diff --git a/tests/unit/chunkers/utils.py b/tests/unit/chunkers/utils.py index 80335e978..b9e7b8539 100644 --- a/tests/unit/chunkers/utils.py +++ b/tests/unit/chunkers/utils.py @@ -5,7 +5,9 @@ def gen_paragraph(max_tokens: int, tokenizer: BaseTokenizer, sentence_separator: all_text = "" word = "foo" index = 0 - add_word = lambda base, w, i: sentence_separator.join([base, f"{w}-{i}"]) + + def add_word(base, w, i): + return sentence_separator.join([base, f"{w}-{i}"]) while max_tokens >= tokenizer.count_tokens(add_word(all_text, word, index)): all_text = f"{word}-{index}" if all_text == "" else add_word(all_text, word, index) diff --git a/tests/unit/common/contents/test_action_call_message_content.py b/tests/unit/common/contents/test_action_call_message_content.py index 2e2e69d27..d6c3f438f 100644 --- a/tests/unit/common/contents/test_action_call_message_content.py +++ b/tests/unit/common/contents/test_action_call_message_content.py @@ -1,6 +1,7 @@ import pytest + from griptape.artifacts.action_artifact import ActionArtifact -from griptape.common import ActionCallMessageContent, ActionCallDeltaMessageContent, ToolAction +from griptape.common import ActionCallDeltaMessageContent, ActionCallMessageContent, ToolAction class TestActionCallMessageContent: diff --git a/tests/unit/common/contents/test_action_result_message_content.py b/tests/unit/common/contents/test_action_result_message_content.py index b1bcc356d..c5eed60d9 100644 --- a/tests/unit/common/contents/test_action_result_message_content.py +++ b/tests/unit/common/contents/test_action_result_message_content.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.text_artifact import TextArtifact from griptape.common import ActionResultMessageContent, ToolAction diff --git a/tests/unit/common/contents/test_image_message_content.py b/tests/unit/common/contents/test_image_message_content.py index b6c1b4c4f..ff8dbe59d 100644 --- a/tests/unit/common/contents/test_image_message_content.py +++ b/tests/unit/common/contents/test_image_message_content.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.image_artifact import ImageArtifact from griptape.common import ImageMessageContent diff --git a/tests/unit/common/contents/test_text_message_content.py b/tests/unit/common/contents/test_text_message_content.py index 01a3c0fd4..eab9eb718 100644 --- a/tests/unit/common/contents/test_text_message_content.py +++ b/tests/unit/common/contents/test_text_message_content.py @@ -1,5 +1,5 @@ from griptape.artifacts.text_artifact import TextArtifact -from griptape.common import TextMessageContent, TextDeltaMessageContent +from griptape.common import TextDeltaMessageContent, TextMessageContent class TestTextMessageContent: diff --git a/tests/unit/common/test_action.py b/tests/unit/common/test_action.py index db5284839..8fcf09a57 100644 --- a/tests/unit/common/test_action.py +++ b/tests/unit/common/test_action.py @@ -1,5 +1,7 @@ -import pytest import json + +import pytest + from griptape.common import ToolAction diff --git a/tests/unit/common/test_prompt_stack.py b/tests/unit/common/test_prompt_stack.py index 83a16e140..e69fe710d 100644 --- a/tests/unit/common/test_prompt_stack.py +++ b/tests/unit/common/test_prompt_stack.py @@ -1,13 +1,18 @@ import pytest -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact, ActionArtifact -from griptape.common import ImageMessageContent, PromptStack, TextMessageContent -from griptape.common import ActionCallMessageContent -from griptape.common import ActionResultMessageContent, ToolAction +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ( + ActionCallMessageContent, + ActionResultMessageContent, + ImageMessageContent, + PromptStack, + TextMessageContent, + ToolAction, +) class TestPromptStack: - @pytest.fixture + @pytest.fixture() def prompt_stack(self): return PromptStack() diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 824e6ce11..afe9b3720 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -1,20 +1,21 @@ import boto3 -from pytest import fixture +import pytest + from griptape.config import AmazonBedrockStructureConfig from tests.utils.aws import mock_aws_credentials class TestAmazonBedrockStructureConfig: - @fixture(autouse=True) - def run_before_and_after_tests(self): + @pytest.fixture(autouse=True) + def _run_before_and_after_tests(self): mock_aws_credentials() - @fixture + @pytest.fixture() def config(self): mock_aws_credentials() return AmazonBedrockStructureConfig() - @fixture + @pytest.fixture() def config_with_values(self): return AmazonBedrockStructureConfig( session=boto3.Session( diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index b41309a83..05519fa5e 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -1,14 +1,15 @@ -from pytest import fixture +import pytest + from griptape.config import AnthropicStructureConfig class TestAnthropicStructureConfig: - @fixture(autouse=True) - def mock_anthropic(self, mocker): + @pytest.fixture(autouse=True) + def _mock_anthropic(self, mocker): mocker.patch("anthropic.Anthropic") mocker.patch("voyageai.Client") - @fixture + @pytest.fixture() def config(self): return AnthropicStructureConfig() diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index 58d557fb9..dcdc3a1dc 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -1,13 +1,14 @@ -from pytest import fixture +import pytest + from griptape.config import AzureOpenAiStructureConfig class TestAzureOpenAiStructureConfig: - @fixture(autouse=True) + @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") - @fixture + @pytest.fixture() def config(self): return AzureOpenAiStructureConfig( azure_endpoint="http://localhost:8080", diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_structure_config.py index 44ed3e4d8..113a589ec 100644 --- a/tests/unit/config/test_cohere_structure_config.py +++ b/tests/unit/config/test_cohere_structure_config.py @@ -1,9 +1,10 @@ -from pytest import fixture +import pytest + from griptape.config import CohereStructureConfig class TestCohereStructureConfig: - @fixture + @pytest.fixture() def config(self): return CohereStructureConfig(api_key="api_key") diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index 469493e2c..e193cc983 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -1,13 +1,14 @@ -from pytest import fixture +import pytest + from griptape.config import GoogleStructureConfig class TestGoogleStructureConfig: - @fixture(autouse=True) + @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("google.generativeai.GenerativeModel") - @fixture + @pytest.fixture() def config(self): return GoogleStructureConfig() diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 19321006f..8969e0ad0 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -1,13 +1,14 @@ -from pytest import fixture +import pytest + from griptape.config import OpenAiStructureConfig class TestOpenAiStructureConfig: - @fixture(autouse=True) + @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.OpenAI") - @fixture + @pytest.fixture() def config(self): return OpenAiStructureConfig() diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 27aaf81c4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,9 +1,10 @@ -from pytest import fixture +import pytest + from griptape.config import StructureConfig class TestStructureConfig: - @fixture + @pytest.fixture() def config(self): return StructureConfig() diff --git a/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py index ba8edad90..5644227c9 100644 --- a/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py @@ -1,11 +1,13 @@ -import pytest from unittest import mock + +import pytest + from griptape.drivers import AmazonBedrockCohereEmbeddingDriver class TestAmazonBedrockCohereEmbeddingDriver: @pytest.fixture(autouse=True) - def mock_session(self, mocker): + def _mock_session(self, mocker): fake_embeddings = '{"embeddings": [[0, 1, 0]] }' mock_session_class = mocker.patch("boto3.Session") diff --git a/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py b/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py index df4455c24..4470cf62b 100644 --- a/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py @@ -1,11 +1,13 @@ -import pytest from unittest import mock + +import pytest + from griptape.drivers import AmazonBedrockTitanEmbeddingDriver class TestAmazonBedrockTitanEmbeddingDriver: @pytest.fixture(autouse=True) - def mock_session(self, mocker): + def _mock_session(self, mocker): fake_embeddings = '{"embedding": [0, 1, 0]}' mock_session_class = mocker.patch("boto3.Session") diff --git a/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py b/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py index d2c20c043..7f604434b 100644 --- a/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py @@ -1,5 +1,7 @@ from unittest.mock import Mock + import pytest + from griptape.drivers import AzureOpenAiEmbeddingDriver @@ -17,7 +19,7 @@ def mock_openai(self, mocker): return mock_chat_create - @pytest.fixture + @pytest.fixture() def driver(self): return AzureOpenAiEmbeddingDriver(azure_endpoint="foobar", model="gpt-4", azure_deployment="foobar") diff --git a/tests/unit/drivers/embedding/test_base_embedding_driver.py b/tests/unit/drivers/embedding/test_base_embedding_driver.py index 24b07778d..4413468b6 100644 --- a/tests/unit/drivers/embedding/test_base_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_base_embedding_driver.py @@ -1,11 +1,13 @@ +from unittest.mock import patch + import pytest + from griptape.artifacts import TextArtifact from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from unittest.mock import patch class TestBaseEmbeddingDriver: - @pytest.fixture + @pytest.fixture() def driver(self): return MockEmbeddingDriver() diff --git a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py index af6a5576d..024e0e74c 100644 --- a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py @@ -1,5 +1,7 @@ from unittest.mock import Mock + import pytest + from griptape.drivers import CohereEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_dummy_embedding_driver.py b/tests/unit/drivers/embedding/test_dummy_embedding_driver.py index 35f81bf77..af56ce6ac 100644 --- a/tests/unit/drivers/embedding/test_dummy_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_dummy_embedding_driver.py @@ -1,11 +1,11 @@ -from griptape.drivers import DummyEmbeddingDriver import pytest +from griptape.drivers import DummyEmbeddingDriver from griptape.exceptions import DummyException class TestDummyEmbeddingDriver: - @pytest.fixture + @pytest.fixture() def embedding_driver(self): return DummyEmbeddingDriver() diff --git a/tests/unit/drivers/embedding/test_google_embedding_driver.py b/tests/unit/drivers/embedding/test_google_embedding_driver.py index 324b95ddb..9e756491e 100644 --- a/tests/unit/drivers/embedding/test_google_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_google_embedding_driver.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock + import pytest + from griptape.drivers import GoogleEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py index 3886ab874..6dda23930 100644 --- a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py @@ -1,4 +1,5 @@ import pytest + from griptape.drivers import OllamaEmbeddingDriver diff --git a/tests/unit/drivers/embedding/test_openai_embedding_driver.py b/tests/unit/drivers/embedding/test_openai_embedding_driver.py index fd30dd30f..78879345a 100644 --- a/tests/unit/drivers/embedding/test_openai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_openai_embedding_driver.py @@ -1,5 +1,7 @@ -from unittest.mock import Mock, MagicMock +from unittest.mock import Mock + import pytest + from griptape.drivers import OpenAiEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer diff --git a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py index 268b47c54..09ec8ec87 100644 --- a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest import mock + +import pytest + from griptape.drivers import AmazonSageMakerJumpstartEmbeddingDriver from griptape.tokenizers.openai_tokenizer import OpenAiTokenizer @@ -41,19 +43,17 @@ def test_try_embed_chunk(self, mock_client): ).try_embed_chunk("foobar") == [0, 2, 0] mock_client.get().read.return_value = b'{"embedding": []}' - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="model response is empty"): assert AmazonSageMakerJumpstartEmbeddingDriver( endpoint="test-endpoint", model="test-model", tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), ).try_embed_chunk("foobar") == [0, 2, 0] - assert str(e) == "model response is empty" mock_client.get().read.return_value = b"{}" - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="invalid response from model"): assert AmazonSageMakerJumpstartEmbeddingDriver( endpoint="test-endpoint", model="test-model", tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), ).try_embed_chunk("foobar") == [0, 2, 0] - assert str(e) == "invalid response from model" diff --git a/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py b/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py index 69db0213c..5371f8db0 100644 --- a/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import VoyageAiEmbeddingDriver diff --git a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py index 706831d67..10ef0354c 100644 --- a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py @@ -1,17 +1,18 @@ -from pytest import fixture -from moto import mock_sqs import boto3 -from tests.mocks.mock_event import MockEvent +import pytest +from moto import mock_sqs + from griptape.drivers.event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver +from tests.mocks.mock_event import MockEvent from tests.utils.aws import mock_aws_credentials class TestAmazonSqsEventListenerDriver: - @fixture() - def run_before_and_after_tests(self): + @pytest.fixture() + def _run_before_and_after_tests(self): mock_aws_credentials() - @fixture() + @pytest.fixture() def driver(self): mock = mock_sqs() mock.start() diff --git a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py index 9a5fe9ec0..b597a5332 100644 --- a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py @@ -1,18 +1,19 @@ -from pytest import fixture -from moto import mock_iotdata import boto3 -from tests.mocks.mock_event import MockEvent +import pytest +from moto import mock_iotdata + from griptape.drivers.event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver +from tests.mocks.mock_event import MockEvent from tests.utils.aws import mock_aws_credentials @mock_iotdata class TestAwsIotCoreEventListenerDriver: - @fixture(autouse=True) - def run_before_and_after_tests(self): + @pytest.fixture(autouse=True) + def _run_before_and_after_tests(self): mock_aws_credentials() - @fixture() + @pytest.fixture() def driver(self): return AwsIotCoreEventListenerDriver( iot_endpoint="foo bar", topic="fizz buzz", session=boto3.Session(region_name="us-east-1") diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 383c0be89..04cfef34b 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -1,4 +1,5 @@ from unittest.mock import MagicMock + from tests.mocks.mock_event import MockEvent from tests.mocks.mock_event_listener_driver import MockEventListenerDriver diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index d27f09ec8..b651841ca 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -2,14 +2,13 @@ from unittest.mock import Mock import pytest -from pytest import fixture from griptape.drivers.event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver from tests.mocks.mock_event import MockEvent class TestGriptapeCloudEventListenerDriver: - @fixture(autouse=True) + @pytest.fixture(autouse=True) def mock_post(self, mocker): data = {"data": {"id": "test"}} @@ -18,7 +17,7 @@ def mock_post(self, mocker): return mock_post - @fixture() + @pytest.fixture() def driver(self): os.environ["GT_CLOUD_BASE_URL"] = "https://cloud123.griptape.ai" diff --git a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py index 6f0636b5c..50856c0da 100644 --- a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py @@ -1,11 +1,13 @@ -from pytest import fixture -from tests.mocks.mock_event import MockEvent -from griptape.drivers import PusherEventListenerDriver from unittest.mock import Mock +import pytest + +from griptape.drivers import PusherEventListenerDriver +from tests.mocks.mock_event import MockEvent + class TestPusherEventListenerDriver: - @fixture(autouse=True) + @pytest.fixture(autouse=True) def mock_post(self, mocker): mock_pusher_client = mocker.patch("pusher.Pusher") mock_pusher_client.return_value.trigger.return_value = Mock() @@ -13,7 +15,7 @@ def mock_post(self, mocker): return mock_pusher_client - @fixture() + @pytest.fixture() def driver(self): return PusherEventListenerDriver( app_id="test-app-id", diff --git a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py index 50021cbe3..f6de0d20f 100644 --- a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py @@ -1,11 +1,13 @@ from unittest.mock import Mock -from pytest import fixture -from tests.mocks.mock_event import MockEvent + +import pytest + from griptape.drivers.event_listener.webhook_event_listener_driver import WebhookEventListenerDriver +from tests.mocks.mock_event import MockEvent class TestWebhookEventListenerDriver: - @fixture(autouse=True) + @pytest.fixture(autouse=True) def mock_post(self, mocker): mock_post = mocker.patch("requests.post") mock_post.return_value = Mock(status_code=201) diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 8d1693ade..cad76a7d6 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -1,9 +1,11 @@ import os import tempfile + import boto3 import pytest from moto import mock_s3 -from griptape.artifacts import ErrorArtifact, ListArtifact, InfoArtifact, TextArtifact + +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import AmazonS3FileManagerDriver from griptape.loaders import TextLoader from tests.utils.aws import mock_aws_credentials @@ -11,35 +13,35 @@ class TestAmazonS3FileManagerDriver: @pytest.fixture(autouse=True) - def set_aws_credentials(self): + def _set_aws_credentials(self): mock_aws_credentials() - @pytest.fixture + @pytest.fixture() def session(self): mock = mock_s3() mock.start() yield boto3.Session(region_name="us-east-1") mock.stop() - @pytest.fixture + @pytest.fixture() def s3_client(self, session): - yield session.client("s3") + return session.client("s3") @pytest.fixture(autouse=True) def bucket(self, s3_client): bucket = "test-bucket" s3_client.create_bucket(Bucket=bucket) - def write_file(path: str, content: bytes): + def write_file(path: str, content: bytes) -> None: s3_client.put_object(Bucket=bucket, Key=path, Body=content) - def mkdir(path: str): + def mkdir(path: str) -> None: # S3-style empty directories, such as is created via the `Create Folder` button # in the AWS S3 console (essentially, an empty file with a trailing slash). s3_dir_key = path.rstrip("/") + "/" s3_client.put_object(Bucket=bucket, Key=s3_dir_key) - def copy_test_resource(resource_path: str): + def copy_test_resource(resource_path: str) -> None: file_dir = os.path.dirname(__file__) full_path = os.path.join(file_dir, "../../../resources", resource_path) full_path = os.path.normpath(full_path) @@ -58,18 +60,18 @@ def copy_test_resource(resource_path: str): mkdir("foo/bar-empty") mkdir("foo/bar/baz-empty") - yield bucket + return bucket - @pytest.fixture + @pytest.fixture() def driver(self, session, bucket): return AmazonS3FileManagerDriver(session=session, bucket=bucket) - @pytest.fixture + @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: yield temp_dir - @pytest.fixture + @pytest.fixture() def get_s3_value(self, s3_client, bucket): def _get_s3_value(key): return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read().decode() @@ -82,7 +84,7 @@ def test_validate_workdir(self, workdir, session, bucket): AmazonS3FileManagerDriver(session=session, bucket=bucket, workdir=workdir) @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # Valid non-empty directories (without trailing slash) ("/", "", ["foo", "foo.txt", "foo-empty", "resources"]), @@ -130,7 +132,7 @@ def test_list_files(self, workdir, path, expected, driver): assert set(filter(None, artifact.value.split("\n"))) == set(expected) @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # non-existent paths ("/", "bar", "Path not found"), @@ -158,7 +160,7 @@ def test_load_file(self, driver): assert len(artifact.value) == 4 @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # non-existent files or directories ("/", "bitcoin.pdf", "Path not found"), @@ -201,7 +203,7 @@ def test_load_file_with_encoding_failure(self, session, bucket): assert isinstance(artifact, ErrorArtifact) @pytest.mark.parametrize( - "workdir,path,content", + ("workdir", "path", "content"), [ # non-existent files ("/", "resources/foo.txt", "one"), @@ -226,7 +228,7 @@ def test_save_file(self, workdir, path, content, driver, get_s3_value): assert get_s3_value(expected_s3_key) == content_str @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # non-existent directories ("/", "bar/", "Path is a directory"), @@ -248,11 +250,6 @@ def test_save_file_failure(self, workdir, path, expected, temp_dir, driver, s3_c artifact = driver.save_file(path, "foobar") - # loop over the files in the bucket and print them - response = s3_client.list_objects_v2(Bucket=bucket) - for obj in response.get("Contents", []): - print(obj.get("Key")) - assert isinstance(artifact, ErrorArtifact) assert artifact.value == expected diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index b3f4ec561..ec9963b49 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -1,28 +1,30 @@ import os -from pathlib import Path import tempfile +from pathlib import Path + import pytest -from griptape.artifacts import ErrorArtifact, ListArtifact, InfoArtifact, TextArtifact + +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader class TestLocalFileManagerDriver: - @pytest.fixture + @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: - def write_file(path: str, content: bytes): + def write_file(path: str, content: bytes) -> None: full_path = os.path.join(temp_dir, path) os.makedirs(os.path.dirname(full_path), exist_ok=True) with open(full_path, "wb") as f: f.write(content) - def mkdir(path: str): + def mkdir(path: str) -> None: full_path = os.path.join(temp_dir, path) os.makedirs(full_path, exist_ok=True) - def copy_test_resources(resource_path: str): + def copy_test_resources(resource_path: str) -> None: file_dir = os.path.dirname(__file__) full_path = os.path.join(file_dir, "../../../resources", resource_path) full_path = os.path.normpath(full_path) @@ -46,7 +48,7 @@ def copy_test_resources(resource_path: str): yield temp_dir - @pytest.fixture + @pytest.fixture() def driver(self, temp_dir): return LocalFileManagerDriver(workdir=temp_dir) @@ -55,7 +57,7 @@ def test_validate_workdir(self): LocalFileManagerDriver(workdir="foo") @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # Valid non-empty directories (without trailing slash) ("/", "", ["foo", "foo.txt", "foo-empty", "resources"]), @@ -104,7 +106,7 @@ def test_list_files(self, workdir, path, expected, temp_dir, driver): assert set(filter(None, artifact.value.split("\n"))) == set(expected) @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # non-existent paths ("/", "bar", "Path not found"), @@ -133,7 +135,7 @@ def test_load_file(self, driver: LocalFileManagerDriver): assert len(artifact.value) == 4 @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # # non-existent files or directories ("/", "bitcoin.pdf", "Path not found"), @@ -177,7 +179,7 @@ def test_load_file_with_encoding_failure(self): assert isinstance(artifact, ErrorArtifact) @pytest.mark.parametrize( - "workdir,path,content", + ("workdir", "path", "content"), [ # non-existent files ("/", "resources/foo.txt", "one"), @@ -202,7 +204,7 @@ def test_save_file(self, workdir, path, content, temp_dir, driver): assert Path(driver.workdir, path).read_text() == content_bytes @pytest.mark.parametrize( - "workdir,path,expected", + ("workdir", "path", "expected"), [ # non-existent directories ("/", "bar/", "Path is a directory"), diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index a2c51f58b..9aa4d3f4f 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -7,18 +7,18 @@ class TestAmazonBedrockImageGenerationDriver: - @pytest.fixture + @pytest.fixture() def bedrock_client(self): return Mock() - @pytest.fixture + @pytest.fixture() def session(self, bedrock_client): session = Mock() session.client.return_value = bedrock_client return session - @pytest.fixture + @pytest.fixture() def model_driver(self): model_driver = Mock() model_driver.text_to_image_request_parameters.return_value = {} @@ -26,7 +26,7 @@ def model_driver(self): return model_driver - @pytest.fixture + @pytest.fixture() def driver(self, session, model_driver): return AmazonBedrockImageGenerationDriver( session=session, model="stability.stable-diffusion-xl-v1", image_generation_model_driver=model_driver @@ -37,7 +37,7 @@ def test_init(self, driver): def test_init_requires_image_generation_model_driver(self, session): with pytest.raises(TypeError): - AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1") # pyright: ignore + AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1") # pyright: ignore[reportCallIssue] def test_try_text_to_image(self, driver): driver.bedrock_client.invoke_model.return_value = { diff --git a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py index 2166bc28a..268708b2b 100644 --- a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py @@ -1,10 +1,12 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import AzureOpenAiImageGenerationDriver class TestAzureOpenAiImageGenerationDriver: - @pytest.fixture + @pytest.fixture() def driver(self): return AzureOpenAiImageGenerationDriver( model="dall-e-3", @@ -27,7 +29,7 @@ def test_init_requires_endpoint(self): with pytest.raises(TypeError): AzureOpenAiImageGenerationDriver( model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512" - ) # pyright: ignore + ) # pyright: ignore[reportCallIssues] def test_try_text_to_image(self, driver): driver.client.images.generate.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")]) diff --git a/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py b/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py index 971c39b89..2df6c6499 100644 --- a/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_dummy_image_generation_driver.py @@ -1,12 +1,12 @@ -from griptape.drivers import DummyImageGenerationDriver -from griptape.artifacts import ImageArtifact import pytest +from griptape.artifacts import ImageArtifact +from griptape.drivers import DummyImageGenerationDriver from griptape.exceptions import DummyException class TestDummyImageGenerationDriver: - @pytest.fixture + @pytest.fixture() def image_generation_driver(self): return DummyImageGenerationDriver() diff --git a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py index 564d3616a..48805cde6 100644 --- a/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_leonardo_image_generation_driver.py @@ -1,11 +1,13 @@ import uuid -from unittest.mock import Mock, PropertyMock, MagicMock +from unittest.mock import Mock + import pytest + from griptape.drivers import LeonardoImageGenerationDriver class TestLeonardoImageGenerationDriver: - @pytest.fixture + @pytest.fixture() def driver(self): requests_session = Mock() diff --git a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py index 8ca488eb1..16bcd2870 100644 --- a/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_openai_image_generation_driver.py @@ -1,10 +1,12 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import OpenAiImageGenerationDriver class TestOpenAiImageGenerationDriver: - @pytest.fixture + @pytest.fixture() def driver(self): return OpenAiImageGenerationDriver(model="dall-e-2", client=Mock(), quality="hd", image_size="512x512") diff --git a/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py b/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py index cdd4e95b7..60583455e 100644 --- a/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py +++ b/tests/unit/drivers/image_generation_model/test_bedrock_stable_diffusion_image_model_driver.py @@ -7,15 +7,15 @@ class TestBedrockStableDiffusionImageGenerationModelDriver: - @pytest.fixture + @pytest.fixture() def model_driver(self): return BedrockStableDiffusionImageGenerationModelDriver() - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(b"image", format="png", width=1024, height=1024) - @pytest.fixture + @pytest.fixture() def mask_artifact(self): return ImageArtifact(b"mask", format="png", width=1024, height=1024) @@ -118,5 +118,5 @@ def test_get_generated_image_failed(self, model_driver): response = {"artifacts": [{"finishReason": "ERROR", "base64": base64.b64encode(image_bytes).decode("utf-8")}]} - with pytest.raises(Exception): + with pytest.raises(Exception, match="Image generation failed:"): model_driver.get_generated_image(response) diff --git a/tests/unit/drivers/image_generation_model/test_bedrock_titan_image_model_driver.py b/tests/unit/drivers/image_generation_model/test_bedrock_titan_image_model_driver.py index 8c4ed40e3..6bf4c30d5 100644 --- a/tests/unit/drivers/image_generation_model/test_bedrock_titan_image_model_driver.py +++ b/tests/unit/drivers/image_generation_model/test_bedrock_titan_image_model_driver.py @@ -5,15 +5,15 @@ class TestBedrockTitanImageGenerationModelDriver: - @pytest.fixture + @pytest.fixture() def model_driver(self): return BedrockTitanImageGenerationModelDriver() - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(b"image", format="png", width=1024, height=512) - @pytest.fixture + @pytest.fixture() def mask_artifact(self): return ImageArtifact(b"mask", format="png", width=1024, height=512) diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py index 57336e8ea..9493ab23d 100644 --- a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py @@ -1,23 +1,25 @@ -import pytest import io from unittest.mock import Mock -from griptape.drivers import AmazonBedrockImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import AmazonBedrockImageQueryDriver class TestAmazonBedrockImageQueryDriver: - @pytest.fixture + @pytest.fixture() def bedrock_client(self, mocker): return Mock() - @pytest.fixture + @pytest.fixture() def session(self, bedrock_client): session = Mock() session.client.return_value = bedrock_client return session - @pytest.fixture + @pytest.fixture() def model_driver(self): model_driver = Mock() model_driver.image_query_request_parameters.return_value = {} @@ -25,7 +27,7 @@ def model_driver(self): return model_driver - @pytest.fixture + @pytest.fixture() def image_query_driver(self, session, model_driver): return AmazonBedrockImageQueryDriver(session=session, model="model", image_query_model_driver=model_driver) diff --git a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py index 24958d58f..db4b2407c 100644 --- a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py @@ -1,12 +1,14 @@ -import pytest import base64 from unittest.mock import Mock -from griptape.drivers import AnthropicImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact +from griptape.drivers import AnthropicImageQueryDriver class TestAnthropicImageQueryDriver: - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") return_value = Mock(text="Content") @@ -55,7 +57,7 @@ def test_try_query_max_tokens_value(self, mock_client): assert text_artifact.value == "Content" def test_try_query_max_tokens_none(self, mock_client): - driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore + driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore[reportArgumentType] test_prompt_string = "Prompt String" test_binary_data = b"test-data" with pytest.raises(TypeError): diff --git a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py index a44319861..a1d428197 100644 --- a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py @@ -1,11 +1,13 @@ -import pytest from unittest.mock import Mock -from griptape.drivers import AzureOpenAiImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact +from griptape.drivers import AzureOpenAiImageQueryDriver class TestAzureOpenAiVisionImageQueryDriver: - @pytest.fixture + @pytest.fixture() def mock_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create mock_choice = Mock(message=Mock(content="expected_output_text")) @@ -52,7 +54,7 @@ def test_try_query_multiple_choices(self, mock_completion_create): azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" ) - with pytest.raises(Exception): + with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) def _expected_messages(self, expected_prompt_string, expected_binary_data): diff --git a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py index 8efcfa749..02b69595f 100644 --- a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py @@ -1,12 +1,12 @@ -from griptape.drivers import DummyImageQueryDriver -from griptape.artifacts import ImageArtifact import pytest +from griptape.artifacts import ImageArtifact +from griptape.drivers import DummyImageQueryDriver from griptape.exceptions import DummyException class TestDummyImageQueryDriver: - @pytest.fixture + @pytest.fixture() def image_query_driver(self): return DummyImageQueryDriver() diff --git a/tests/unit/drivers/image_query/test_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_openai_image_query_driver.py index 08f0c70c9..9c4b011a6 100644 --- a/tests/unit/drivers/image_query/test_openai_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_openai_image_query_driver.py @@ -1,11 +1,13 @@ -import pytest from unittest.mock import Mock -from griptape.drivers import OpenAiImageQueryDriver + +import pytest + from griptape.artifacts import ImageArtifact +from griptape.drivers import OpenAiImageQueryDriver class TestOpenAiVisionImageQueryDriver: - @pytest.fixture + @pytest.fixture() def mock_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create mock_choice = Mock(message=Mock(content="expected_output_text")) @@ -43,7 +45,7 @@ def test_try_query_multiple_choices(self, mock_completion_create): mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") - with pytest.raises(Exception): + with pytest.raises(Exception, match="Image query responses with more than one choice are not supported yet."): driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]) def _expected_messages(self, expected_prompt_string, expected_binary_data): diff --git a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py index 14fa8ff28..c274f71dd 100644 --- a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py +++ b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py @@ -1,6 +1,7 @@ import pytest -from griptape.drivers import BedrockClaudeImageQueryModelDriver + from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import BedrockClaudeImageQueryModelDriver class TestBedrockClaudeImageQueryModelDriver: diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index ba79a4def..8e700d0a5 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -1,12 +1,13 @@ +import boto3 import pytest from moto import mock_dynamodb -import boto3 -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.utils.aws import mock_aws_credentials + +from griptape.drivers import AmazonDynamoDbConversationMemoryDriver from griptape.memory.structure import ConversationMemory -from griptape.tasks import PromptTask from griptape.structures import Pipeline -from griptape.drivers import AmazonDynamoDbConversationMemoryDriver +from griptape.tasks import PromptTask +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.utils.aws import mock_aws_credentials class TestDynamoDbConversationMemoryDriver: @@ -17,7 +18,7 @@ class TestDynamoDbConversationMemoryDriver: PARTITION_KEY_VALUE = "bar" @pytest.fixture(autouse=True) - def run_before_and_after_tests(self): + def _run_before_and_after_tests(self): mock_aws_credentials() self.mock_dynamodb = mock_dynamodb() self.mock_dynamodb.start() diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index c794afd0e..e1a383ab9 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -1,17 +1,20 @@ +import contextlib import os + import pytest -from tests.mocks.mock_prompt_driver import MockPromptDriver + from griptape.drivers import LocalConversationMemoryDriver from griptape.memory.structure import ConversationMemory -from griptape.tasks import PromptTask from griptape.structures import Pipeline +from griptape.tasks import PromptTask +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestLocalConversationMemoryDriver: MEMORY_FILE_PATH = "test_memory.json" @pytest.fixture(autouse=True) - def run_before_and_after_tests(self): + def _run_before_and_after_tests(self): self.__delete_file(self.MEMORY_FILE_PATH) yield @@ -28,7 +31,7 @@ def test_store(self): try: with open(self.MEMORY_FILE_PATH): - assert False + raise AssertionError() except FileNotFoundError: assert True @@ -74,8 +77,6 @@ def test_autoload(self): assert autoloaded_memory.runs[0].input.value == "test" assert autoloaded_memory.runs[0].output.value == "mock output" - def __delete_file(self, file_path): - try: + def __delete_file(self, file_path) -> None: + with contextlib.suppress(FileNotFoundError): os.remove(file_path) - except FileNotFoundError: - pass diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py index 1af9d74dc..4a92a28a8 100644 --- a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -1,7 +1,8 @@ import pytest import redis -from griptape.memory.structure.base_conversation_memory import BaseConversationMemory + from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver +from griptape.memory.structure.base_conversation_memory import BaseConversationMemory TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" @@ -13,7 +14,7 @@ class TestRedisConversationMemoryDriver: @pytest.fixture(autouse=True) - def mock_redis(self, mocker): + def _mock_redis(self, mocker): mocker.patch.object(redis.StrictRedis, "hset", return_value=None) mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"test"]) mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_CONVERSATION) @@ -25,13 +26,13 @@ def mock_redis(self, mocker): mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch) - @pytest.fixture + @pytest.fixture() def driver(self): return RedisConversationMemoryDriver(host=HOST, port=PORT, db=0, index=INDEX, conversation_id=CONVERSATION_ID) def test_store(self, driver): memory = BaseConversationMemory.from_json(TEST_CONVERSATION) - assert driver.store(memory) == None + assert driver.store(memory) is None def test_load(self, driver): memory = driver.load() diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index e31b6a448..6a58b09dc 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,10 +1,8 @@ import pytest -from griptape.artifacts import ImageArtifact, TextArtifact, ListArtifact, ErrorArtifact, ActionArtifact -from griptape.common import PromptStack -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction +from griptape.artifacts import ActionArtifact, ErrorArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import AmazonBedrockPromptDriver - from tests.mocks.mock_tool.tool import MockTool @@ -159,7 +157,7 @@ class TestAmazonBedrockPromptDriver: }, ] - @pytest.fixture + @pytest.fixture() def mock_converse(self, mocker): mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse @@ -177,7 +175,7 @@ def mock_converse(self, mocker): return mock_converse - @pytest.fixture + @pytest.fixture() def mock_converse_stream(self, mocker): mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream @@ -275,7 +273,7 @@ def prompt_stack(self, request): return prompt_stack - @pytest.fixture + @pytest.fixture() def messages(self): return [ {"role": "user", "content": [{"text": "user-input"}]}, diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index a75fc6ed0..e74797e42 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -1,11 +1,13 @@ +import json +from io import BytesIO from typing import Any + +import pytest from botocore.response import StreamingBody -from griptape.tokenizers import HuggingFaceTokenizer -from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver + from griptape.common import PromptStack -from io import BytesIO -import json -import pytest +from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver +from griptape.tokenizers import HuggingFaceTokenizer def to_streaming_body(data: Any) -> StreamingBody: diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index e8f6d337f..3b1343336 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,10 +1,11 @@ -from griptape.artifacts.error_artifact import ErrorArtifact -from griptape.drivers import AnthropicPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction -from griptape.artifacts import TextArtifact, ActionArtifact, ImageArtifact, ListArtifact from unittest.mock import Mock + import pytest +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction +from griptape.drivers import AnthropicPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -131,7 +132,7 @@ class TestAnthropicPromptDriver: }, ] - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") mock_tool_use = Mock(type="tool_use", id="mock-id", input={"foo": "bar"}) @@ -150,7 +151,7 @@ def mock_client(self, mocker): return mock_client - @pytest.fixture + @pytest.fixture() def mock_stream_client(self, mocker): mock_stream_client = mocker.patch("anthropic.Anthropic") @@ -263,7 +264,7 @@ def prompt_stack(self, request): return prompt_stack - @pytest.fixture + @pytest.fixture() def messages(self): return [ {"role": "user", "content": "user-input"}, diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 9e56b39bd..dc0b54b0a 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -1,13 +1,15 @@ -import pytest -from griptape.artifacts import TextArtifact, ActionArtifact from unittest.mock import Mock + +import pytest + +from griptape.artifacts import ActionArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, TextDeltaMessageContent from griptape.drivers import AzureOpenAiChatPromptDriver -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin class TestAzureOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): - @pytest.fixture + @pytest.fixture() def mock_chat_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id") @@ -22,7 +24,7 @@ def mock_chat_completion_create(self, mocker): return mock_chat_create - @pytest.fixture + @pytest.fixture() def mock_chat_completion_stream_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create mock_tool_call_delta_header = Mock() diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 6eb000e1f..3c2bb333e 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,11 +1,11 @@ +from griptape.artifacts import ErrorArtifact, TextArtifact +from griptape.common import PromptStack from griptape.common.prompt_stack.messages.message import Message from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.common import PromptStack -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver -from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.tasks import PromptTask from griptape.structures import Pipeline +from griptape.tasks import PromptTask +from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBasePromptDriver: diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 167c08b34..c642b7ee0 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -5,9 +5,8 @@ from griptape.artifacts.action_artifact import ActionArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact -from griptape.common import PromptStack, ToolAction +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import CoherePromptDriver -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from tests.mocks.mock_tool.tool import MockTool @@ -42,7 +41,7 @@ class TestCoherePromptDriver: }, ] - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value mock_tool_call = Mock(parameters={"foo": "bar"}) @@ -53,7 +52,7 @@ def mock_client(self, mocker): return mock_client - @pytest.fixture + @pytest.fixture() def mock_stream_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value mock_tool_call_delta_header = Mock() diff --git a/tests/unit/drivers/prompt/test_dummy_prompt_driver.py b/tests/unit/drivers/prompt/test_dummy_prompt_driver.py index d569b55af..203bad3d5 100644 --- a/tests/unit/drivers/prompt/test_dummy_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_dummy_prompt_driver.py @@ -1,11 +1,11 @@ -from griptape.drivers import DummyPromptDriver import pytest +from griptape.drivers import DummyPromptDriver from griptape.exceptions import DummyException class TestDummyPromptDriver: - @pytest.fixture + @pytest.fixture() def prompt_driver(self): return DummyPromptDriver() diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index b1a72d10d..478ef8fb9 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,15 +1,15 @@ -from google.generativeai.types import ContentDict, GenerationConfig +from unittest.mock import Mock + +import pytest from google.generativeai.protos import FunctionCall, FunctionResponse, Part -from griptape.artifacts import TextArtifact, ImageArtifact, ActionArtifact +from google.generativeai.types import ContentDict, GenerationConfig +from google.protobuf.json_format import MessageToDict + +from griptape.artifacts import ActionArtifact, ImageArtifact, TextArtifact from griptape.artifacts.list_artifact import ListArtifact -from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import GooglePromptDriver -from griptape.common import PromptStack -from unittest.mock import Mock from tests.mocks.mock_tool.tool import MockTool -from google.protobuf.json_format import MessageToDict - -import pytest class TestGooglePromptDriver: @@ -43,7 +43,7 @@ class TestGooglePromptDriver: }, ] - @pytest.fixture + @pytest.fixture() def mock_generative_model(self, mocker): mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") mock_function_call = Mock(type="tool_use", id="MockTool_test", args={"foo": "bar"}) @@ -55,7 +55,7 @@ def mock_generative_model(self, mocker): return mock_generative_model - @pytest.fixture + @pytest.fixture() def mock_stream_generative_model(self, mocker): mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") mock_function_call_delta = Mock(type="tool_use", id="MockTool_test", args={"foo": "bar"}) @@ -117,7 +117,7 @@ def prompt_stack(self, request): return prompt_stack - @pytest.fixture + @pytest.fixture() def messages(self): return [ {"parts": ["user-input"], "role": "user"}, diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index ec7ea73f8..1a4e1b25b 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,10 +1,11 @@ -from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent import pytest +from griptape.common import PromptStack, TextDeltaMessageContent +from griptape.drivers import HuggingFaceHubPromptDriver + class TestHuggingFaceHubPromptDriver: - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value @@ -20,14 +21,14 @@ def tokenizer(self, mocker): return tokenizer - @pytest.fixture + @pytest.fixture() def mock_client_stream(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value mock_client.text_generation.return_value = iter(["model-output"]) return mock_client - @pytest.fixture + @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index a63d697fb..5323f5d2d 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -1,7 +1,8 @@ -from griptape.drivers import HuggingFacePipelinePromptDriver -from griptape.common import PromptStack import pytest +from griptape.common import PromptStack +from griptape.drivers import HuggingFacePipelinePromptDriver + class TestHuggingFacePipelinePromptDriver: @pytest.fixture(autouse=True) @@ -25,7 +26,7 @@ def mock_autotokenizer(self, mocker): mock_autotokenizer.encode.return_value = [1, 2, 3] return mock_autotokenizer - @pytest.fixture + @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index a247a77ab..e51da368a 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,12 +1,13 @@ +import pytest + +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact +from griptape.common import PromptStack from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent from griptape.drivers import OllamaPromptDriver -from griptape.common import PromptStack -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -import pytest class TestOllamaPromptDriver: - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") @@ -14,7 +15,7 @@ def mock_client(self, mocker): return mock_client - @pytest.fixture + @pytest.fixture() def mock_stream_client(self, mocker): mock_stream_client = mocker.patch("ollama.Client") mock_stream_client.return_value.chat.return_value = iter([{"message": {"content": "model-output"}}]) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 59772ff23..23a40e20c 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,12 +1,13 @@ -from griptape.artifacts import ImageArtifact, ListArtifact -from griptape.artifacts import TextArtifact, ActionArtifact +from unittest.mock import Mock + +import pytest + +from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent, ActionCallDeltaMessageContent, ToolAction from griptape.tokenizers import OpenAiTokenizer -from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer from tests.mocks.mock_tool.tool import MockTool -import pytest class TestOpenAiChatPromptDriverFixtureMixin: @@ -153,7 +154,7 @@ class TestOpenAiChatPromptDriverFixtureMixin: }, ] - @pytest.fixture + @pytest.fixture() def mock_chat_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id") @@ -168,7 +169,7 @@ def mock_chat_completion_create(self, mocker): return mock_chat_create - @pytest.fixture + @pytest.fixture() def mock_chat_completion_stream_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create mock_tool_call_delta_header = Mock() @@ -206,7 +207,7 @@ def mock_chat_completion_stream_create(self, mocker): ) return mock_chat_create - @pytest.fixture + @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() prompt_stack.tools = [MockTool()] @@ -244,7 +245,7 @@ def prompt_stack(self): ) return prompt_stack - @pytest.fixture + @pytest.fixture() def messages(self): return [ {"role": "system", "content": "system-input"}, @@ -283,7 +284,7 @@ def __init__( remaining_tokens=234, limit_requests=345, limit_tokens=456, - ): + ) -> None: self.reset_requests_in = reset_requests_in self.reset_requests_in_unit = reset_requests_in_unit self.reset_tokens_in = reset_tokens_in diff --git a/tests/unit/drivers/rerank/test_cohere_rerank_driver.py b/tests/unit/drivers/rerank/test_cohere_rerank_driver.py index 952546e5a..87a727269 100644 --- a/tests/unit/drivers/rerank/test_cohere_rerank_driver.py +++ b/tests/unit/drivers/rerank/test_cohere_rerank_driver.py @@ -1,11 +1,12 @@ import pytest -from cohere import RerankResponseResultsItemDocument, RerankResponseResultsItem +from cohere import RerankResponseResultsItem, RerankResponseResultsItemDocument + from griptape.artifacts import TextArtifact from griptape.drivers import CohereRerankDriver class TestCohereRerankDriver: - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value mock_client.rerank.return_value.results = [ diff --git a/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py b/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py index d67e1c557..966df481c 100644 --- a/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py +++ b/tests/unit/drivers/sql/test_amazon_redshift_sql_driver.py @@ -1,7 +1,8 @@ -import pytest import boto3 +import pytest from botocore.stub import Stubber -from griptape.drivers import BaseSqlDriver, AmazonRedshiftSqlDriver + +from griptape.drivers import AmazonRedshiftSqlDriver, BaseSqlDriver class TestAmazonRedshiftSqlDriver: @@ -41,7 +42,7 @@ class TestAmazonRedshiftSqlDriver: }, ] - @pytest.fixture + @pytest.fixture() def statement_driver(self): session = boto3.Session(region_name="us-east-1") client = session.client("redshift-data") @@ -108,7 +109,7 @@ def statement_driver(self): return AmazonRedshiftSqlDriver(database="dev", session=session, workgroup_name="dev", client=client) - @pytest.fixture + @pytest.fixture() def describe_table_driver(self): session = boto3.Session(region_name="us-east-1") client = session.client("redshift-data") diff --git a/tests/unit/drivers/sql/test_snowflake_sql_driver.py b/tests/unit/drivers/sql/test_snowflake_sql_driver.py index 91403a467..055b1e744 100644 --- a/tests/unit/drivers/sql/test_snowflake_sql_driver.py +++ b/tests/unit/drivers/sql/test_snowflake_sql_driver.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from unittest import mock + import pytest -from sqlalchemy import create_engine from snowflake.connector import SnowflakeConnection +from sqlalchemy import create_engine + from griptape.drivers import BaseSqlDriver, SnowflakeSqlDriver @@ -11,7 +13,7 @@ class TestSnowflakeSqlDriver: TEST_COLUMNS = [("first_name", "VARCHAR"), ("last_name", "VARCHAR")] - @pytest.fixture + @pytest.fixture() def mock_table(self, mocker): @dataclass class Column: @@ -21,13 +23,13 @@ class Column: mock_table = mocker.MagicMock(name="table", columns=[Column("first_name"), Column("last_name")]) return mock_table - @pytest.fixture + @pytest.fixture() def mock_metadata(self, mocker): mock_meta = mocker.MagicMock(name="metadata") mock_meta.reflect.return_value = None return mock_meta - @pytest.fixture + @pytest.fixture() def mock_snowflake_engine(self, mocker): mock_engine = mocker.MagicMock(name="engine") result_mock = mocker.MagicMock(name="result") @@ -44,22 +46,22 @@ def mock_snowflake_engine(self, mocker): return mock_engine - @pytest.fixture + @pytest.fixture() def mock_snowflake_connection(self, mocker): mock_connection = mocker.MagicMock(spec=SnowflakeConnection, name="connection") return mock_connection - @pytest.fixture + @pytest.fixture() def mock_snowflake_connection_no_schema(self, mocker): mock_connection = mocker.MagicMock(spec=SnowflakeConnection, name="connection_no_schema", schema=None) return mock_connection - @pytest.fixture + @pytest.fixture() def mock_snowflake_connection_no_database(self, mocker): mock_connection = mocker.MagicMock(spec=SnowflakeConnection, name="connection_no_database", database=None) return mock_connection - @pytest.fixture + @pytest.fixture() def driver(self, mock_snowflake_engine, mock_snowflake_connection): def get_connection(): return mock_snowflake_connection diff --git a/tests/unit/drivers/sql/test_sql_driver.py b/tests/unit/drivers/sql/test_sql_driver.py index d4caf6f50..742acd0b2 100644 --- a/tests/unit/drivers/sql/test_sql_driver.py +++ b/tests/unit/drivers/sql/test_sql_driver.py @@ -1,9 +1,10 @@ import pytest + from griptape.drivers import SqlDriver class TestSqlDriver: - @pytest.fixture + @pytest.fixture() def driver(self): new_driver = SqlDriver(engine_url="sqlite:///:memory:") diff --git a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py index b056241ec..bdd5cd3ed 100644 --- a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py @@ -1,9 +1,10 @@ import pytest -from griptape.artifacts import TextArtifact, InfoArtifact + +from griptape.artifacts import InfoArtifact, TextArtifact class TestGriptapeCloudStructureRunDriver: - @pytest.fixture + @pytest.fixture() def driver(self, mocker): from griptape.drivers import GriptapeCloudStructureRunDriver diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index cb7b3058e..316f7bf71 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -1,11 +1,9 @@ import os -import pytest -from griptape.artifacts.text_artifact import TextArtifact + +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask -from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.drivers import LocalStructureRunDriver -from griptape.structures import Pipeline class TestLocalStructureRunDriver: @@ -22,8 +20,8 @@ def test_run(self): def test_run_with_env(self): pipeline = Pipeline() - agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["key"])) - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"key": "value"}) + agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["KEY"])) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) pipeline.add_task(task) diff --git a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py index 2d90bc2f5..26c29adcf 100644 --- a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py +++ b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py @@ -1,10 +1,12 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.drivers import ElevenLabsTextToSpeechDriver class TestElevenLabsTextToSpeechDriver: - @pytest.fixture + @pytest.fixture() def driver(self): return ElevenLabsTextToSpeechDriver(model="model", client=Mock(), voice="voice", api_key="key") diff --git a/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py b/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py index 57c5a5e2e..f9c22a725 100644 --- a/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py +++ b/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py @@ -1,16 +1,17 @@ -import pytest from unittest.mock import Mock +import pytest + from griptape.artifacts import AudioArtifact from griptape.drivers import OpenAiAudioTranscriptionDriver class TestOpenAiAudioTranscriptionDriver: - @pytest.fixture + @pytest.fixture() def audio_artifact(self): return AudioArtifact(value=b"audio data", format="mp3") - @pytest.fixture + @pytest.fixture() def driver(self): return OpenAiAudioTranscriptionDriver(model="model", client=Mock(), api_key="key") diff --git a/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py b/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py index b66cc057e..e5988b234 100644 --- a/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_amazon_opensearch_vector_store_driver.py @@ -1,12 +1,14 @@ +from unittest.mock import Mock, create_autospec, patch + +import boto3 +import numpy as np import pytest -from unittest.mock import patch, Mock, create_autospec + from griptape.drivers import AmazonOpenSearchVectorStoreDriver -import numpy as np -import boto3 class TestAmazonOpenSearchVectorStoreDriver: - @pytest.fixture + @pytest.fixture() def driver(self): mock_session = create_autospec(boto3.Session, instance=True) mock_driver = create_autospec(AmazonOpenSearchVectorStoreDriver, instance=True, session=mock_session) diff --git a/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py b/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py index b68486914..6dd4fa5e9 100644 --- a/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_azure_mongodb_vector_store_driver.py @@ -1,14 +1,13 @@ -import pytest import mongomock -from unittest.mock import patch -from pymongo.errors import OperationFailure +import pytest + from griptape.artifacts import TextArtifact from griptape.drivers import AzureMongoDbVectorStoreDriver, BaseVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestAzureMongoDbVectorStoreDriver: - @pytest.fixture + @pytest.fixture() def driver(self, monkeypatch): embedding_driver = MockEmbeddingDriver() return AzureMongoDbVectorStoreDriver( @@ -66,15 +65,18 @@ def test_load_entries(self, driver): vector = [0.5, 0.5, 0.5] driver.upsert_vector(vector, vector_id=vector_id_str) # ensure at least one entry exists results = list(driver.load_entries()) - assert results is not None and len(results) > 0 + assert results is not None + assert len(results) > 0 def test_delete(self, driver): vector_id_str = "123" vector = [0.5, 0.5, 0.5] driver.upsert_vector(vector, vector_id=vector_id_str) # ensure at least one entry exists results = list(driver.load_entries()) - assert results is not None and len(results) > 0 + assert results is not None + assert len(results) > 0 driver.delete_vector(vector_id_str) results = list(driver.load_entries()) - assert results is not None and len(results) == 0 + assert results is not None + assert len(results) == 0 diff --git a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py index 8c08292dd..ac4ff8043 100644 --- a/tests/unit/drivers/vector/test_base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod -import pytest from unittest.mock import patch + +import pytest + from griptape.artifacts import TextArtifact from griptape.artifacts.csv_row_artifact import CsvRowArtifact class BaseLocalVectorStoreDriver(ABC): - @pytest.fixture + @pytest.fixture() @abstractmethod def driver(self): ... diff --git a/tests/unit/drivers/vector/test_dummy_vector_store_driver.py b/tests/unit/drivers/vector/test_dummy_vector_store_driver.py index 720778c38..df4867212 100644 --- a/tests/unit/drivers/vector/test_dummy_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_dummy_vector_store_driver.py @@ -1,10 +1,11 @@ import pytest + from griptape.drivers import DummyVectorStoreDriver from griptape.exceptions import DummyException class TestDummyVectorStoreDriver: - @pytest.fixture + @pytest.fixture() def vector_store_driver(self): return DummyVectorStoreDriver() diff --git a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py index 957edebb8..0f52ba6c5 100644 --- a/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_griptape_cloud_knowledge_base_vector_store_driver.py @@ -1,5 +1,7 @@ -import pytest import uuid + +import pytest + from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver @@ -10,7 +12,7 @@ class TestGriptapeCloudKnowledgeBaseVectorStoreDriver: test_metas = [{"key": "value1"}, {"key": "value2"}] test_scores = [0.7, 0.8] - @pytest.fixture + @pytest.fixture() def driver(self, mocker): test_entries = { "entries": [ diff --git a/tests/unit/drivers/vector/test_local_vector_store_driver.py b/tests/unit/drivers/vector/test_local_vector_store_driver.py index 314f2fd6d..6f022793c 100644 --- a/tests/unit/drivers/vector/test_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_local_vector_store_driver.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -6,7 +7,7 @@ class TestLocalVectorStoreDriver(BaseLocalVectorStoreDriver): - @pytest.fixture + @pytest.fixture() def driver(self): return LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py index f42906035..5c2399bc5 100644 --- a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py @@ -1,7 +1,9 @@ from collections import namedtuple + import pytest -from griptape.drivers import MarqoVectorStoreDriver + from griptape.artifacts import TextArtifact +from griptape.drivers import MarqoVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -79,7 +81,7 @@ def mock_marqo(self, mocker): # Return the mock_client for use in other fixtures return mock_client - @pytest.fixture + @pytest.fixture() def driver(self, mock_marqo): return MarqoVectorStoreDriver( api_key="foobar", diff --git a/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py b/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py index 5b9aeed06..20cb8bdc0 100644 --- a/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_mongodb_atlas_vector_store_driver.py @@ -1,14 +1,13 @@ -import pytest import mongomock -from unittest.mock import patch -from pymongo.errors import OperationFailure +import pytest + from griptape.artifacts import TextArtifact -from griptape.drivers import MongoDbAtlasVectorStoreDriver, BaseVectorStoreDriver +from griptape.drivers import BaseVectorStoreDriver, MongoDbAtlasVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestMongoDbAtlasVectorStoreDriver: - @pytest.fixture + @pytest.fixture() def driver(self, monkeypatch): embedding_driver = MockEmbeddingDriver() return MongoDbAtlasVectorStoreDriver( @@ -66,15 +65,18 @@ def test_load_entries(self, driver): vector = [0.5, 0.5, 0.5] driver.upsert_vector(vector, vector_id=vector_id_str) # ensure at least one entry exists results = list(driver.load_entries()) - assert results is not None and len(results) > 0 + assert results is not None + assert len(results) > 0 def test_delete(self, driver): vector_id_str = "123" vector = [0.5, 0.5, 0.5] driver.upsert_vector(vector, vector_id=vector_id_str) # ensure at least one entry exists results = list(driver.load_entries()) - assert results is not None and len(results) > 0 + assert results is not None + assert len(results) > 0 driver.delete_vector(vector_id_str) results = list(driver.load_entries()) - assert results is not None and len(results) == 0 + assert results is not None + assert len(results) == 0 diff --git a/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py b/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py index d2c967caf..cef3805ab 100644 --- a/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_opensearch_vector_store_driver.py @@ -1,11 +1,13 @@ +from unittest.mock import Mock, create_autospec, patch + +import numpy as np import pytest -from unittest.mock import patch, Mock, create_autospec + from griptape.drivers import OpenSearchVectorStoreDriver -import numpy as np class TestOpenSearchVectorStoreDriver: - @pytest.fixture + @pytest.fixture() def driver(self): mock_driver = create_autospec(OpenSearchVectorStoreDriver, instance=True) mock_driver.upsert_vector.return_value = "foo" diff --git a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py index 8f6773fc1..c130858b5 100644 --- a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py @@ -1,6 +1,8 @@ import os import tempfile + import pytest + from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -8,12 +10,12 @@ class TestPersistentLocalVectorStoreDriver(BaseLocalVectorStoreDriver): - @pytest.fixture + @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: yield temp_dir - @pytest.fixture + @pytest.fixture() def driver(self, temp_dir): persist_file = os.path.join(temp_dir, "store.json") diff --git a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py index 3854ea4f1..29b5ad82e 100644 --- a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py @@ -1,25 +1,27 @@ -from typing import Any import uuid -import pytest +from typing import Any from unittest.mock import MagicMock, Mock + +import pytest +from sqlalchemy import create_engine + from griptape.drivers import PgVectorVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from sqlalchemy import create_engine class TestPgVectorVectorStoreDriver: connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" table_name = "griptape_vectors" - @pytest.fixture + @pytest.fixture() def embedding_driver(self): return MockEmbeddingDriver() - @pytest.fixture + @pytest.fixture() def mock_engine(self): return MagicMock() - @pytest.fixture + @pytest.fixture() def mock_session(self, mocker): session = MagicMock() mock_session_manager = MagicMock() @@ -30,14 +32,14 @@ def mock_session(self, mocker): def test_initialize_requires_engine_or_connection_string(self, embedding_driver): with pytest.raises(ValueError): - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) + PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) def test_initialize_accepts_engine(self, embedding_driver): engine: Any = create_engine(self.connection_string) - driver = PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) + PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) def test_initialize_accepts_connection_string(self, embedding_driver): - driver = PgVectorVectorStoreDriver( + PgVectorVectorStoreDriver( embedding_driver=embedding_driver, connection_string=self.connection_string, table_name=self.table_name ) diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 7aea4d411..0726a0c7e 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -1,19 +1,13 @@ import pytest -from griptape import utils from griptape.artifacts import TextArtifact from griptape.drivers import PineconeVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestPineconeVectorStorageDriver: - """ - This should really be under `unit` but the Pinecone client results - in tests hanging on GitHub. - """ - @pytest.fixture(autouse=True) - def mock_pinecone(self, mocker): + def _mock_pinecone(self, mocker): # Create a fake response fake_query_response = { "matches": [{"id": "foo", "values": [0, 1, 0], "score": 42, "metadata": {"foo": "bar"}}], @@ -25,7 +19,7 @@ def mock_pinecone(self, mocker): mock_client().Index().query.return_value = fake_query_response mock_client().create_index.return_value = None - @pytest.fixture + @pytest.fixture() def driver(self): return PineconeVectorStoreDriver( api_key="foobar", index_name="test", environment="test", embedding_driver=MockEmbeddingDriver() diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py index 8abfbf4f7..0b22784eb 100644 --- a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -1,17 +1,19 @@ -import pytest +import uuid from unittest.mock import MagicMock, patch + +import pytest + from griptape.drivers import QdrantVectorStoreDriver -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from griptape.utils import import_optional_dependency -import uuid +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestQdrantVectorStoreDriver: - @pytest.fixture + @pytest.fixture() def embedding_driver(self): return MockEmbeddingDriver() - @pytest.fixture + @pytest.fixture() def mock_engine(self): return MagicMock() diff --git a/tests/unit/drivers/vector/test_redis_vector_store_driver.py b/tests/unit/drivers/vector/test_redis_vector_store_driver.py index 18759a2d7..2f74b9279 100644 --- a/tests/unit/drivers/vector/test_redis_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_redis_vector_store_driver.py @@ -1,7 +1,9 @@ from unittest.mock import MagicMock + import pytest -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver + from griptape.drivers import RedisVectorStoreDriver +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestRedisVectorStorageDriver: @@ -9,12 +11,12 @@ class TestRedisVectorStorageDriver: def mock_client(self, mocker): return mocker.patch("redis.Redis").return_value - @pytest.fixture + @pytest.fixture() def mock_keys(self, mock_client): mock_client.keys.return_value = [b"some_vector_id"] return mock_client.keys - @pytest.fixture + @pytest.fixture() def mock_hgetall(self, mock_client): mock_client.hgetall.return_value = { b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", @@ -22,13 +24,13 @@ def mock_hgetall(self, mock_client): } return mock_client.hgetall - @pytest.fixture + @pytest.fixture() def driver(self): return RedisVectorStoreDriver( host="localhost", port=6379, index="test_index", db=0, embedding_driver=MockEmbeddingDriver() ) - @pytest.fixture + @pytest.fixture() def mock_search(self, mock_client): mock_client.ft.return_value.search.return_value.docs = [ MagicMock( diff --git a/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py b/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py index 33500839f..dbdafa98f 100644 --- a/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py +++ b/tests/unit/drivers/web_scraper/test_markdownify_web_scraper_driver.py @@ -1,4 +1,5 @@ from textwrap import dedent + import pytest from griptape.drivers.web_scraper.markdownify_web_scraper_driver import MarkdownifyWebScraperDriver @@ -15,13 +16,13 @@ def mock_content(self, mock_playwright): mock_content.return_value = 'foobar' return mock_content - @pytest.fixture + @pytest.fixture() def web_scraper(self): return MarkdownifyWebScraperDriver() def test_scrape_url(self, web_scraper): artifact = web_scraper.scrape_url("https://example.com/") - assert "[foobar](foobar.com)" == artifact.value + assert artifact.value == "[foobar](foobar.com)" def test_scrape_url_whitespace(self, web_scraper, mock_content): mock_content.return_value = dedent( @@ -46,35 +47,35 @@ def test_scrape_url_whitespace(self, web_scraper, mock_content): """ ) artifact = web_scraper.scrape_url("https://example.com/") - assert "foo\n---\n\n* bar:\n + baz\n + baz\n\n + baz" == artifact.value + assert artifact.value == "foo\n---\n\n* bar:\n + baz\n + baz\n\n + baz" def test_scrape_url_no_excludes(self): web_scraper = MarkdownifyWebScraperDriver(exclude_tags=[], exclude_classes=[], exclude_ids=[]) artifact = web_scraper.scrape_url("https://example.com/") - assert "[foobar](foobar.com)" == artifact.value + assert artifact.value == "[foobar](foobar.com)" def test_scrape_url_exclude_links(self): web_scraper = MarkdownifyWebScraperDriver(include_links=False) artifact = web_scraper.scrape_url("https://example.com/") - assert "foobar" == artifact.value + assert artifact.value == "foobar" def test_scrape_url_exclude_tags(self, mock_content): mock_content.return_value = "powwow" web_scraper = MarkdownifyWebScraperDriver(exclude_tags=["wow"], exclude_classes=[], exclude_ids=[]) artifact = web_scraper.scrape_url("https://example.com/") - assert "pow" == artifact.value + assert artifact.value == "pow" def test_scrape_url_exclude_classes(self, mock_content): mock_content.return_value = 'powwow' web_scraper = MarkdownifyWebScraperDriver(exclude_tags=[], exclude_classes=["now"], exclude_ids=[]) artifact = web_scraper.scrape_url("https://example.com/") - assert "pow" == artifact.value + assert artifact.value == "pow" def test_scrape_url_exclude_ids(self, mock_content): mock_content.return_value = 'powwow' web_scraper = MarkdownifyWebScraperDriver(exclude_tags=[], exclude_classes=[], exclude_ids=["cow"]) artifact = web_scraper.scrape_url("https://example.com/") - assert "pow" == artifact.value + assert artifact.value == "pow" def test_scrape_url_raises_on_empty_string_from_playwright(self, web_scraper, mock_content): mock_content.return_value = "" diff --git a/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py b/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py index 569800c6a..95e5a2880 100644 --- a/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py +++ b/tests/unit/drivers/web_scraper/test_proxy_web_scraper_driver.py @@ -1,7 +1,7 @@ import pytest -from griptape.drivers import ProxyWebScraperDriver from griptape.artifacts import TextArtifact +from griptape.drivers import ProxyWebScraperDriver class TestProxyWebScraperDriver: @@ -11,11 +11,11 @@ def mock_client(self, mocker): mock_response.text = "test_scrape" return mocker.patch("requests.get", return_value=mock_response) - @pytest.fixture + @pytest.fixture() def mock_client_error(self, mocker): return mocker.patch("requests.get", side_effect=Exception("test_error")) - @pytest.fixture + @pytest.fixture() def web_scraper(self, mocker): return ProxyWebScraperDriver( proxies={"http": "http://localhost:8080", "https": "http://localhost:8080"}, @@ -26,7 +26,7 @@ def test_scrape_url(self, web_scraper, mock_client): output = web_scraper.scrape_url("https://example.com/") mock_client.assert_called_with("https://example.com/", proxies=web_scraper.proxies, test_param="test_param") assert isinstance(output, TextArtifact) - assert "test_scrape" == output.value + assert output.value == "test_scrape" def test_scrape_url_error(self, web_scraper, mock_client_error): with pytest.raises(Exception, match="test_error"): diff --git a/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py b/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py index f2ea56666..53ddf4500 100644 --- a/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py +++ b/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py @@ -5,7 +5,7 @@ class TestTrafilaturaWebScraperDriver: @pytest.fixture(autouse=True) - def mock_fetch_url(self, mocker): + def _mock_fetch_url(self, mocker): # Through trial and error, I've found that include_links in trafilatura's extract does not work # if the body of the page is not long enough, which is why I'm adding an arbitrary number of # characters to the body. @@ -13,7 +13,7 @@ def mock_fetch_url(self, mocker): "trafilatura.fetch_url" ).return_value = f'{"x"*243}foobar' - @pytest.fixture + @pytest.fixture() def web_scraper(self): return TrafilaturaWebScraperDriver(include_links=True) diff --git a/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py b/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py index fcacc274c..3d0a782eb 100644 --- a/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_duck_duck_go_web_search_driver.py @@ -1,11 +1,13 @@ -import pytest import json -from griptape.drivers import DuckDuckGoWebSearchDriver + +import pytest + from griptape.artifacts import ListArtifact +from griptape.drivers import DuckDuckGoWebSearchDriver class TestDuckDuckGoWebSearchDriver: - @pytest.fixture + @pytest.fixture() def driver(self, mocker): mock_response = [ {"title": "foo", "href": "bar", "body": "baz"}, @@ -16,7 +18,7 @@ def driver(self, mocker): return DuckDuckGoWebSearchDriver() - @pytest.fixture + @pytest.fixture() def driver_with_error(self, mocker): mocker.patch("duckduckgo_search.DDGS.text", side_effect=Exception("test_error")) diff --git a/tests/unit/drivers/web_search/test_google_web_search_driver.py b/tests/unit/drivers/web_search/test_google_web_search_driver.py index 9ecb92f46..3809532de 100644 --- a/tests/unit/drivers/web_search/test_google_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_google_web_search_driver.py @@ -1,13 +1,13 @@ -from pytest import fixture -import pytest -from griptape.drivers import GoogleWebSearchDriver -from griptape.artifacts import ErrorArtifact import json + +import pytest from pytest_mock import MockerFixture +from griptape.drivers import GoogleWebSearchDriver + class TestGoogleWebSearchDriver: - @fixture + @pytest.fixture() def driver(self, mocker: MockerFixture): mock_response = mocker.Mock() mocker.patch.object( @@ -19,7 +19,7 @@ def driver(self, mocker: MockerFixture): return GoogleWebSearchDriver(api_key="test", search_id="test") - @fixture + @pytest.fixture() def driver_with_error(self, mocker: MockerFixture): mock_response = mocker.Mock() mock_response.status_code = 500 diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index ded595d59..f69d8a0ba 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -1,10 +1,11 @@ import pytest + from griptape.engines import CsvExtractionEngine from tests.mocks.mock_prompt_driver import MockPromptDriver class TestCsvExtractionEngine: - @pytest.fixture + @pytest.fixture() def engine(self): return CsvExtractionEngine(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index 797c5de7a..d95adbb43 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -1,12 +1,13 @@ import pytest from schema import Schema + from griptape.artifacts import ErrorArtifact from griptape.engines import JsonExtractionEngine from tests.mocks.mock_prompt_driver import MockPromptDriver class TestJsonExtractionEngine: - @pytest.fixture + @pytest.fixture() def engine(self): return JsonExtractionEngine( prompt_driver=MockPromptDriver( diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index e5ba50a5b..385cf0c04 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.common import Reference from griptape.engines.rag import RagContext @@ -7,7 +8,7 @@ class TestFootnotePromptResponseRagModule: - @pytest.fixture + @pytest.fixture() def module(self): return FootnotePromptResponseRagModule(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index f262d6d06..2f8a912e2 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule @@ -6,7 +7,7 @@ class TestPromptResponseRagModule: - @pytest.fixture + @pytest.fixture() def module(self): return PromptResponseRagModule(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py index 2750257f4..bc85cf266 100644 --- a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py @@ -1,6 +1,6 @@ from griptape.engines.rag import RagContext from griptape.engines.rag.modules import RulesetsBeforeResponseRagModule -from griptape.rules import Ruleset, Rule +from griptape.rules import Rule, Ruleset class TestRulesetsBeforeResponseRagModule: diff --git a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py index 6488d650e..ae4410b2c 100644 --- a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py @@ -1,11 +1,12 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import TextChunksResponseRagModule class TestTextChunksResponseRagModule: - @pytest.fixture + @pytest.fixture() def module(self): return TextChunksResponseRagModule() diff --git a/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py index fa3bfecb2..dda6e89e7 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_text_chunks_rerank_rag_module.py @@ -1,5 +1,6 @@ import pytest from cohere import RerankResponseResultsItem, RerankResponseResultsItemDocument + from griptape.artifacts import TextArtifact from griptape.drivers import CohereRerankDriver from griptape.engines.rag import RagContext @@ -7,7 +8,7 @@ class TestTextChunksRerankRagModule: - @pytest.fixture + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value mock_client.rerank.return_value.results = [ diff --git a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py index 7c69f674a..69e334c7f 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_text_loader_retrieval_rag_module.py @@ -11,7 +11,7 @@ class TestTextLoaderRetrievalRagModule: @pytest.fixture(autouse=True) - def mock_trafilatura_fetch_url(self, mocker): + def _mock_trafilatura_fetch_url(self, mocker): mocker.patch("trafilatura.fetch_url", return_value="foobar") def test_run(self): diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index a39c0c2f1..c3d728bb3 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -1,14 +1,15 @@ import pytest + from griptape.drivers import LocalVectorStoreDriver -from griptape.engines.rag import RagEngine, RagContext -from griptape.engines.rag.modules import VectorStoreRetrievalRagModule, PromptResponseRagModule -from griptape.engines.rag.stages import RetrievalRagStage, ResponseRagStage +from griptape.engines.rag import RagContext, RagEngine +from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule +from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver class TestRagEngine: - @pytest.fixture + @pytest.fixture() def engine(self): return RagEngine( retrieval_stage=RetrievalRagStage( diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 34c6e3563..e826a2b4d 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,13 +1,15 @@ +import os + import pytest -from griptape.artifacts import TextArtifact, ListArtifact -from griptape.engines import PromptSummaryEngine + +from griptape.artifacts import ListArtifact, TextArtifact from griptape.common import PromptStack +from griptape.engines import PromptSummaryEngine from tests.mocks.mock_prompt_driver import MockPromptDriver -import os class TestPromptSummaryEngine: - @pytest.fixture + @pytest.fixture() def engine(self): return PromptSummaryEngine(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index 595c90f1f..778f7c096 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -1,17 +1,19 @@ import time + import pytest + from griptape.artifacts.base_artifact import BaseArtifact from griptape.events import ( - StartPromptEvent, + BaseEvent, + CompletionChunkEvent, + FinishActionsSubtaskEvent, FinishPromptEvent, - StartTaskEvent, + FinishStructureRunEvent, FinishTaskEvent, StartActionsSubtaskEvent, - FinishActionsSubtaskEvent, - CompletionChunkEvent, + StartPromptEvent, StartStructureRunEvent, - FinishStructureRunEvent, - BaseEvent, + StartTaskEvent, ) from tests.mocks.mock_event import MockEvent diff --git a/tests/unit/events/test_completion_chunk_event.py b/tests/unit/events/test_completion_chunk_event.py index aa9618a53..943ea483f 100644 --- a/tests/unit/events/test_completion_chunk_event.py +++ b/tests/unit/events/test_completion_chunk_event.py @@ -1,9 +1,10 @@ import pytest + from griptape.events import CompletionChunkEvent class TestCompletionChunkEvent: - @pytest.fixture + @pytest.fixture() def completion_chunk_event(self): return CompletionChunkEvent(token="foo bar") diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 2f32837e0..a79d2b6ea 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -1,27 +1,29 @@ from unittest.mock import Mock + import pytest -from griptape.events.base_event import BaseEvent -from griptape.structures import Pipeline -from griptape.tasks import ToolkitTask, ActionsSubtask + from griptape.events import ( - StartTaskEvent, + CompletionChunkEvent, + EventListener, + FinishActionsSubtaskEvent, + FinishPromptEvent, + FinishStructureRunEvent, FinishTaskEvent, StartActionsSubtaskEvent, - FinishActionsSubtaskEvent, StartPromptEvent, - FinishPromptEvent, StartStructureRunEvent, - FinishStructureRunEvent, - CompletionChunkEvent, - EventListener, + StartTaskEvent, ) +from griptape.events.base_event import BaseEvent +from griptape.structures import Pipeline +from griptape.tasks import ActionsSubtask, ToolkitTask +from tests.mocks.mock_event import MockEvent from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.mocks.mock_event import MockEvent class TestEventListener: - @pytest.fixture + @pytest.fixture() def pipeline(self): task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) @@ -107,7 +109,7 @@ def test_publish_event(self): mock_event_listener_driver = Mock() mock_event_listener_driver.try_publish_event_payload.return_value = None - def event_handler(_: BaseEvent): + def event_handler(_: BaseEvent) -> None: return None mock_event = MockEvent() diff --git a/tests/unit/events/test_finish_actions_subtask_event.py b/tests/unit/events/test_finish_actions_subtask_event.py index 14d7cbdde..5e2a0807a 100644 --- a/tests/unit/events/test_finish_actions_subtask_event.py +++ b/tests/unit/events/test_finish_actions_subtask_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import FinishActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask @@ -7,7 +8,7 @@ class TestFinishActionsSubtaskEvent: - @pytest.fixture + @pytest.fixture() def finish_subtask_event(self): valid_input = ( "Thought: need to test\n" diff --git a/tests/unit/events/test_finish_prompt_event.py b/tests/unit/events/test_finish_prompt_event.py index 7443fce0c..397efe8a0 100644 --- a/tests/unit/events/test_finish_prompt_event.py +++ b/tests/unit/events/test_finish_prompt_event.py @@ -1,9 +1,10 @@ import pytest + from griptape.events import FinishPromptEvent class TestFinishPromptEvent: - @pytest.fixture + @pytest.fixture() def finish_prompt_event(self): return FinishPromptEvent(input_token_count=321, output_token_count=123, result="foo bar", model="foo bar") diff --git a/tests/unit/events/test_finish_structure_run_event.py b/tests/unit/events/test_finish_structure_run_event.py index 0e9e61f4f..9c0961314 100644 --- a/tests/unit/events/test_finish_structure_run_event.py +++ b/tests/unit/events/test_finish_structure_run_event.py @@ -5,7 +5,7 @@ class TestFinishStructureRunEvent: - @pytest.fixture + @pytest.fixture() def finish_structure_run_event(self): return FinishStructureRunEvent( structure_id="fizz", diff --git a/tests/unit/events/test_finish_task_event.py b/tests/unit/events/test_finish_task_event.py index 40e71c9ea..df1d6d42a 100644 --- a/tests/unit/events/test_finish_task_event.py +++ b/tests/unit/events/test_finish_task_event.py @@ -1,12 +1,13 @@ import pytest -from griptape.structures import Agent + from griptape.events import FinishTaskEvent +from griptape.structures import Agent from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFinishTaskEvent: - @pytest.fixture + @pytest.fixture() def finish_task_event(self): task = PromptTask() agent = Agent(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/events/test_start_actions_subtask_event.py b/tests/unit/events/test_start_actions_subtask_event.py index d8b63de22..8b628057c 100644 --- a/tests/unit/events/test_start_actions_subtask_event.py +++ b/tests/unit/events/test_start_actions_subtask_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import StartActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask @@ -7,7 +8,7 @@ class TestStartActionsSubtaskEvent: - @pytest.fixture + @pytest.fixture() def start_subtask_event(self): valid_input = ( "Thought: need to test\n" diff --git a/tests/unit/events/test_start_prompt_event.py b/tests/unit/events/test_start_prompt_event.py index 4ef08ec5c..2d7e9368f 100644 --- a/tests/unit/events/test_start_prompt_event.py +++ b/tests/unit/events/test_start_prompt_event.py @@ -1,10 +1,11 @@ import pytest -from griptape.events import StartPromptEvent + from griptape.common import PromptStack +from griptape.events import StartPromptEvent class TestStartPromptEvent: - @pytest.fixture + @pytest.fixture() def start_prompt_event(self): prompt_stack = PromptStack() prompt_stack.add_user_message("foo") diff --git a/tests/unit/events/test_start_structure_run_event.py b/tests/unit/events/test_start_structure_run_event.py index c2f1b923d..221f1a544 100644 --- a/tests/unit/events/test_start_structure_run_event.py +++ b/tests/unit/events/test_start_structure_run_event.py @@ -1,10 +1,11 @@ import pytest + from griptape.artifacts.text_artifact import TextArtifact from griptape.events import StartStructureRunEvent class TestStartStructureRunEvent: - @pytest.fixture + @pytest.fixture() def start_structure_run_event(self): return StartStructureRunEvent( structure_id="fizz", input_task_input=TextArtifact("foo"), input_task_output=TextArtifact("bar") diff --git a/tests/unit/events/test_start_task_event.py b/tests/unit/events/test_start_task_event.py index f4d243421..ea027f147 100644 --- a/tests/unit/events/test_start_task_event.py +++ b/tests/unit/events/test_start_task_event.py @@ -1,4 +1,5 @@ import pytest + from griptape.events import StartTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask @@ -6,7 +7,7 @@ class TestStartTaskEvent: - @pytest.fixture + @pytest.fixture() def start_task_event(self): task = PromptTask() agent = Agent(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/loaders/conftest.py b/tests/unit/loaders/conftest.py index e1823a154..494916be6 100644 --- a/tests/unit/loaders/conftest.py +++ b/tests/unit/loaders/conftest.py @@ -4,7 +4,7 @@ import pytest -@pytest.fixture +@pytest.fixture() def path_from_resource_path(): def create_source(resource_path: str) -> Path: return Path(os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../resources", resource_path)) @@ -12,7 +12,7 @@ def create_source(resource_path: str) -> Path: return create_source -@pytest.fixture +@pytest.fixture() def bytes_from_resource_path(path_from_resource_path): def create_source(resource_path: str) -> bytes: with open(path_from_resource_path(resource_path), "rb") as f: @@ -21,7 +21,7 @@ def create_source(resource_path: str) -> bytes: return create_source -@pytest.fixture +@pytest.fixture() def str_from_resource_path(path_from_resource_path): def test_csv_str(resource_path: str) -> str: with open(path_from_resource_path(resource_path)) as f: diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py index b7946da03..473fd0d9e 100644 --- a/tests/unit/loaders/test_audio_loader.py +++ b/tests/unit/loaders/test_audio_loader.py @@ -5,15 +5,15 @@ class TestAudioLoader: - @pytest.fixture + @pytest.fixture() def loader(self): return AudioLoader() - @pytest.fixture + @pytest.fixture() def create_source(self, bytes_from_resource_path): return bytes_from_resource_path - @pytest.mark.parametrize("resource_path,suffix,mime_type", [("sentences.wav", ".wav", "audio/wav")]) + @pytest.mark.parametrize(("resource_path", "suffix", "mime_type"), [("sentences.wav", ".wav", "audio/wav")]) def test_load(self, resource_path, suffix, mime_type, loader, create_source): source = create_source(resource_path) @@ -32,8 +32,7 @@ def test_load_collection(self, create_source, loader): assert len(collection) == len(resource_paths) - keys = {loader.to_key(source) for source in sources} - for key in collection.keys(): + for key in collection: artifact = collection[key] assert isinstance(artifact, AudioArtifact) assert artifact.name.endswith(".wav") diff --git a/tests/unit/loaders/test_blob_loader.py b/tests/unit/loaders/test_blob_loader.py index f2b462726..4812e669c 100644 --- a/tests/unit/loaders/test_blob_loader.py +++ b/tests/unit/loaders/test_blob_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts import BlobArtifact from griptape.loaders import BlobLoader diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index 579146ba2..a747afff7 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.loaders.csv_loader import CsvLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -12,7 +13,7 @@ def loader(self, request): else: return CsvLoader(embedding_driver=MockEmbeddingDriver(), encoding=encoding) - @pytest.fixture + @pytest.fixture() def loader_with_pipe_delimiter(self): return CsvLoader(embedding_driver=MockEmbeddingDriver(), delimiter="|") diff --git a/tests/unit/loaders/test_dataframe_loader.py b/tests/unit/loaders/test_dataframe_loader.py index 536555558..5c2a57ed6 100644 --- a/tests/unit/loaders/test_dataframe_loader.py +++ b/tests/unit/loaders/test_dataframe_loader.py @@ -1,13 +1,14 @@ import os + import pandas as pd import pytest -from griptape import utils + from griptape.loaders.dataframe_loader import DataFrameLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestDataFrameLoader: - @pytest.fixture + @pytest.fixture() def loader(self): return DataFrameLoader(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/loaders/test_email_loader.py b/tests/unit/loaders/test_email_loader.py index ef69b8348..f1e057453 100644 --- a/tests/unit/loaders/test_email_loader.py +++ b/tests/unit/loaders/test_email_loader.py @@ -1,12 +1,14 @@ from __future__ import annotations +import email from email import message -from griptape.artifacts import ErrorArtifact, ListArtifact -from griptape.loaders import EmailLoader from typing import Optional -import email + import pytest +from griptape.artifacts import ErrorArtifact, ListArtifact +from griptape.loaders import EmailLoader + class TestEmailLoader: @pytest.fixture(autouse=True) @@ -15,7 +17,7 @@ def mock_imap_connection(self, mocker): mock_imap_connection.__enter__.return_value = mock_imap_connection return mock_imap_connection - @pytest.fixture + @pytest.fixture() def mock_login(self, mock_imap_connection): return mock_imap_connection.login @@ -25,7 +27,7 @@ def mock_select(self, mock_imap_connection): mock_select.return_value = to_select_response("OK", 1) return mock_select - @pytest.fixture + @pytest.fixture() def mock_search(self, mock_imap_connection): return mock_imap_connection.search @@ -35,7 +37,7 @@ def mock_fetch(self, mock_imap_connection): mock_fetch.return_value = to_fetch_message("message", "text/plain") return mock_fetch - @pytest.fixture + @pytest.fixture() def loader(self): return EmailLoader(imap_url="an.email.server.hostname", username="username", password="password") diff --git a/tests/unit/loaders/test_image_loader.py b/tests/unit/loaders/test_image_loader.py index 2a90d6b5d..eca4cbccc 100644 --- a/tests/unit/loaders/test_image_loader.py +++ b/tests/unit/loaders/test_image_loader.py @@ -5,20 +5,20 @@ class TestImageLoader: - @pytest.fixture + @pytest.fixture() def loader(self): return ImageLoader() - @pytest.fixture + @pytest.fixture() def png_loader(self): return ImageLoader(format="png") - @pytest.fixture + @pytest.fixture() def create_source(self, bytes_from_resource_path): return bytes_from_resource_path @pytest.mark.parametrize( - "resource_path,suffix,mime_type", + ("resource_path", "suffix", "mime_type"), [ ("small.png", ".png", "image/png"), ("small.jpg", ".jpeg", "image/jpeg"), diff --git a/tests/unit/loaders/test_pdf_loader.py b/tests/unit/loaders/test_pdf_loader.py index 0ab78b8b6..3f4f7848e 100644 --- a/tests/unit/loaders/test_pdf_loader.py +++ b/tests/unit/loaders/test_pdf_loader.py @@ -1,8 +1,5 @@ -import os -from pathlib import Path -from typing import IO import pytest -from griptape import utils + from griptape.loaders import PdfLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -10,11 +7,11 @@ class TestPdfLoader: - @pytest.fixture + @pytest.fixture() def loader(self): return PdfLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) - @pytest.fixture + @pytest.fixture() def create_source(self, bytes_from_resource_path): return bytes_from_resource_path diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index 8541e4fb8..fbfa6d4fa 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -1,5 +1,6 @@ import pytest from sqlalchemy.pool import StaticPool + from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -8,7 +9,7 @@ class TestSqlLoader: - @pytest.fixture + @pytest.fixture() def loader(self): sql_loader = SqlLoader( sql_driver=SqlDriver( diff --git a/tests/unit/loaders/test_text_loader.py b/tests/unit/loaders/test_text_loader.py index 0c59df12f..07527f9e6 100644 --- a/tests/unit/loaders/test_text_loader.py +++ b/tests/unit/loaders/test_text_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.loaders.text_loader import TextLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index e26573539..f264ce667 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,4 +1,5 @@ import pytest + from griptape.artifacts.error_artifact import ErrorArtifact from griptape.loaders import WebLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -8,10 +9,10 @@ class TestWebLoader: @pytest.fixture(autouse=True) - def mock_trafilatura_fetch_url(self, mocker): + def _mock_trafilatura_fetch_url(self, mocker): mocker.patch("trafilatura.fetch_url", return_value="foobar") - @pytest.fixture + @pytest.fixture() def loader(self): return WebLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/memory/meta/test_action_subtask_meta_entry.py b/tests/unit/memory/meta/test_action_subtask_meta_entry.py index f5da6ee01..0a51e7227 100644 --- a/tests/unit/memory/meta/test_action_subtask_meta_entry.py +++ b/tests/unit/memory/meta/test_action_subtask_meta_entry.py @@ -1,9 +1,10 @@ import pytest + from griptape.memory.meta import ActionSubtaskMetaEntry class TestActionSubtaskMetaEntry: - @pytest.fixture + @pytest.fixture() def entry(self): return ActionSubtaskMetaEntry(thought="foo", actions="[]", answer="baz") diff --git a/tests/unit/memory/meta/test_meta_memory.py b/tests/unit/memory/meta/test_meta_memory.py index bbdacf5b4..5a1249529 100644 --- a/tests/unit/memory/meta/test_meta_memory.py +++ b/tests/unit/memory/meta/test_meta_memory.py @@ -1,9 +1,10 @@ import pytest -from griptape.memory.meta import MetaMemory, ActionSubtaskMetaEntry + +from griptape.memory.meta import ActionSubtaskMetaEntry, MetaMemory class TestMetaMemory: - @pytest.fixture + @pytest.fixture() def memory(self): return MetaMemory() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 613d4b1fe..2ffd7b8cb 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,12 +1,12 @@ import json -from griptape.structures import Agent + +from griptape.artifacts import TextArtifact from griptape.common import PromptStack -from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory -from griptape.structures import Pipeline +from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run +from griptape.structures import Agent, Pipeline +from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tokenizer import MockTokenizer -from griptape.tasks import PromptTask -from griptape.artifacts import TextArtifact class TestConversationMemory: diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index e625ac6c6..4396c7b23 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -1,9 +1,8 @@ import json - +from griptape.artifacts import TextArtifact from griptape.memory.structure import Run, SummaryConversationMemory from griptape.structures import Pipeline -from griptape.artifacts import TextArtifact from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig diff --git a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py index dd42b6bc2..c7f2cfcbd 100644 --- a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py @@ -1,10 +1,11 @@ import pytest + from griptape.artifacts import BlobArtifact, TextArtifact from griptape.memory.task.storage import BlobArtifactStorage class TestBlobArtifactStorage: - @pytest.fixture + @pytest.fixture() def storage(self): return BlobArtifactStorage() diff --git a/tests/unit/memory/tool/storage/test_text_artifact_storage.py b/tests/unit/memory/tool/storage/test_text_artifact_storage.py index 706a80f0c..64f44c581 100644 --- a/tests/unit/memory/tool/storage/test_text_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_text_artifact_storage.py @@ -1,10 +1,11 @@ import pytest + from griptape.artifacts import BlobArtifact, TextArtifact from tests.utils import defaults class TestTextArtifactStorage: - @pytest.fixture + @pytest.fixture() def storage(self): return defaults.text_tool_artifact_storage() diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index fc1cf75c7..53e4703a6 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -1,6 +1,6 @@ import pytest -from griptape.artifacts import CsvRowArtifact, BlobArtifact, ErrorArtifact, InfoArtifact -from griptape.artifacts import TextArtifact, ListArtifact + +from griptape.artifacts import BlobArtifact, CsvRowArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.memory import TaskMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.structures import Agent @@ -11,10 +11,10 @@ class TestTaskMemory: @pytest.fixture(autouse=True) - def mock_griptape(self, mocker): + def _mock_griptape(self, mocker): mocker.patch("griptape.engines.CsvExtractionEngine.extract", return_value=[CsvRowArtifact({"foo": "bar"})]) - @pytest.fixture + @pytest.fixture() def memory(self): return defaults.text_task_memory("MyMemory") diff --git a/tests/unit/mixins/test_activity_mixin.py b/tests/unit/mixins/test_activity_mixin.py index 45db7cc7f..91e0c3a67 100644 --- a/tests/unit/mixins/test_activity_mixin.py +++ b/tests/unit/mixins/test_activity_mixin.py @@ -1,10 +1,11 @@ import pytest -from schema import Schema, Literal, Optional +from schema import Literal, Optional, Schema + from tests.mocks.mock_tool.tool import MockTool class TestActivityMixin: - @pytest.fixture + @pytest.fixture() def tool(self): return MockTool(test_field="hello", test_int=5) diff --git a/tests/unit/mixins/test_image_artifact_file_output_mixin.py b/tests/unit/mixins/test_image_artifact_file_output_mixin.py index 69a2f1d71..03c44e081 100644 --- a/tests/unit/mixins/test_image_artifact_file_output_mixin.py +++ b/tests/unit/mixins/test_image_artifact_file_output_mixin.py @@ -19,7 +19,7 @@ def test_output_file(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") class Test(BlobArtifactFileOutputMixin): - def run(self): + def run(self) -> None: self._write_to_file(artifact) outfile = os.path.join(tempfile.gettempdir(), artifact.name) @@ -34,7 +34,7 @@ def test_output_dir(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") class Test(BlobArtifactFileOutputMixin): - def run(self): + def run(self) -> None: self._write_to_file(artifact) outdir = tempfile.gettempdir() diff --git a/tests/unit/mixins/test_seriliazable_mixin.py b/tests/unit/mixins/test_seriliazable_mixin.py index 1704000e3..afb3d1eb4 100644 --- a/tests/unit/mixins/test_seriliazable_mixin.py +++ b/tests/unit/mixins/test_seriliazable_mixin.py @@ -1,11 +1,13 @@ import json + import pytest + +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.drivers import OpenAiChatPromptDriver -from griptape.memory.structure import ConversationMemory from griptape.memory import TaskMemory -from tests.mocks.mock_serializable import MockSerializable +from griptape.memory.structure import ConversationMemory from griptape.schemas import BaseSchema -from griptape.artifacts import BaseArtifact, TextArtifact +from tests.mocks.mock_serializable import MockSerializable class TestSerializableMixin: diff --git a/tests/unit/schemas/test_base_schema.py b/tests/unit/schemas/test_base_schema.py index fcbd08c7f..f3a3f0c1f 100644 --- a/tests/unit/schemas/test_base_schema.py +++ b/tests/unit/schemas/test_base_schema.py @@ -1,13 +1,16 @@ from __future__ import annotations + from datetime import datetime +from typing import Literal, Optional, Union + import pytest -from typing import Union, Optional, Literal from marshmallow import fields + from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.loaders import TextLoader from griptape.schemas import PolymorphicSchema -from griptape.schemas.bytes_field import Bytes from griptape.schemas.base_schema import BaseSchema -from griptape.loaders import TextLoader +from griptape.schemas.bytes_field import Bytes from tests.mocks.mock_serializable import MockSerializable @@ -62,8 +65,8 @@ def test_get_field_type_info(self): assert BaseSchema._get_field_type_info(list) == (list, (), False) - assert BaseSchema._get_field_type_info(Literal["foo"]) == (str, (), False) # pyright: ignore - assert BaseSchema._get_field_type_info(Literal[5]) == (int, (), False) # pyright: ignore + assert BaseSchema._get_field_type_info(Literal["foo"]) == (str, (), False) # pyright: ignore[reportArgumentType] + assert BaseSchema._get_field_type_info(Literal[5]) == (int, (), False) # pyright: ignore[reportArgumentType] def test_is_list_sequence(self): assert BaseSchema.is_list_sequence(list) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index baceac825..a09ad0f9a 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,15 +1,15 @@ import pytest -from griptape.memory.structure import ConversationMemory + +from griptape.engines import PromptSummaryEngine from griptape.memory import TaskMemory +from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Agent -from griptape.tasks import PromptTask, BaseTask, ToolkitTask -from griptape.engines import PromptSummaryEngine - +from griptape.tasks import BaseTask, PromptTask, ToolkitTask +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestAgent: @@ -49,8 +49,8 @@ def test_rules_and_rulesets(self): with pytest.raises(ValueError): Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + agent = Agent() with pytest.raises(ValueError): - agent = Agent() agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) def test_with_task_memory(self): @@ -149,13 +149,13 @@ def test_add_tasks(self): try: agent.add_tasks(first_task, second_task) - assert False + raise AssertionError() except ValueError: assert True try: agent + [first_task, second_task] - assert False + raise AssertionError() except ValueError: assert True diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index e63937a62..38f8abfb3 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -1,20 +1,21 @@ -import pytest import time -from griptape.artifacts import TextArtifact, ErrorArtifact +import pytest + +from griptape.artifacts import ErrorArtifact, TextArtifact +from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset +from griptape.structures import Pipeline +from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from griptape.tokenizers import OpenAiTokenizer -from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask -from griptape.memory.structure import ConversationMemory from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Pipeline from tests.mocks.mock_tool.tool import MockTool from tests.unit.structures.test_agent import MockEmbeddingDriver class TestPipeline: - @pytest.fixture + @pytest.fixture() def waiting_task(self): def fn(task): time.sleep(2) @@ -22,7 +23,7 @@ def fn(task): return CodeExecutionTask(run_fn=fn) - @pytest.fixture + @pytest.fixture() def error_artifact_task(self): def fn(task): return ErrorArtifact("error") @@ -76,8 +77,8 @@ def test_rules_and_rulesets(self): with pytest.raises(ValueError): Pipeline(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + pipeline = Pipeline() with pytest.raises(ValueError): - pipeline = Pipeline() pipeline.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) def test_with_no_task_memory(self): diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index bf55e852f..2be164ea7 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -1,20 +1,20 @@ import time + import pytest -from pytest import fixture +from griptape.artifacts import ErrorArtifact, TextArtifact +from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage -from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.rules import Rule, Ruleset -from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask from griptape.structures import Workflow -from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.memory.structure import ConversationMemory -from tests.mocks.mock_tool.tool import MockTool +from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from tests.mocks.mock_embedding_driver import MockEmbeddingDriver +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tool.tool import MockTool class TestWorkflow: - @fixture + @pytest.fixture() def waiting_task(self): def fn(task): time.sleep(2) @@ -22,7 +22,7 @@ def fn(task): return CodeExecutionTask(run_fn=fn) - @fixture + @pytest.fixture() def error_artifact_task(self): def fn(task): return ErrorArtifact("error") @@ -75,8 +75,8 @@ def test_rules_and_rulesets(self): with pytest.raises(ValueError): Workflow(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + workflow = Workflow() with pytest.raises(ValueError): - workflow = Workflow() workflow.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) def test_with_no_task_memory(self): @@ -777,7 +777,7 @@ def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting assert workflow.output is not None @staticmethod - def _validate_topology_1(workflow): + def _validate_topology_1(workflow) -> None: assert len(workflow.tasks) == 4 assert workflow.input_task.id == "task1" assert workflow.output_task.id == "task4" @@ -805,8 +805,8 @@ def _validate_topology_1(workflow): assert task4.child_ids == [] @staticmethod - def _validate_topology_2(workflow): - """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg""" + def _validate_topology_2(workflow) -> None: + """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg.""" assert len(workflow.tasks) == 5 assert workflow.input_task.id == "taska" assert workflow.output_task.id == "taske" @@ -839,7 +839,7 @@ def _validate_topology_2(workflow): assert taske.child_ids == [] @staticmethod - def _validate_topology_3(workflow): + def _validate_topology_3(workflow) -> None: assert len(workflow.tasks) == 4 assert workflow.input_task.id == "task1" assert workflow.output_task.id == "task3" @@ -867,7 +867,7 @@ def _validate_topology_3(workflow): assert task4.child_ids == ["task2"] @staticmethod - def _validate_topology_4(workflow): + def _validate_topology_4(workflow) -> None: assert len(workflow.tasks) == 9 assert workflow.input_task.id == "collect_movie_info" assert workflow.output_task.id == "summarize_to_slack" diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index c6e5ca038..e25a42120 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -1,10 +1,11 @@ import json -from griptape.artifacts import ListArtifact, TextArtifact, ActionArtifact + +from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact -from tests.mocks.mock_tool.tool import MockTool -from griptape.tasks import ToolkitTask, ActionsSubtask -from griptape.structures import Agent from griptape.common import ToolAction +from griptape.structures import Agent +from griptape.tasks import ActionsSubtask, ToolkitTask +from tests.mocks.mock_tool.tool import MockTool class TestActionsSubtask: diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 3a53fd49d..f4bc0e8b8 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -5,17 +5,17 @@ from griptape.artifacts import AudioArtifact, TextArtifact from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent, Pipeline -from griptape.tasks import BaseTask, AudioTranscriptionTask +from griptape.tasks import AudioTranscriptionTask, BaseTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig class TestAudioTranscriptionTask: - @pytest.fixture + @pytest.fixture() def audio_artifact(self): return AudioArtifact(value=b"audio data", format="mp3") - @pytest.fixture + @pytest.fixture() def audio_transcription_engine(self): return Mock() @@ -40,14 +40,9 @@ def test_config_audio_transcription_engine(self, audio_artifact): def test_run(self, audio_artifact, audio_transcription_engine): audio_transcription_engine.run.return_value = TextArtifact("mock transcription") - logger = Mock() task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver(), logger=logger) + pipeline = Pipeline(prompt_driver=MockPromptDriver()) pipeline.add_task(task) assert pipeline.run().output.to_text() == "mock transcription" - - def test_before_run(self, audio_artifact, audio_transcription_engine): - task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - task diff --git a/tests/unit/tasks/test_base_audio_input_task.py b/tests/unit/tasks/test_base_audio_input_task.py index e11074880..e16c6536c 100644 --- a/tests/unit/tasks/test_base_audio_input_task.py +++ b/tests/unit/tasks/test_base_audio_input_task.py @@ -1,12 +1,12 @@ import pytest -from tests.mocks.mock_audio_input_task import MockAudioInputTask from griptape.artifacts import AudioArtifact, TextArtifact +from tests.mocks.mock_audio_input_task import MockAudioInputTask from tests.mocks.mock_text_input_task import MockTextInputTask class TestBaseAudioInputTask: - @pytest.fixture + @pytest.fixture() def audio_artifact(self): return AudioArtifact(b"audio content", format="mp3") diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py index ad4776aee..3d8d67a55 100644 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -1,7 +1,7 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Pipeline from griptape.artifacts import TextArtifact +from griptape.structures import Pipeline from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBaseMultiTextInputTask: diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 7fe2810f5..7a1afcf07 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -1,9 +1,8 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.structures import Agent +from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask -from griptape.structures import Workflow from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_task import MockTask @@ -11,7 +10,7 @@ class TestBaseTask: - @pytest.fixture + @pytest.fixture() def task(self): agent = Agent(prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()]) diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 14c3c3f2e..86dc98805 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -1,7 +1,7 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Pipeline from griptape.artifacts import TextArtifact -from griptape.rules import Ruleset, Rule +from griptape.rules import Rule, Ruleset +from griptape.structures import Pipeline +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_text_input_task import MockTextInputTask diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index b94714912..3178e29db 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -1,4 +1,4 @@ -from griptape.artifacts import BaseArtifact, TextArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.structures import Pipeline from griptape.tasks import CodeExecutionTask from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/tasks/test_csv_extraction_task.py b/tests/unit/tasks/test_csv_extraction_task.py index 9b8fb15bb..7d37c3897 100644 --- a/tests/unit/tasks/test_csv_extraction_task.py +++ b/tests/unit/tasks/test_csv_extraction_task.py @@ -8,7 +8,7 @@ class TestCsvExtractionTask: - @pytest.fixture + @pytest.fixture() def task(self): return CsvExtractionTask(args={"column_names": ["test1"]}) @@ -30,4 +30,4 @@ def test_config_extraction_engine(self, task): def test_missing_extraction_engine(self, task): with pytest.raises(ValueError): - task.extraction_engine + task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index 5e5ec09f6..afa73a506 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -1,4 +1,5 @@ import pytest + from griptape.engines import CsvExtractionEngine from griptape.structures import Agent from griptape.tasks import ExtractionTask @@ -6,7 +7,7 @@ class TestExtractionTask: - @pytest.fixture + @pytest.fixture() def task(self): return ExtractionTask( extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), args={"column_names": ["test1"]} diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index dd4940213..549009fc1 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -12,18 +12,18 @@ class TestImageQueryTask: - @pytest.fixture + @pytest.fixture() def image_query_engine(self) -> Mock: mock = Mock() mock.run.return_value = TextArtifact("image") return mock - @pytest.fixture + @pytest.fixture() def text_artifact(self): return TextArtifact(value="some text") - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(value=b"some image data", format="png", width=512, height=512) @@ -70,7 +70,7 @@ def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) with pytest.raises(ValueError, match="Image Query Engine"): - task.image_query_engine + task.image_query_engine # noqa: B018 def test_run(self, image_query_engine, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 9dc6aff54..ff287d79a 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -1,21 +1,22 @@ -from griptape.artifacts.list_artifact import ListArtifact -from griptape.engines import InpaintingImageGenerationEngine from unittest.mock import Mock import pytest -from griptape.tasks import BaseTask, InpaintingImageGenerationTask -from griptape.artifacts import TextArtifact, ImageArtifact + +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.engines import InpaintingImageGenerationEngine from griptape.structures import Agent +from griptape.tasks import BaseTask, InpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_structure_config import MockStructureConfig class TestInpaintingImageGenerationTask: - @pytest.fixture + @pytest.fixture() def text_artifact(self): return TextArtifact(value="some text") - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(value=b"some image data", format="png", width=512, height=512) @@ -59,4 +60,4 @@ def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 0366652b0..ba7d1ce30 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -1,14 +1,15 @@ -from griptape.engines import JsonExtractionEngine import pytest from schema import Schema + +from griptape.engines import JsonExtractionEngine from griptape.structures import Agent from griptape.tasks import JsonExtractionTask -from tests.mocks.mock_structure_config import MockStructureConfig from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_structure_config import MockStructureConfig class TestJsonExtractionTask: - @pytest.fixture + @pytest.fixture() def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) @@ -34,4 +35,4 @@ def test_config_extraction_engine(self, task): def test_missing_extraction_engine(self, task): with pytest.raises(ValueError): - task.extraction_engine + task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index 148ea133d..7de530330 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -1,10 +1,10 @@ -from griptape.artifacts.list_artifact import ListArtifact -from griptape.engines import OutpaintingImageGenerationEngine from unittest.mock import Mock import pytest from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.engines import OutpaintingImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, OutpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -12,11 +12,11 @@ class TestOutpaintingImageGenerationTask: - @pytest.fixture + @pytest.fixture() def text_artifact(self): return TextArtifact(value="some text") - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(value=b"some image data", format="png", width=512, height=512) @@ -60,4 +60,4 @@ def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_image_generation_task.py b/tests/unit/tasks/test_prompt_image_generation_task.py index 4f6117c07..c3add5720 100644 --- a/tests/unit/tasks/test_prompt_image_generation_task.py +++ b/tests/unit/tasks/test_prompt_image_generation_task.py @@ -1,4 +1,3 @@ -from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from unittest.mock import Mock import pytest @@ -7,6 +6,7 @@ from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, PromptImageGenerationTask +from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -37,4 +37,4 @@ def test_missing_summary_engine(self): task = PromptImageGenerationTask("foo bar") with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index c76f0284b..083ea6da5 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,14 +1,15 @@ import pytest + from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.structure.run import Run -from tests.mocks.mock_structure_config import MockStructureConfig -from griptape.tasks import PromptTask from griptape.rules import Rule -from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.structures import Pipeline +from griptape.tasks import PromptTask +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptTask: @@ -37,7 +38,7 @@ def test_missing_prompt_driver(self): task = PromptTask("test") with pytest.raises(ValueError): - task.prompt_driver + task.prompt_driver # noqa: B018 def test_input(self): # Str diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index c9b82f208..b205d385a 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -1,4 +1,5 @@ import pytest + from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule from griptape.engines.rag.stages import ResponseRagStage @@ -8,7 +9,7 @@ class TestRagTask: - @pytest.fixture + @pytest.fixture() def task(self): return RagTask( input="test", diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index d89e98c91..1053ade9e 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -1,8 +1,7 @@ +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask -from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.drivers import LocalStructureRunDriver -from griptape.structures import Pipeline class TestStructureRunTask: diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index d7a474373..bb08f9d31 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -1,9 +1,10 @@ import pytest -from tests.mocks.mock_structure_config import MockStructureConfig + from griptape.engines import PromptSummaryEngine +from griptape.structures import Agent from griptape.tasks import TextSummaryTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.structures import Agent +from tests.mocks.mock_structure_config import MockStructureConfig class TestTextSummaryTask: @@ -34,4 +35,4 @@ def test_missing_summary_engine(self): task = TextSummaryTask("test") with pytest.raises(ValueError): - task.summary_engine + task.summary_engine # noqa: B018 diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index 7a8e49364..86bc1d2ce 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from griptape.artifacts import TextArtifact, AudioArtifact +from griptape.artifacts import AudioArtifact, TextArtifact from griptape.engines import TextToSpeechEngine from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 2af8b73c3..dfc679919 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -1,8 +1,10 @@ import json + import pytest + from griptape.artifacts import TextArtifact from griptape.structures import Agent -from griptape.tasks import ToolTask, ActionsSubtask +from griptape.tasks import ActionsSubtask, ToolTask from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -163,7 +165,7 @@ class TestToolTask: "$schema": "http://json-schema.org/draft-07/schema#", } - @pytest.fixture + @pytest.fixture() def agent(self): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 2217ba70c..cd5dd21f8 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -1,9 +1,9 @@ from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.structures import Agent -from griptape.tasks import ToolkitTask, ActionsSubtask, PromptTask from griptape.common import ToolAction -from tests.mocks.mock_tool.tool import MockTool +from griptape.structures import Agent +from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -166,7 +166,7 @@ def test_init(self): try: ToolkitTask("test", tools=[MockTool(), MockTool()]) - assert False + raise AssertionError() except ValueError: assert True @@ -231,7 +231,7 @@ def test_init_from_prompt_1(self): assert subtask.output is None def test_init_from_prompt_2(self): - valid_input = """Thought: need to test\nObservation: test + valid_input = """Thought: need to test\nObservation: test observation\nAnswer: test output""" task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index 6a9533da3..3f865931f 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -1,21 +1,22 @@ -from griptape.artifacts.list_artifact import ListArtifact -from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig from unittest.mock import Mock import pytest -from griptape.tasks import BaseTask, VariationImageGenerationTask -from griptape.artifacts import TextArtifact, ImageArtifact + +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact from griptape.engines import VariationImageGenerationEngine from griptape.structures import Agent +from griptape.tasks import BaseTask, VariationImageGenerationTask +from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver +from tests.mocks.mock_structure_config import MockStructureConfig class TestVariationImageGenerationTask: - @pytest.fixture + @pytest.fixture() def text_artifact(self): return TextArtifact(value="some text") - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(value=b"some image data", format="png", width=512, height=512) @@ -56,4 +57,4 @@ def test_missing_summary_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) with pytest.raises(ValueError): - task.image_generation_engine + task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py b/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py index 2b77ba3dc..bb928c1c3 100644 --- a/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py +++ b/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py @@ -1,14 +1,15 @@ -from griptape.tokenizers import AmazonBedrockTokenizer import pytest +from griptape.tokenizers import AmazonBedrockTokenizer + class TestAmazonBedrockTokenizer: - @pytest.fixture + @pytest.fixture() def tokenizer(self, request): return AmazonBedrockTokenizer(model=request.param) @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("anthropic.claude-v2:1", 4), ("anthropic.claude-v2", 4), @@ -21,7 +22,7 @@ def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("anthropic.claude-v2", 99996), ("anthropic.claude-v2:1", 199996), @@ -34,7 +35,7 @@ def test_input_tokens_left(self, tokenizer, expected): assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("anthropic.claude-v2", 4092), ("anthropic.claude-v2:1", 4092), diff --git a/tests/unit/tokenizers/test_anthropic_tokenizer.py b/tests/unit/tokenizers/test_anthropic_tokenizer.py index ee165d270..859ba9684 100644 --- a/tests/unit/tokenizers/test_anthropic_tokenizer.py +++ b/tests/unit/tokenizers/test_anthropic_tokenizer.py @@ -1,14 +1,15 @@ import pytest + from griptape.tokenizers import AnthropicTokenizer class TestAnthropicTokenizer: - @pytest.fixture + @pytest.fixture() def tokenizer(self, request): return AnthropicTokenizer(model=request.param) @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [("claude-2.1", 5), ("claude-2.0", 5), ("claude-3-opus", 5), ("claude-3-sonnet", 5), ("claude-3-haiku", 5)], indirect=["tokenizer"], ) @@ -16,7 +17,7 @@ def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("claude-2.0", 99995), ("claude-2.1", 199995), @@ -30,7 +31,7 @@ def test_input_tokens_left(self, tokenizer, expected): assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("claude-2.0", 4091), ("claude-2.1", 4091), diff --git a/tests/unit/tokenizers/test_base_tokenizer.py b/tests/unit/tokenizers/test_base_tokenizer.py index eed15b9b2..08fd42c72 100644 --- a/tests/unit/tokenizers/test_base_tokenizer.py +++ b/tests/unit/tokenizers/test_base_tokenizer.py @@ -1,4 +1,5 @@ import logging + from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/unit/tokenizers/test_cohere_tokenizer.py b/tests/unit/tokenizers/test_cohere_tokenizer.py index 9ca23f4f0..3999ec399 100644 --- a/tests/unit/tokenizers/test_cohere_tokenizer.py +++ b/tests/unit/tokenizers/test_cohere_tokenizer.py @@ -1,5 +1,6 @@ import cohere import pytest + from griptape.tokenizers import CohereTokenizer @@ -10,7 +11,7 @@ def mock_client(self, mocker): return mock_client - @pytest.fixture + @pytest.fixture() def tokenizer(self): return CohereTokenizer(model="command", client=cohere.Client("foobar")) diff --git a/tests/unit/tokenizers/test_dummy_tokenizer.py b/tests/unit/tokenizers/test_dummy_tokenizer.py index 855fb6eee..5d770d1aa 100644 --- a/tests/unit/tokenizers/test_dummy_tokenizer.py +++ b/tests/unit/tokenizers/test_dummy_tokenizer.py @@ -1,10 +1,11 @@ import pytest + from griptape.exceptions import DummyException from griptape.tokenizers import DummyTokenizer class TestDummyTokenizer: - @pytest.fixture + @pytest.fixture() def tokenizer(self): return DummyTokenizer() diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index 34510cdac..41012f5d6 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import Mock + +import pytest + from griptape.common import PromptStack from griptape.common.prompt_stack.messages.message import Message from griptape.tokenizers import GoogleTokenizer @@ -13,22 +15,22 @@ def mock_generative_model(self, mocker): return mock_generative_model - @pytest.fixture + @pytest.fixture() def tokenizer(self, request): return GoogleTokenizer(model=request.param, api_key="1234") - @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 5)], indirect=["tokenizer"]) + @pytest.mark.parametrize(("tokenizer", "expected"), [("gemini-pro", 5)], indirect=["tokenizer"]) def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected assert tokenizer.count_tokens(PromptStack(messages=[Message(content="foo", role="user")])) == expected assert tokenizer.count_tokens(["foo", "bar", "huzzah"]) == expected - @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 30715)], indirect=["tokenizer"]) + @pytest.mark.parametrize(("tokenizer", "expected"), [("gemini-pro", 30715)], indirect=["tokenizer"]) def test_input_tokens_left(self, tokenizer, expected): assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected assert tokenizer.count_input_tokens_left(["foo", "bar", "huzzah"]) == expected - @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 2043)], indirect=["tokenizer"]) + @pytest.mark.parametrize(("tokenizer", "expected"), [("gemini-pro", 2043)], indirect=["tokenizer"]) def test_output_tokens_left(self, tokenizer, expected): assert tokenizer.count_output_tokens_left("foo bar huzzah") == expected assert tokenizer.count_output_tokens_left(["foo", "bar", "huzzah"]) == expected diff --git a/tests/unit/tokenizers/test_hugging_face_tokenizer.py b/tests/unit/tokenizers/test_hugging_face_tokenizer.py index dcb309a84..e717140e3 100644 --- a/tests/unit/tokenizers/test_hugging_face_tokenizer.py +++ b/tests/unit/tokenizers/test_hugging_face_tokenizer.py @@ -3,11 +3,12 @@ environ["TRANSFORMERS_VERBOSITY"] = "error" import pytest # noqa: E402 + from griptape.tokenizers import HuggingFaceTokenizer # noqa: E402 class TestHuggingFaceTokenizer: - @pytest.fixture + @pytest.fixture() def tokenizer(self): return HuggingFaceTokenizer(model="gpt2", max_output_tokens=1024) diff --git a/tests/unit/tokenizers/test_openai_tokenizer.py b/tests/unit/tokenizers/test_openai_tokenizer.py index 4aa42a87a..697184546 100644 --- a/tests/unit/tokenizers/test_openai_tokenizer.py +++ b/tests/unit/tokenizers/test_openai_tokenizer.py @@ -1,14 +1,15 @@ import pytest + from griptape.tokenizers import OpenAiTokenizer class TestOpenAiTokenizer: - @pytest.fixture + @pytest.fixture() def tokenizer(self, request): return OpenAiTokenizer(model=request.param) @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("gpt-4-1106", 5), ("gpt-4-32k", 5), @@ -34,7 +35,7 @@ def test_initialize_with_unknown_model(self): assert tokenizer.max_input_tokens == OpenAiTokenizer.DEFAULT_MAX_TOKENS - OpenAiTokenizer.TOKEN_OFFSET @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("gpt-4-1106", 19), ("gpt-4-32k", 19), @@ -45,7 +46,6 @@ def test_initialize_with_unknown_model(self): ("gpt-3.5-turbo", 19), ("gpt-35-turbo-16k", 19), ("gpt-35-turbo", 19), - ("gpt-35-turbo", 19), ], indirect=["tokenizer"], ) @@ -57,7 +57,7 @@ def test_token_count_for_messages(self, tokenizer, expected): == expected ) - @pytest.mark.parametrize("tokenizer,expected", [("not-real-model", 19)], indirect=["tokenizer"]) + @pytest.mark.parametrize(("tokenizer", "expected"), [("not-real-model", 19)], indirect=["tokenizer"]) def test_token_count_for_messages_unknown_model(self, tokenizer, expected): with pytest.raises(NotImplementedError): tokenizer.count_tokens( @@ -65,7 +65,7 @@ def test_token_count_for_messages_unknown_model(self, tokenizer, expected): ) @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("gpt-4-1106", 127987), ("gpt-4o", 127987), @@ -86,7 +86,7 @@ def test_input_tokens_left(self, tokenizer, expected): assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [ ("gpt-4-1106", 4091), ("gpt-4-32k", 4091), diff --git a/tests/unit/tokenizers/test_simple_tokenizer.py b/tests/unit/tokenizers/test_simple_tokenizer.py index a34c0d481..d2598067a 100644 --- a/tests/unit/tokenizers/test_simple_tokenizer.py +++ b/tests/unit/tokenizers/test_simple_tokenizer.py @@ -1,9 +1,10 @@ import pytest + from griptape.tokenizers import SimpleTokenizer class TestSimpleTokenizer: - @pytest.fixture + @pytest.fixture() def tokenizer(self): return SimpleTokenizer(max_input_tokens=1024, max_output_tokens=4096, characters_per_token=6) diff --git a/tests/unit/tokenizers/test_voyageai_tokenizer.py b/tests/unit/tokenizers/test_voyageai_tokenizer.py index 46c9490f7..6f631ae4e 100644 --- a/tests/unit/tokenizers/test_voyageai_tokenizer.py +++ b/tests/unit/tokenizers/test_voyageai_tokenizer.py @@ -1,4 +1,5 @@ import pytest + from griptape.tokenizers import VoyageAiTokenizer @@ -10,12 +11,12 @@ def mock_client(self, mocker): return mock_client - @pytest.fixture + @pytest.fixture() def tokenizer(self, request): return VoyageAiTokenizer(model=request.param) @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [("voyage-large-2", 5), ("voyage-code-2", 5), ("voyage-2", 5), ("voyage-lite-02-instruct", 5)], indirect=["tokenizer"], ) @@ -23,7 +24,7 @@ def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [("voyage-large-2", 15995), ("voyage-code-2", 15995), ("voyage-2", 3995), ("voyage-lite-02-instruct", 3995)], indirect=["tokenizer"], ) @@ -31,7 +32,7 @@ def test_input_tokens_left(self, tokenizer, expected): assert tokenizer.count_input_tokens_left("foo bar huzzah") == expected @pytest.mark.parametrize( - "tokenizer,expected", + ("tokenizer", "expected"), [("voyage-large-2", 0), ("voyage-code-2", 0), ("voyage-2", 0), ("voyage-lite-02-instruct", 0)], indirect=["tokenizer"], ) diff --git a/tests/unit/tools/test_aws_iam.py b/tests/unit/tools/test_aws_iam.py index 2a256425e..54dbaa5fb 100644 --- a/tests/unit/tools/test_aws_iam.py +++ b/tests/unit/tools/test_aws_iam.py @@ -1,12 +1,13 @@ -from pytest import fixture +import boto3 +import pytest + from griptape.tools import AwsIamClient from tests.utils.aws import mock_aws_credentials -import boto3 class TestAwsIamClient: - @fixture(autouse=True) - def run_before_and_after_tests(self): + @pytest.fixture(autouse=True) + def _run_before_and_after_tests(self): mock_aws_credentials() def test_get_user_policy(self): diff --git a/tests/unit/tools/test_aws_s3.py b/tests/unit/tools/test_aws_s3.py index 6fe62b71c..5c6a4c151 100644 --- a/tests/unit/tools/test_aws_s3.py +++ b/tests/unit/tools/test_aws_s3.py @@ -1,12 +1,13 @@ -from pytest import fixture +import boto3 +import pytest + from griptape.tools import AwsS3Client from tests.utils.aws import mock_aws_credentials -import boto3 class TestAwsS3Client: - @fixture(autouse=True) - def run_before_and_after_tests(self): + @pytest.fixture(autouse=True) + def _run_before_and_after_tests(self): mock_aws_credentials() def test_get_bucket_acl(self): diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 75154b509..a4d4097c1 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -1,10 +1,12 @@ import inspect import os + import pytest import yaml -from schema import SchemaMissingKeyError, Schema, Or -from griptape.tasks import ActionsSubtask, ToolkitTask +from schema import Or, Schema, SchemaMissingKeyError + from griptape.common import ToolAction +from griptape.tasks import ActionsSubtask, ToolkitTask from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -153,7 +155,7 @@ class TestBaseTool: "$schema": "http://json-schema.org/draft-07/schema#", } - @pytest.fixture + @pytest.fixture() def tool(self): return MockTool(test_field="hello", test_int=5, test_dict={"foo": "bar"}) @@ -195,9 +197,9 @@ def test_validate(self, tool): def test_invalid_config(self): try: - from tests.mocks.invalid_mock_tool.tool import InvalidMockTool # noqa + from tests.mocks.invalid_mock_tool.tool import InvalidMockTool # noqa: F401 - assert False + raise AssertionError() except SchemaMissingKeyError: assert True @@ -252,6 +254,6 @@ def test_to_native_tool_name(self, tool): assert tool.to_native_tool_name(tool.test) == "MockTool_test" + tool.name = "mock_tool" with pytest.raises(ValueError): - tool.name = "mock_tool" tool.to_native_tool_name(tool.foo) diff --git a/tests/unit/tools/test_computer.py b/tests/unit/tools/test_computer.py index b11fd080e..95de18ae3 100644 --- a/tests/unit/tools/test_computer.py +++ b/tests/unit/tools/test_computer.py @@ -1,10 +1,11 @@ import pytest -from tests.mocks.docker.fake_api_client import make_fake_client + from griptape.tools import Computer +from tests.mocks.docker.fake_api_client import make_fake_client class TestComputer: - @pytest.fixture + @pytest.fixture() def computer(self): return Computer(docker_client=make_fake_client(), install_dependencies_on_init=False) diff --git a/tests/unit/tools/test_date_time.py b/tests/unit/tools/test_date_time.py index daa511f04..c534ae69b 100644 --- a/tests/unit/tools/test_date_time.py +++ b/tests/unit/tools/test_date_time.py @@ -1,6 +1,7 @@ -from griptape.tools import DateTime from datetime import datetime +from griptape.tools import DateTime + class TestDateTime: def test_get_current_datetime(self): diff --git a/tests/unit/tools/test_email_client.py b/tests/unit/tools/test_email_client.py index 183d7b265..cf99009b8 100644 --- a/tests/unit/tools/test_email_client.py +++ b/tests/unit/tools/test_email_client.py @@ -1,8 +1,8 @@ -from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact +import pytest + +from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.loaders.email_loader import EmailLoader -from griptape.artifacts import TextArtifact from griptape.tools import EmailClient -import pytest class TestEmailClient: @@ -27,7 +27,7 @@ def mock_smtp_ssl(self, mocker): mock_smtp_ssl.__enter__.return_value = mock_smtp_ssl return mock_smtp_ssl - @pytest.fixture + @pytest.fixture() def client(self): return EmailClient( username="fake-username", @@ -37,12 +37,12 @@ def client(self): mailboxes={"INBOX": "default mailbox for incoming email", "SENT": "default mailbox for sent email"}, ) - @pytest.fixture + @pytest.fixture() def send_params(self): return {"values": {"to": "fake@fake.fake", "subject": "fake-subject", "body": "fake-body"}} @pytest.mark.parametrize( - "values,query", + ("values", "query"), [ ({"label": "fake-label"}, EmailLoader.EmailQuery(label="fake-label")), ({"label": "fake-label", "key": "fake-key"}, EmailLoader.EmailQuery(label="fake-label", key="fake-key")), diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 06c49f32b..57dd2c83e 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -1,9 +1,11 @@ -import os.path import os +import os.path import tempfile from pathlib import Path + import pytest -from griptape.artifacts import TextArtifact, ListArtifact + +from griptape.artifacts import ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader @@ -12,14 +14,14 @@ class TestFileManager: - @pytest.fixture + @pytest.fixture() def file_manager(self): return FileManager( input_memory=[defaults.text_task_memory("Memory1")], file_manager_driver=LocalFileManagerDriver(workdir=os.path.abspath(os.path.dirname(__file__))), ) - @pytest.fixture + @pytest.fixture() def temp_dir(self): with tempfile.TemporaryDirectory() as temp_dir: yield temp_dir diff --git a/tests/unit/tools/test_google_docs_client.py b/tests/unit/tools/test_google_docs_client.py index ad74fcaca..a42fddda3 100644 --- a/tests/unit/tools/test_google_docs_client.py +++ b/tests/unit/tools/test_google_docs_client.py @@ -2,7 +2,7 @@ class TestGoogleDocsClient: - @pytest.fixture + @pytest.fixture() def mock_docs_client(self): from griptape.tools import GoogleDocsClient diff --git a/tests/unit/tools/test_google_drive_client.py b/tests/unit/tools/test_google_drive_client.py index 5d2f62df7..55f3c168f 100644 --- a/tests/unit/tools/test_google_drive_client.py +++ b/tests/unit/tools/test_google_drive_client.py @@ -1,5 +1,5 @@ -from griptape.tools import GoogleDriveClient from griptape.artifacts import ErrorArtifact +from griptape.tools import GoogleDriveClient class TestGoogleDriveClient: diff --git a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py b/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py index 9feba9cbf..7d75d8670 100644 --- a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py +++ b/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py @@ -1,10 +1,11 @@ import pytest from requests import exceptions -from griptape.artifacts import TextArtifact, ErrorArtifact + +from griptape.artifacts import ErrorArtifact, TextArtifact class TestGriptapeCloudKnowledgeBaseClient: - @pytest.fixture + @pytest.fixture() def client(self, mocker): from griptape.tools import GriptapeCloudKnowledgeBaseClient @@ -22,7 +23,7 @@ def client(self, mocker): base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) - @pytest.fixture + @pytest.fixture() def client_no_description(self, mocker): from griptape.tools import GriptapeCloudKnowledgeBaseClient @@ -35,7 +36,7 @@ def client_no_description(self, mocker): base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) - @pytest.fixture + @pytest.fixture() def client_kb_not_found(self, mocker): from griptape.tools import GriptapeCloudKnowledgeBaseClient @@ -48,7 +49,7 @@ def client_kb_not_found(self, mocker): base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) - @pytest.fixture + @pytest.fixture() def client_kb_error(self, mocker): from griptape.tools import GriptapeCloudKnowledgeBaseClient @@ -75,10 +76,10 @@ def test_get_knowledge_base_description(self, client): def test_get_knowledge_base_description_error(self, client_no_description): exception_match_text = f"No description found for Knowledge Base {client_no_description.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseClient.description` attribute." - with pytest.raises(ValueError, match=exception_match_text) as e: + with pytest.raises(ValueError, match=exception_match_text): client_no_description._get_knowledge_base_description() def test_get_knowledge_base_kb_error(self, client_kb_not_found): exception_match_text = f"Error accessing Knowledge Base {client_kb_not_found.knowledge_base_id}." - with pytest.raises(ValueError, match=exception_match_text) as e: + with pytest.raises(ValueError, match=exception_match_text): client_kb_not_found._get_knowledge_base_description() diff --git a/tests/unit/tools/test_inpainting_image_generation_client.py b/tests/unit/tools/test_inpainting_image_generation_client.py index 9e1d017bb..14ddcb5b6 100644 --- a/tests/unit/tools/test_inpainting_image_generation_client.py +++ b/tests/unit/tools/test_inpainting_image_generation_client.py @@ -10,18 +10,18 @@ class TestInpaintingImageGenerationClient: - @pytest.fixture + @pytest.fixture() def image_generation_engine(self) -> Mock: return Mock() - @pytest.fixture + @pytest.fixture() def image_loader(self) -> Mock: loader = Mock() loader.load.return_value = ImageArtifact(value=b"image_data", format="png", width=512, height=512) return loader - @pytest.fixture + @pytest.fixture() def image_generator(self, image_generation_engine, image_loader) -> InpaintingImageGenerationClient: return InpaintingImageGenerationClient(engine=image_generation_engine, image_loader=image_loader) @@ -53,7 +53,7 @@ def test_image_inpainting_with_outfile(self, image_generation_engine, image_load engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_openweather_client.py b/tests/unit/tools/test_openweather_client.py index 319a7ec2a..89b80e164 100644 --- a/tests/unit/tools/test_openweather_client.py +++ b/tests/unit/tools/test_openweather_client.py @@ -1,16 +1,18 @@ -import pytest from unittest.mock import patch + +import pytest + from griptape.artifacts import ErrorArtifact from griptape.tools import OpenWeatherClient -@pytest.fixture +@pytest.fixture() def client(): return OpenWeatherClient(api_key="YOUR_API_KEY") class MockResponse: - def __init__(self, json_data, status_code): + def __init__(self, json_data, status_code) -> None: self.json_data = json_data self.status_code = status_code diff --git a/tests/unit/tools/test_outpainting_image_variation_client.py b/tests/unit/tools/test_outpainting_image_variation_client.py index 1a84018a4..d604574cb 100644 --- a/tests/unit/tools/test_outpainting_image_variation_client.py +++ b/tests/unit/tools/test_outpainting_image_variation_client.py @@ -10,18 +10,18 @@ class TestOutpaintingImageGenerationClient: - @pytest.fixture + @pytest.fixture() def image_generation_engine(self) -> Mock: return Mock() - @pytest.fixture + @pytest.fixture() def image_loader(self) -> Mock: loader = Mock() loader.load.return_value = ImageArtifact(value=b"image_data", format="png", width=512, height=512) return loader - @pytest.fixture + @pytest.fixture() def image_generator(self, image_generation_engine, image_loader) -> OutpaintingImageGenerationClient: return OutpaintingImageGenerationClient(engine=image_generation_engine, image_loader=image_loader) @@ -53,7 +53,7 @@ def test_image_outpainting_with_outfile(self, image_generation_engine, image_loa engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_prompt_image_generation_client.py b/tests/unit/tools/test_prompt_image_generation_client.py index dffbb4239..7393d1eff 100644 --- a/tests/unit/tools/test_prompt_image_generation_client.py +++ b/tests/unit/tools/test_prompt_image_generation_client.py @@ -9,11 +9,11 @@ class TestPromptImageGenerationClient: - @pytest.fixture + @pytest.fixture() def image_generation_engine(self) -> Mock: return Mock() - @pytest.fixture + @pytest.fixture() def image_generator(self, image_generation_engine) -> PromptImageGenerationClient: return PromptImageGenerationClient(engine=image_generation_engine) @@ -36,7 +36,7 @@ def test_generate_image_with_outfile(self, image_generation_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = PromptImageGenerationClient(engine=image_generation_engine, output_file=outfile) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_rest_api_client.py b/tests/unit/tools/test_rest_api_client.py index b937b9f23..58f21d1f1 100644 --- a/tests/unit/tools/test_rest_api_client.py +++ b/tests/unit/tools/test_rest_api_client.py @@ -1,9 +1,10 @@ import pytest + from griptape.artifacts import BaseArtifact class TestRestApi: - @pytest.fixture + @pytest.fixture() def client(self): from griptape.tools import RestApiClient diff --git a/tests/unit/tools/test_sql_client.py b/tests/unit/tools/test_sql_client.py index 6584fa752..8ab61fc8f 100644 --- a/tests/unit/tools/test_sql_client.py +++ b/tests/unit/tools/test_sql_client.py @@ -1,12 +1,14 @@ +import sqlite3 + import pytest + from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader from griptape.tools import SqlClient -import sqlite3 class TestSqlClient: - @pytest.fixture + @pytest.fixture() def driver(self): new_driver = SqlDriver(engine_url="sqlite:///:memory:") diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_client.py index b57bfb28f..d498b7c56 100644 --- a/tests/unit/tools/test_structure_run_client.py +++ b/tests/unit/tools/test_structure_run_client.py @@ -1,12 +1,13 @@ import pytest + from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver -from griptape.tools import StructureRunClient from griptape.structures import Agent +from griptape.tools import StructureRunClient from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureRunClient: - @pytest.fixture + @pytest.fixture() def client(self): driver = MockPromptDriver() agent = Agent(prompt_driver=driver) diff --git a/tests/unit/tools/test_task_memory_client.py b/tests/unit/tools/test_task_memory_client.py index 3956ae415..4276b89ec 100644 --- a/tests/unit/tools/test_task_memory_client.py +++ b/tests/unit/tools/test_task_memory_client.py @@ -1,11 +1,12 @@ import pytest + from griptape.artifacts import TextArtifact from griptape.tools import TaskMemoryClient from tests.utils import defaults class TestTaskMemoryClient: - @pytest.fixture + @pytest.fixture() def tool(self): return TaskMemoryClient(off_prompt=True, input_memory=[defaults.text_task_memory("TestMemory")]) diff --git a/tests/unit/tools/test_text_to_speech_client.py b/tests/unit/tools/test_text_to_speech_client.py index 881b1234d..0b9061aa6 100644 --- a/tests/unit/tools/test_text_to_speech_client.py +++ b/tests/unit/tools/test_text_to_speech_client.py @@ -9,11 +9,11 @@ class TestTextToSpeechClient: - @pytest.fixture + @pytest.fixture() def text_to_speech_engine(self) -> Mock: return Mock() - @pytest.fixture + @pytest.fixture() def text_to_speech_client(self, text_to_speech_engine) -> TextToSpeechClient: return TextToSpeechClient(engine=text_to_speech_engine) @@ -32,7 +32,7 @@ def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" text_to_speech_client = TextToSpeechClient(engine=text_to_speech_engine, output_file=outfile) - text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore + text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) diff --git a/tests/unit/tools/test_transcription_client.py b/tests/unit/tools/test_transcription_client.py index ea6bd3453..7768792d0 100644 --- a/tests/unit/tools/test_transcription_client.py +++ b/tests/unit/tools/test_transcription_client.py @@ -7,11 +7,11 @@ class TestTranscriptionClient: - @pytest.fixture + @pytest.fixture() def transcription_engine(self) -> Mock: return Mock() - @pytest.fixture + @pytest.fixture() def audio_loader(self) -> Mock: loader = Mock() loader.load.return_value = AudioArtifact(value=b"audio data", format="wav") @@ -24,7 +24,7 @@ def test_init_transcription_client(self, transcription_engine, audio_loader) -> @patch("builtins.open", mock_open(read_data=b"audio data")) def test_transcribe_audio_from_disk(self, transcription_engine, audio_loader) -> None: client = AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) - client.engine.run.return_value = Mock(value="transcription") # pyright: ignore + client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_disk(params={"values": {"path": "audio.wav"}}) @@ -37,7 +37,7 @@ def test_transcribe_audio_from_memory(self, transcription_engine, audio_loader) memory.load_artifacts = Mock(return_value=[AudioArtifact(value=b"audio data", format="wav", name="name")]) client.find_input_memory = Mock(return_value=memory) - client.engine.run.return_value = Mock(value="transcription") # pyright: ignore + client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_memory( params={"values": {"memory_name": "memory", "artifact_namespace": "namespace", "artifact_name": "name"}} diff --git a/tests/unit/tools/test_variation_image_generation_client.py b/tests/unit/tools/test_variation_image_generation_client.py index b29f4fecf..ba707e5bb 100644 --- a/tests/unit/tools/test_variation_image_generation_client.py +++ b/tests/unit/tools/test_variation_image_generation_client.py @@ -10,18 +10,18 @@ class TestVariationImageGenerationClient: - @pytest.fixture + @pytest.fixture() def image_generation_engine(self) -> Mock: return Mock() - @pytest.fixture + @pytest.fixture() def image_loader(self) -> Mock: loader = Mock() loader.load.return_value = ImageArtifact(value=b"image_data", format="png", width=512, height=512) return loader - @pytest.fixture + @pytest.fixture() def image_generator(self, image_generation_engine, image_loader) -> VariationImageGenerationClient: return VariationImageGenerationClient(engine=image_generation_engine, image_loader=image_loader) @@ -54,7 +54,7 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = Mock( # pyright: ignore + image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_vector_store_client.py b/tests/unit/tools/test_vector_store_client.py index 45018b847..b02dda226 100644 --- a/tests/unit/tools/test_vector_store_client.py +++ b/tests/unit/tools/test_vector_store_client.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import TextArtifact, ListArtifact + +from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers import LocalVectorStoreDriver from griptape.tools import VectorStoreClient from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -7,7 +8,7 @@ class TestVectorStoreClient: @pytest.fixture(autouse=True) - def mock_try_run(self, mocker): + def _mock_try_run(self, mocker): mocker.patch("griptape.drivers.OpenAiEmbeddingDriver.try_embed_chunk", return_value=[0, 1]) def test_search(self): @@ -16,7 +17,7 @@ def test_search(self): driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) - assert set([a.value for a in tool.search({"values": {"query": "test"}})]) == {"foo", "bar"} + assert {a.value for a in tool.search({"values": {"query": "test"}})} == {"foo", "bar"} def test_search_with_namespace(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/tools/test_web_scraper.py b/tests/unit/tools/test_web_scraper.py index f46004e8f..30362ce65 100644 --- a/tests/unit/tools/test_web_scraper.py +++ b/tests/unit/tools/test_web_scraper.py @@ -1,9 +1,10 @@ import pytest + from griptape.artifacts import ListArtifact class TestWebScraper: - @pytest.fixture + @pytest.fixture() def scraper(self): from griptape.tools import WebScraper diff --git a/tests/unit/tools/test_web_search.py b/tests/unit/tools/test_web_search.py index 0abc880c8..dd447de5d 100644 --- a/tests/unit/tools/test_web_search.py +++ b/tests/unit/tools/test_web_search.py @@ -1,10 +1,11 @@ import pytest + from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.tools import WebSearch class TestWebSearch: - @pytest.fixture + @pytest.fixture() def websearch_tool(self, mocker): mock_response = TextArtifact("test_response") driver = mocker.Mock() @@ -12,7 +13,7 @@ def websearch_tool(self, mocker): return WebSearch(web_search_driver=driver) - @pytest.fixture + @pytest.fixture() def websearch_tool_with_error(self, mocker): mock_response = Exception("test_error") driver = mocker.Mock() diff --git a/tests/unit/utils/test_base_tokenizer.py b/tests/unit/utils/test_base_tokenizer.py index eed15b9b2..08fd42c72 100644 --- a/tests/unit/utils/test_base_tokenizer.py +++ b/tests/unit/utils/test_base_tokenizer.py @@ -1,4 +1,5 @@ import logging + from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/unit/utils/test_command_runner.py b/tests/unit/utils/test_command_runner.py index 4ca3afebc..25b7fd8c3 100644 --- a/tests/unit/utils/test_command_runner.py +++ b/tests/unit/utils/test_command_runner.py @@ -1,4 +1,3 @@ -from griptape.artifacts import TextArtifact from griptape.utils import CommandRunner diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index cce067f73..28ee72409 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -1,8 +1,8 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.memory.structure import ConversationMemory, SummaryConversationMemory -from griptape.tasks import PromptTask from griptape.structures import Pipeline +from griptape.tasks import PromptTask from griptape.utils import Conversation +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: diff --git a/tests/unit/utils/test_deprecate.py b/tests/unit/utils/test_deprecate.py index 0c8064f8d..868dbd60f 100644 --- a/tests/unit/utils/test_deprecate.py +++ b/tests/unit/utils/test_deprecate.py @@ -1,4 +1,5 @@ import pytest + from griptape.utils.deprecation import deprecation_warn diff --git a/tests/unit/utils/test_dict_utils.py b/tests/unit/utils/test_dict_utils.py index 4b4e4ca08..94e870e1a 100644 --- a/tests/unit/utils/test_dict_utils.py +++ b/tests/unit/utils/test_dict_utils.py @@ -1,6 +1,7 @@ -from griptape.utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively import pytest +from griptape.utils import dict_merge, remove_key_in_dict_recursively, remove_null_values_in_dict_recursively + class TestDictUtils: def test_remove_null_values_in_dict_recursively(self): diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index dbcf1044b..de1882ef5 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -1,7 +1,8 @@ import os -from griptape.loaders import TextLoader -from griptape import utils from concurrent import futures + +from griptape import utils +from griptape.loaders import TextLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 diff --git a/tests/unit/utils/test_futures.py b/tests/unit/utils/test_futures.py index 5e30148a9..04ddb9877 100644 --- a/tests/unit/utils/test_futures.py +++ b/tests/unit/utils/test_futures.py @@ -1,4 +1,5 @@ from concurrent import futures + from griptape import utils diff --git a/tests/unit/utils/test_import_utils.py b/tests/unit/utils/test_import_utils.py index bcfb06c87..f6b2429d9 100644 --- a/tests/unit/utils/test_import_utils.py +++ b/tests/unit/utils/test_import_utils.py @@ -1,4 +1,5 @@ import pytest + from griptape.utils import import_optional_dependency, is_dependency_installed diff --git a/tests/unit/utils/test_load_artifact_from_memory.py b/tests/unit/utils/test_load_artifact_from_memory.py index db4f7d573..946ddd6e7 100644 --- a/tests/unit/utils/test_load_artifact_from_memory.py +++ b/tests/unit/utils/test_load_artifact_from_memory.py @@ -2,43 +2,39 @@ import pytest -from griptape.artifacts import TextArtifact, ErrorArtifact, ImageArtifact +from griptape.artifacts import ImageArtifact, TextArtifact from griptape.utils import load_artifact_from_memory class TestLoadImageArtifactFromMemory: - @pytest.fixture + @pytest.fixture() def memory(self): return Mock() - @pytest.fixture + @pytest.fixture() def text_artifact(self): return TextArtifact(value="text", name="text") - @pytest.fixture + @pytest.fixture() def image_artifact(self): return ImageArtifact(value=b"image", name="image", format="png", height=32, width=32) def test_no_memory(self): with pytest.raises(ValueError): - load_artifact_from_memory(None, "", "", TextArtifact) # pyright: ignore + load_artifact_from_memory(None, "", "", TextArtifact) # pyright: ignore[reportArgumentType] def test_no_artifacts_in_memory(self, memory): memory.load_artifacts.return_value = [] - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="no artifacts found in namespace"): load_artifact_from_memory(memory, "", "", TextArtifact) - assert str(e) == "no artifacts found in namespace" - def test_no_artifacts_by_name(self, memory, text_artifact): memory.load_artifacts.return_value = [text_artifact] - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="artifact other_name not found in namespace namespace"): load_artifact_from_memory(memory, "namespace", "other_name", TextArtifact) - assert str(e) == "artifact name not found in namespace" - def test_returns_one_artifact(self, memory, text_artifact): memory.load_artifacts.return_value = [text_artifact] @@ -56,7 +52,5 @@ def test_returns_multiple_artifacts(self, memory, text_artifact, image_artifact) def test_wrong_artifact_type(self, memory, image_artifact): memory.load_artifacts.return_value = [image_artifact] - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="image is not of type"): load_artifact_from_memory(memory, "namespace", image_artifact.name, TextArtifact) - - assert str(e) == "artifact is not of type ImageArtifact" diff --git a/tests/unit/utils/test_message_stack.py b/tests/unit/utils/test_message_stack.py index 908388a33..799705a54 100644 --- a/tests/unit/utils/test_message_stack.py +++ b/tests/unit/utils/test_message_stack.py @@ -5,7 +5,7 @@ class TestPromptStack: - @pytest.fixture + @pytest.fixture() def prompt_stack(self): return PromptStack() diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 33c97cc75..767d31ba0 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -1,5 +1,7 @@ -from typing import Iterator +from collections.abc import Iterator + import pytest + from griptape.structures import Agent from griptape.utils import Stream from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -20,9 +22,9 @@ def test_init(self, agent): chat_stream_artifact = next(chat_stream_run) assert chat_stream_artifact.value == "mock output" + next(chat_stream_run) with pytest.raises(StopIteration): next(chat_stream_run) - next(chat_stream_run) else: with pytest.raises(ValueError): Stream(agent) diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py index e16275a5c..f6e621b91 100644 --- a/tests/unit/utils/test_structure_visualizer.py +++ b/tests/unit/utils/test_structure_visualizer.py @@ -1,7 +1,7 @@ -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.utils import StructureVisualizer +from griptape.structures import Agent, Pipeline, Workflow from griptape.tasks import PromptTask -from griptape.structures import Agent, Workflow, Pipeline +from griptape.utils import StructureVisualizer +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureVisualizer: diff --git a/tests/utils/code_blocks.py b/tests/utils/code_blocks.py index 9cfebb987..ca5b193d1 100644 --- a/tests/utils/code_blocks.py +++ b/tests/utils/code_blocks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import pathlib import textwrap @@ -6,7 +8,7 @@ def check_py_string(source: str) -> None: - """Exec the python source given in a new module namespace + """Exec the python source given in a new module namespace. Does not return anything, but exceptions raised by the source will propagate out unmodified diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index 5a9f6f958..bad7f0d79 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -1,11 +1,11 @@ -from griptape.artifacts import TextArtifact, BlobArtifact +from griptape.artifacts import BlobArtifact, TextArtifact from griptape.drivers import LocalVectorStoreDriver -from griptape.engines import PromptSummaryEngine, CsvExtractionEngine, JsonExtractionEngine +from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine -from griptape.engines.rag.modules import VectorStoreRetrievalRagModule, PromptResponseRagModule -from griptape.engines.rag.stages import RetrievalRagStage, ResponseRagStage +from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule +from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.memory import TaskMemory -from griptape.memory.task.storage import TextArtifactStorage, BlobArtifactStorage +from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/utils/postgres.py b/tests/utils/postgres.py index 1e04153bd..a320d9a05 100644 --- a/tests/utils/postgres.py +++ b/tests/utils/postgres.py @@ -1,4 +1,4 @@ -from psycopg2 import connect, OperationalError +from psycopg2 import OperationalError, connect def can_connect_to_postgres(user="postgres", password="postgres", host="localhost", port="5432", database="postgres"): diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 8d62bc835..5b908065b 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -1,25 +1,26 @@ from __future__ import annotations -import os -from attrs import field, define -from schema import Schema, Literal -import logging + import json -from griptape.artifacts.error_artifact import ErrorArtifact +import logging +import os -from griptape.structures import Agent -from griptape.rules import Rule, Ruleset -from griptape.tasks import PromptTask -from griptape.structures import Structure +from attrs import define, field +from schema import Literal, Schema + +from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import ( - BasePromptDriver, AmazonBedrockPromptDriver, + AmazonSageMakerJumpstartPromptDriver, AnthropicPromptDriver, - CoherePromptDriver, - OpenAiChatPromptDriver, AzureOpenAiChatPromptDriver, - AmazonSageMakerJumpstartPromptDriver, + BasePromptDriver, + CoherePromptDriver, GooglePromptDriver, + OpenAiChatPromptDriver, ) +from griptape.rules import Rule, Ruleset +from griptape.structures import Agent, Structure +from griptape.tasks import PromptTask def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]: @@ -295,7 +296,7 @@ def verify_structure_output(self, structure) -> dict: return result - def run(self, prompt, assert_correctness: bool = True) -> dict: + def run(self, prompt, *, assert_correctness: bool = True) -> dict: result = self.structure.run(prompt) if isinstance(result.output_task.output, ErrorArtifact): verified_result = {"correct": False, "explanation": f"ErrorArtifact: {result.output_task.output.to_text()}"}