-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Structure running to use Drivers
- Loading branch information
1 parent
d9ca5d9
commit d7f4417
Showing
15 changed files
with
119 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
11 changes: 11 additions & 0 deletions
11
griptape/drivers/structure_run/base_structure_run_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from abc import ABC, abstractmethod | ||
from attrs import define | ||
|
||
from griptape.artifacts.base_artifact import BaseArtifact | ||
|
||
|
||
@define | ||
class BaseStructureRunDriver(ABC): | ||
@abstractmethod | ||
def run(self, *args) -> BaseArtifact: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
griptape/drivers/structure_run/local_structure_run_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import BaseArtifact, ErrorArtifact | ||
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver | ||
|
||
if TYPE_CHECKING: | ||
from griptape.structures import Structure | ||
|
||
|
||
@define | ||
class LocalStructureRunDriver(BaseStructureRunDriver): | ||
structure: Structure = field(kw_only=True) | ||
|
||
def run(self, *args) -> BaseArtifact: | ||
self.structure.run(*args) | ||
|
||
if self.structure.output_task.output is not None: | ||
return self.structure.output_task.output | ||
else: | ||
return ErrorArtifact("Structure did not produce any output.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
from attr import define, field | ||
from typing import TYPE_CHECKING | ||
|
||
from griptape.artifacts import BaseArtifact, ErrorArtifact | ||
from griptape.artifacts import BaseArtifact | ||
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver | ||
from griptape.tasks import BaseTextInputTask | ||
|
||
if TYPE_CHECKING: | ||
from griptape.structures import Structure | ||
|
||
|
||
@define | ||
class StructureRunTask(BaseTextInputTask): | ||
"""Task to run a Structure. | ||
Attributes: | ||
target_structure: Structure to run. | ||
driver: Driver to run the Structure. | ||
""" | ||
|
||
target_structure: Structure = field(kw_only=True) | ||
driver: BaseStructureRunDriver = field(kw_only=True) | ||
|
||
def run(self) -> BaseArtifact: | ||
self.target_structure.run(self.input) | ||
|
||
if self.target_structure.output_task.output is not None: | ||
return self.target_structure.output_task.output | ||
else: | ||
return ErrorArtifact("Structure did not produce any output.") | ||
return self.driver.run(self.input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 0 additions & 5 deletions
5
griptape/tools/griptape_cloud_structure_run_client/manifest.yml
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,32 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
from attr import define, field | ||
from schema import Literal, Schema | ||
|
||
from griptape.artifacts import ErrorArtifact, BaseArtifact | ||
from griptape.artifacts import BaseArtifact | ||
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver | ||
from griptape.tools.base_tool import BaseTool | ||
from griptape.utils.decorators import activity | ||
|
||
if TYPE_CHECKING: | ||
from griptape.structures import Structure | ||
|
||
|
||
@define | ||
class GriptapeStructureRunClient(BaseTool): | ||
""" | ||
Attributes: | ||
description: A description of what the Structure does. | ||
structure: A Structure to run. | ||
driver: Driver to run the Structure. | ||
""" | ||
|
||
description: str = field(kw_only=True) | ||
target_structure: Structure = field(kw_only=True) | ||
driver: BaseStructureRunDriver = field(kw_only=True) | ||
|
||
@activity( | ||
config={ | ||
"description": "Can be used to Run a Griptape Structure with the following description: {{ self.description }}", | ||
"schema": Schema({Literal("args", description="Arguments to pass to the Structure Run"): str}), | ||
"description": "Can be used to run a Griptape Structure with the following description: {{ self.description }}", | ||
"schema": Schema({Literal("args", description="Arguments to pass to the Structure Run."): str}), | ||
} | ||
) | ||
def run_structure(self, params: dict) -> BaseArtifact: | ||
args: str = params["values"]["args"] | ||
|
||
self.target_structure.run(args) | ||
|
||
if self.target_structure.output_task.output is not None: | ||
return self.target_structure.output_task.output | ||
else: | ||
return ErrorArtifact("Structure did not produce any output.") | ||
return self.driver.run(args) |
Empty file.
21 changes: 21 additions & 0 deletions
21
tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
from griptape.artifacts import TextArtifact | ||
|
||
|
||
class TestGriptapeCloudStructureRunDriver: | ||
@pytest.fixture | ||
def driver(self, mocker): | ||
from griptape.drivers import GriptapeCloudStructureRunDriver | ||
|
||
mock_response = mocker.Mock() | ||
mock_response.json.return_value = {"structure_run_id": 1} | ||
mocker.patch("requests.post", return_value=mock_response) | ||
|
||
mock_response = mocker.Mock() | ||
mock_response.json.return_value = {"description": "fizz buzz", "output": "fooey booey", "status": "SUCCEEDED"} | ||
mocker.patch("requests.get", return_value=mock_response) | ||
|
||
return GriptapeCloudStructureRunDriver(base_url="https://api.griptape.ai", api_key="foo bar", structure_id="1") | ||
|
||
def test_run(self, driver): | ||
assert isinstance(driver.run("foo bar"), TextArtifact) |
24 changes: 24 additions & 0 deletions
24
tests/unit/drivers/structure_run/test_local_structure_run_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pytest | ||
from griptape.tasks import StructureRunTask | ||
from griptape.structures import Agent | ||
from tests.mocks.mock_prompt_driver import MockPromptDriver | ||
from griptape.drivers import LocalStructureRunDriver | ||
from griptape.structures import Pipeline | ||
|
||
|
||
class TestLocalStructureRunDriver: | ||
@pytest.fixture | ||
def driver(self): | ||
agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) | ||
driver = LocalStructureRunDriver(structure=agent) | ||
|
||
return driver | ||
|
||
def test_run(self, driver): | ||
pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) | ||
|
||
task = StructureRunTask(driver=driver) | ||
|
||
pipeline.add_task(task) | ||
|
||
assert task.run().to_text() == "agent mock output" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 0 additions & 27 deletions
27
tests/unit/tools/test_griptape_cloud_structure_run_client.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters