diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 7409900311..586511f807 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -97,6 +97,9 @@ from .file_manager.local_file_manager_driver import LocalFileManagerDriver from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver +from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver +from .structure_run.local_structure_run_driver import LocalStructureRunDriver + __all__ = [ "BasePromptDriver", "OpenAiChatPromptDriver", @@ -181,4 +184,6 @@ "BaseFileManagerDriver", "LocalFileManagerDriver", "AmazonS3FileManagerDriver", + "GriptapeCloudStructureRunDriver", + "LocalStructureRunDriver", ] diff --git a/griptape/tools/griptape_cloud_structure_run_client/__init__.py b/griptape/drivers/structure_run/__init__.py similarity index 100% rename from griptape/tools/griptape_cloud_structure_run_client/__init__.py rename to griptape/drivers/structure_run/__init__.py diff --git a/griptape/drivers/structure_run/base_structure_run_driver.py b/griptape/drivers/structure_run/base_structure_run_driver.py new file mode 100644 index 0000000000..eff87c6f3d --- /dev/null +++ b/griptape/drivers/structure_run/base_structure_run_driver.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from attrs import define + +from griptape.artifacts.base_artifact import BaseArtifact + + +@define +class BaseStructureRunDriver(ABC): + @abstractmethod + def run(self, *args) -> BaseArtifact: + ... diff --git a/griptape/tools/griptape_cloud_structure_run_client/tool.py b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py similarity index 53% rename from griptape/tools/griptape_cloud_structure_run_client/tool.py rename to griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py index 7d4edbd720..a500fec4da 100644 --- a/griptape/tools/griptape_cloud_structure_run_client/tool.py +++ b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py @@ -1,62 +1,33 @@ from __future__ import annotations + import time -from typing import Any, Optional +from typing import Any from urllib.parse import urljoin -from schema import Schema, Literal -from attr import define, field -from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient -from griptape.utils.decorators import activity -from griptape.artifacts import InfoArtifact, TextArtifact, ErrorArtifact + +from attrs import Factory, define, field + +from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver @define -class GriptapeCloudStructureRunClient(BaseGriptapeCloudClient): - """ - Attributes: - description: LLM-friendly structure description. - structure_id: ID of the Griptape Cloud Structure. - """ - - _description: Optional[str] = field(default=None, kw_only=True) +class GriptapeCloudStructureRunDriver(BaseStructureRunDriver): + base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + api_key: str = field(kw_only=True) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True + ) structure_id: str = field(kw_only=True) structure_run_wait_time_interval: int = field(default=2, kw_only=True) structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True) - @property - def description(self) -> str: - if self._description is None: - from requests import get - - url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/") - - response = get(url, headers=self.headers).json() - if "description" in response: - self._description = response["description"] - else: - raise ValueError(f'Error getting Structure description: {response["message"]}') + def run(self, *args) -> BaseArtifact: + from requests import HTTPError, Response, exceptions, post - return self._description - - @description.setter - def description(self, value: str) -> None: - self._description = value - - @activity( - config={ - "description": "Can be used to execute a Run of a Structure with the following description: {{ _self.description }}", - "schema": Schema( - {Literal("args", description="A list of string arguments to submit to the Structure Run"): list} - ), - } - ) - def execute_structure_run(self, params: dict) -> InfoArtifact | TextArtifact | ErrorArtifact: - from requests import post, exceptions, HTTPError, Response - - args: list[str] = params["values"]["args"] url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs") try: - response: Response = post(url, json={"args": args}, headers=self.headers) + response: Response = post(url, json={"args": list(args)}, headers=self.headers) response.raise_for_status() response_json = response.json() return self._get_structure_run_result(response_json["structure_run_id"]) @@ -96,4 +67,5 @@ def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any: response: Response = get(structure_run_url, headers=self.headers) response.raise_for_status() + return response.json() diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py new file mode 100644 index 0000000000..631df7af2f --- /dev/null +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from attrs import define, field + +from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver + +if TYPE_CHECKING: + from griptape.structures import Structure + + +@define +class LocalStructureRunDriver(BaseStructureRunDriver): + structure: Structure = field(kw_only=True) + + def run(self, *args) -> BaseArtifact: + self.structure.run(*args) + + if self.structure.output_task.output is not None: + return self.structure.output_task.output + else: + return ErrorArtifact("Structure did not produce any output.") diff --git a/griptape/tasks/structure_run_task.py b/griptape/tasks/structure_run_task.py index 7340ec434a..85e5942010 100644 --- a/griptape/tasks/structure_run_task.py +++ b/griptape/tasks/structure_run_task.py @@ -1,29 +1,21 @@ from __future__ import annotations from attr import define, field -from typing import TYPE_CHECKING -from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver from griptape.tasks import BaseTextInputTask -if TYPE_CHECKING: - from griptape.structures import Structure - @define class StructureRunTask(BaseTextInputTask): """Task to run a Structure. Attributes: - target_structure: Structure to run. + driver: Driver to run the Structure. """ - target_structure: Structure = field(kw_only=True) + driver: BaseStructureRunDriver = field(kw_only=True) def run(self) -> BaseArtifact: - self.target_structure.run(self.input) - - if self.target_structure.output_task.output is not None: - return self.target_structure.output_task.output - else: - return ErrorArtifact("Structure did not produce any output.") + return self.driver.run(self.input) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index a9627ed0ab..6788fc9ef8 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -24,7 +24,6 @@ from .inpainting_image_generation_client.tool import InpaintingImageGenerationClient from .outpainting_image_generation_client.tool import OutpaintingImageGenerationClient from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient -from .griptape_cloud_structure_run_client.tool import GriptapeCloudStructureRunClient from .griptape_structure_run_client.tool import GriptapeStructureRunClient from .image_query_client.tool import ImageQueryClient @@ -55,7 +54,6 @@ "InpaintingImageGenerationClient", "OutpaintingImageGenerationClient", "GriptapeCloudKnowledgeBaseClient", - "GriptapeCloudStructureRunClient", "GriptapeStructureRunClient", "ImageQueryClient", ] diff --git a/griptape/tools/griptape_cloud_structure_run_client/manifest.yml b/griptape/tools/griptape_cloud_structure_run_client/manifest.yml deleted file mode 100644 index 5741fa336a..0000000000 --- a/griptape/tools/griptape_cloud_structure_run_client/manifest.yml +++ /dev/null @@ -1,5 +0,0 @@ -version: "v1" -name: Griptape Cloud Structure Run Client -description: Tool for using the Griptape Cloud Structure Run API. -contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/griptape_structure_run_client/tool.py b/griptape/tools/griptape_structure_run_client/tool.py index 103ede2281..deedf0bbdb 100644 --- a/griptape/tools/griptape_structure_run_client/tool.py +++ b/griptape/tools/griptape_structure_run_client/tool.py @@ -1,40 +1,32 @@ from __future__ import annotations -from typing import TYPE_CHECKING from attr import define, field from schema import Literal, Schema -from griptape.artifacts import ErrorArtifact, BaseArtifact +from griptape.artifacts import BaseArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver from griptape.tools.base_tool import BaseTool from griptape.utils.decorators import activity -if TYPE_CHECKING: - from griptape.structures import Structure - @define class GriptapeStructureRunClient(BaseTool): """ Attributes: description: A description of what the Structure does. - structure: A Structure to run. + driver: Driver to run the Structure. """ description: str = field(kw_only=True) - target_structure: Structure = field(kw_only=True) + driver: BaseStructureRunDriver = field(kw_only=True) @activity( config={ - "description": "Can be used to Run a Griptape Structure with the following description: {{ self.description }}", - "schema": Schema({Literal("args", description="Arguments to pass to the Structure Run"): str}), + "description": "Can be used to run a Griptape Structure with the following description: {{ self.description }}", + "schema": Schema({Literal("args", description="Arguments to pass to the Structure Run."): str}), } ) def run_structure(self, params: dict) -> BaseArtifact: args: str = params["values"]["args"] - self.target_structure.run(args) - - if self.target_structure.output_task.output is not None: - return self.target_structure.output_task.output - else: - return ErrorArtifact("Structure did not produce any output.") + return self.driver.run(args) diff --git a/tests/unit/drivers/structure_run/__init__.py b/tests/unit/drivers/structure_run/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py new file mode 100644 index 0000000000..0579ab1d41 --- /dev/null +++ b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py @@ -0,0 +1,21 @@ +import pytest +from griptape.artifacts import TextArtifact + + +class TestGriptapeCloudStructureRunDriver: + @pytest.fixture + def driver(self, mocker): + from griptape.drivers import GriptapeCloudStructureRunDriver + + mock_response = mocker.Mock() + mock_response.json.return_value = {"structure_run_id": 1} + mocker.patch("requests.post", return_value=mock_response) + + mock_response = mocker.Mock() + mock_response.json.return_value = {"description": "fizz buzz", "output": "fooey booey", "status": "SUCCEEDED"} + mocker.patch("requests.get", return_value=mock_response) + + return GriptapeCloudStructureRunDriver(base_url="https://api.griptape.ai", api_key="foo bar", structure_id="1") + + def test_run(self, driver): + assert isinstance(driver.run("foo bar"), TextArtifact) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py new file mode 100644 index 0000000000..596c2ec745 --- /dev/null +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -0,0 +1,24 @@ +import pytest +from griptape.tasks import StructureRunTask +from griptape.structures import Agent +from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Pipeline + + +class TestLocalStructureRunDriver: + @pytest.fixture + def driver(self): + agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) + driver = LocalStructureRunDriver(structure=agent) + + return driver + + def test_run(self, driver): + pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + + task = StructureRunTask(driver=driver) + + pipeline.add_task(task) + + assert task.run().to_text() == "agent mock output" diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 781dff4394..2e3d8ade1e 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -1,6 +1,7 @@ from griptape.tasks import StructureRunTask from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.drivers import LocalStructureRunDriver from griptape.structures import Pipeline @@ -8,8 +9,9 @@ class TestStructureRunTask: def test_run(self): agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + driver = LocalStructureRunDriver(structure=agent) - task = StructureRunTask(target_structure=agent) + task = StructureRunTask(driver=driver) pipeline.add_task(task) diff --git a/tests/unit/tools/test_griptape_cloud_structure_run_client.py b/tests/unit/tools/test_griptape_cloud_structure_run_client.py deleted file mode 100644 index 8656a8f772..0000000000 --- a/tests/unit/tools/test_griptape_cloud_structure_run_client.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -from griptape.artifacts import TextArtifact - - -class TestGriptapeCloudStructureRunClient: - @pytest.fixture - def client(self, mocker): - from griptape.tools import GriptapeCloudStructureRunClient - - mock_response = mocker.Mock() - mock_response.json.return_value = {"structure_run_id": 1} - mocker.patch("requests.post", return_value=mock_response) - - mock_response = mocker.Mock() - mock_response.json.return_value = {"description": "fizz buzz", "output": "fooey booey", "status": "SUCCEEDED"} - mocker.patch("requests.get", return_value=mock_response) - - return GriptapeCloudStructureRunClient(base_url="https://api.griptape.ai", api_key="foo bar", structure_id="1") - - def test_execute_structure_run(self, client): - assert isinstance(client.execute_structure_run({"values": {"args": ["foo bar"]}}), TextArtifact) - - def test_get_structure_description(self, client): - assert client.description == "fizz buzz" - - client.description = "foo bar" - assert client.description == "foo bar" diff --git a/tests/unit/tools/test_griptape_structure_run_client.py b/tests/unit/tools/test_griptape_structure_run_client.py index baef209e5c..fb2df5628c 100644 --- a/tests/unit/tools/test_griptape_structure_run_client.py +++ b/tests/unit/tools/test_griptape_structure_run_client.py @@ -1,4 +1,5 @@ import pytest +from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver from griptape.tools import GriptapeStructureRunClient from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -10,7 +11,8 @@ def client(self): driver = MockPromptDriver() agent = Agent(prompt_driver=driver) - return GriptapeStructureRunClient(target_structure=agent, description="fizz buzz") + return GriptapeStructureRunClient(description="foo bar", driver=LocalStructureRunDriver(structure=agent)) def test_run_structure(self, client): assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output" + assert client.description == "foo bar"