Skip to content

Commit

Permalink
Refactor Structure running to use Drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed May 2, 2024
1 parent d9ca5d9 commit d7f4417
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 109 deletions.
5 changes: 5 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@
from .file_manager.local_file_manager_driver import LocalFileManagerDriver
from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver

from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver
from .structure_run.local_structure_run_driver import LocalStructureRunDriver

__all__ = [
"BasePromptDriver",
"OpenAiChatPromptDriver",
Expand Down Expand Up @@ -181,4 +184,6 @@
"BaseFileManagerDriver",
"LocalFileManagerDriver",
"AmazonS3FileManagerDriver",
"GriptapeCloudStructureRunDriver",
"LocalStructureRunDriver",
]
File renamed without changes.
11 changes: 11 additions & 0 deletions griptape/drivers/structure_run/base_structure_run_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from attrs import define

from griptape.artifacts.base_artifact import BaseArtifact


@define
class BaseStructureRunDriver(ABC):
@abstractmethod
def run(self, *args) -> BaseArtifact:
...
Original file line number Diff line number Diff line change
@@ -1,62 +1,33 @@
from __future__ import annotations

import time
from typing import Any, Optional
from typing import Any
from urllib.parse import urljoin
from schema import Schema, Literal
from attr import define, field
from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient
from griptape.utils.decorators import activity
from griptape.artifacts import InfoArtifact, TextArtifact, ErrorArtifact

from attrs import Factory, define, field

from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver


@define
class GriptapeCloudStructureRunClient(BaseGriptapeCloudClient):
"""
Attributes:
description: LLM-friendly structure description.
structure_id: ID of the Griptape Cloud Structure.
"""

_description: Optional[str] = field(default=None, kw_only=True)
class GriptapeCloudStructureRunDriver(BaseStructureRunDriver):
base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
api_key: str = field(kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
)
structure_id: str = field(kw_only=True)
structure_run_wait_time_interval: int = field(default=2, kw_only=True)
structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True)

@property
def description(self) -> str:
if self._description is None:
from requests import get

url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/")

response = get(url, headers=self.headers).json()
if "description" in response:
self._description = response["description"]
else:
raise ValueError(f'Error getting Structure description: {response["message"]}')
def run(self, *args) -> BaseArtifact:
from requests import HTTPError, Response, exceptions, post

return self._description

@description.setter
def description(self, value: str) -> None:
self._description = value

@activity(
config={
"description": "Can be used to execute a Run of a Structure with the following description: {{ _self.description }}",
"schema": Schema(
{Literal("args", description="A list of string arguments to submit to the Structure Run"): list}
),
}
)
def execute_structure_run(self, params: dict) -> InfoArtifact | TextArtifact | ErrorArtifact:
from requests import post, exceptions, HTTPError, Response

args: list[str] = params["values"]["args"]
url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")

try:
response: Response = post(url, json={"args": args}, headers=self.headers)
response: Response = post(url, json={"args": list(args)}, headers=self.headers)
response.raise_for_status()
response_json = response.json()
return self._get_structure_run_result(response_json["structure_run_id"])
Expand Down Expand Up @@ -96,4 +67,5 @@ def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any:

response: Response = get(structure_run_url, headers=self.headers)
response.raise_for_status()

return response.json()
23 changes: 23 additions & 0 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from attrs import define, field

from griptape.artifacts import BaseArtifact, ErrorArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver

if TYPE_CHECKING:
from griptape.structures import Structure


@define
class LocalStructureRunDriver(BaseStructureRunDriver):
structure: Structure = field(kw_only=True)

def run(self, *args) -> BaseArtifact:
self.structure.run(*args)

if self.structure.output_task.output is not None:
return self.structure.output_task.output
else:
return ErrorArtifact("Structure did not produce any output.")
18 changes: 5 additions & 13 deletions griptape/tasks/structure_run_task.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
from __future__ import annotations

from attr import define, field
from typing import TYPE_CHECKING

from griptape.artifacts import BaseArtifact, ErrorArtifact
from griptape.artifacts import BaseArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver
from griptape.tasks import BaseTextInputTask

if TYPE_CHECKING:
from griptape.structures import Structure


@define
class StructureRunTask(BaseTextInputTask):
"""Task to run a Structure.
Attributes:
target_structure: Structure to run.
driver: Driver to run the Structure.
"""

target_structure: Structure = field(kw_only=True)
driver: BaseStructureRunDriver = field(kw_only=True)

def run(self) -> BaseArtifact:
self.target_structure.run(self.input)

if self.target_structure.output_task.output is not None:
return self.target_structure.output_task.output
else:
return ErrorArtifact("Structure did not produce any output.")
return self.driver.run(self.input)
2 changes: 0 additions & 2 deletions griptape/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .inpainting_image_generation_client.tool import InpaintingImageGenerationClient
from .outpainting_image_generation_client.tool import OutpaintingImageGenerationClient
from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient
from .griptape_cloud_structure_run_client.tool import GriptapeCloudStructureRunClient
from .griptape_structure_run_client.tool import GriptapeStructureRunClient
from .image_query_client.tool import ImageQueryClient

