From 6fcc953051fa7603b747b1b4b61fcc6e51e6f661 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 25 Oct 2024 12:40:02 -0700 Subject: [PATCH] Add some tests --- .../tasks/test_base_image_generation_task.py | 15 ++++++++ tests/unit/tools/test_image_query_tool.py | 36 +++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 tests/unit/tools/test_image_query_tool.py diff --git a/tests/unit/tasks/test_base_image_generation_task.py b/tests/unit/tasks/test_base_image_generation_task.py index c4272d78c0..5c974aa370 100644 --- a/tests/unit/tasks/test_base_image_generation_task.py +++ b/tests/unit/tasks/test_base_image_generation_task.py @@ -36,3 +36,18 @@ def test_negative_rulesets_from_rules(self) -> None: def test_validate_output_dir(self) -> None: with pytest.raises(ValueError): MockImageGenerationTask(TextArtifact("some input"), output_dir="some/dir", output_file="some/file") + + def test__get_prompts(self): + task = MockImageGenerationTask( + TextArtifact("some input"), rulesets=[Ruleset(name="Ruleset", rules=[Rule(value="Rule")])] + ) + + assert task._get_prompts(task.input.to_text()) == ["some input", "Rule"] + + def test__get_negative_prompts(self): + task = MockImageGenerationTask( + TextArtifact("some input"), + negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])], + ) + + assert task._get_negative_prompts() == ["Negative Rule"] diff --git a/tests/unit/tools/test_image_query_tool.py b/tests/unit/tools/test_image_query_tool.py new file mode 100644 index 0000000000..630f1bc4d8 --- /dev/null +++ b/tests/unit/tools/test_image_query_tool.py @@ -0,0 +1,36 @@ +import pytest + +from griptape.artifacts.image_artifact import ImageArtifact +from griptape.tools import ImageQueryTool +from tests.mocks.mock_image_query_driver import MockImageQueryDriver +from tests.utils import defaults + + +class TestImageQueryTool: + @pytest.fixture() + def tool(self): + task_memory = defaults.text_task_memory("memory_name") + task_memory.store_artifact("namespace", ImageArtifact(b"", format="png", width=1, height=1, name="test")) + return ImageQueryTool(input_memory=[task_memory], image_query_driver=MockImageQueryDriver()) + + def test_query_image_from_disk(self, tool): + assert tool.query_image_from_disk({"values": {"query": "test", "image_paths": []}}).value == "mock text" + + def test_query_images_from_memory(self, tool): + assert ( + tool.query_images_from_memory( + { + "values": { + "query": "test", + "memory_name": tool.input_memory[0].name, + "image_artifacts": [ + { + "image_artifact_name": "test", + "image_artifact_namespace": "namespace", + } + ], + } + } + ).value + == "mock text" + )