Expand Down Expand Up @@ -55,7 +54,6 @@
"InpaintingImageGenerationClient",
"OutpaintingImageGenerationClient",
"GriptapeCloudKnowledgeBaseClient",
"GriptapeCloudStructureRunClient",
"GriptapeStructureRunClient",
"ImageQueryClient",
]

This file was deleted.

22 changes: 7 additions & 15 deletions griptape/tools/griptape_structure_run_client/tool.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from attr import define, field
from schema import Literal, Schema

from griptape.artifacts import ErrorArtifact, BaseArtifact
from griptape.artifacts import BaseArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver
from griptape.tools.base_tool import BaseTool
from griptape.utils.decorators import activity

if TYPE_CHECKING:
from griptape.structures import Structure


@define
class GriptapeStructureRunClient(BaseTool):
"""
Attributes:
description: A description of what the Structure does.
structure: A Structure to run.
driver: Driver to run the Structure.
"""

description: str = field(kw_only=True)
target_structure: Structure = field(kw_only=True)
driver: BaseStructureRunDriver = field(kw_only=True)

@activity(
config={
"description": "Can be used to Run a Griptape Structure with the following description: {{ self.description }}",
"schema": Schema({Literal("args", description="Arguments to pass to the Structure Run"): str}),
"description": "Can be used to run a Griptape Structure with the following description: {{ self.description }}",
"schema": Schema({Literal("args", description="Arguments to pass to the Structure Run."): str}),
}
)
def run_structure(self, params: dict) -> BaseArtifact:
args: str = params["values"]["args"]

self.target_structure.run(args)

if self.target_structure.output_task.output is not None:
return self.target_structure.output_task.output
else:
return ErrorArtifact("Structure did not produce any output.")
return self.driver.run(args)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from griptape.artifacts import TextArtifact


class TestGriptapeCloudStructureRunDriver:
@pytest.fixture
def driver(self, mocker):
from griptape.drivers import GriptapeCloudStructureRunDriver

mock_response = mocker.Mock()
mock_response.json.return_value = {"structure_run_id": 1}
mocker.patch("requests.post", return_value=mock_response)

mock_response = mocker.Mock()
mock_response.json.return_value = {"description": "fizz buzz", "output": "fooey booey", "status": "SUCCEEDED"}
mocker.patch("requests.get", return_value=mock_response)

return GriptapeCloudStructureRunDriver(base_url="https://api.griptape.ai", api_key="foo bar", structure_id="1")

def test_run(self, driver):
assert isinstance(driver.run("foo bar"), TextArtifact)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from griptape.tasks import StructureRunTask
from griptape.structures import Agent
from tests.mocks.mock_prompt_driver import MockPromptDriver
from griptape.drivers import LocalStructureRunDriver
from griptape.structures import Pipeline


class TestLocalStructureRunDriver:
@pytest.fixture
def driver(self):
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output"))
driver = LocalStructureRunDriver(structure=agent)

return driver

def test_run(self, driver):
pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output"))

task = StructureRunTask(driver=driver)

pipeline.add_task(task)

assert task.run().to_text() == "agent mock output"
4 changes: 3 additions & 1 deletion tests/unit/tasks/test_structure_run_task.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from griptape.tasks import StructureRunTask
from griptape.structures import Agent
from tests.mocks.mock_prompt_driver import MockPromptDriver
from griptape.drivers import LocalStructureRunDriver
from griptape.structures import Pipeline


class TestStructureRunTask:
def test_run(self):
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output"))
pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output"))
driver = LocalStructureRunDriver(structure=agent)

task = StructureRunTask(target_structure=agent)
task = StructureRunTask(driver=driver)

pipeline.add_task(task)

Expand Down
27 changes: 0 additions & 27 deletions tests/unit/tools/test_griptape_cloud_structure_run_client.py

This file was deleted.

4 changes: 3 additions & 1 deletion tests/unit/tools/test_griptape_structure_run_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver
from griptape.tools import GriptapeStructureRunClient
from griptape.structures import Agent
from tests.mocks.mock_prompt_driver import MockPromptDriver
Expand All @@ -10,7 +11,8 @@ def client(self):
driver = MockPromptDriver()
agent = Agent(prompt_driver=driver)

return GriptapeStructureRunClient(target_structure=agent, description="fizz buzz")
return GriptapeStructureRunClient(description="foo bar", driver=LocalStructureRunDriver(structure=agent))

def test_run_structure(self, client):
assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output"
assert client.description == "foo bar"

0 comments on commit d7f4417

Please sign in to comment.