-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b936419
commit fe7d255
Showing
10 changed files
with
236 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from griptape.artifacts import InfoArtifact | ||
from griptape.loaders import ImageLoader | ||
from griptape.structures import Workflow | ||
from griptape.tasks import BranchTask, PromptTask | ||
|
||
|
||
def on_run(task: BranchTask) -> InfoArtifact: | ||
if "hot dog" in task.input.value: | ||
return InfoArtifact("hot_dog") | ||
else: | ||
return InfoArtifact("not_hot_dog") | ||
|
||
|
||
image_artifact = ImageLoader().load("tests/resources/mountain.png") | ||
|
||
workflow = Workflow( | ||
tasks=[ | ||
PromptTask(["What is in this image?", image_artifact], child_ids=["branch"]), | ||
BranchTask( | ||
"{{ parents_output_text }}", | ||
on_run=on_run, | ||
id="branch", | ||
child_ids=["hot_dog", "not_hot_dog"], | ||
), | ||
PromptTask("Tell me about hot dogs", id="hot_dog"), | ||
PromptTask("Tell me about anything but hotdogs", id="not_hot_dog"), | ||
] | ||
) | ||
|
||
workflow.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Callable, Union | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import InfoArtifact, ListArtifact | ||
from griptape.tasks import BaseTask, CodeExecutionTask | ||
|
||
|
||
@define | ||
class BranchTask(CodeExecutionTask): | ||
on_run: Callable[[BranchTask], Union[InfoArtifact, ListArtifact[InfoArtifact]]] = field(kw_only=True) | ||
|
||
def try_run(self) -> InfoArtifact | ListArtifact[InfoArtifact]: | ||
result = self.on_run(self) | ||
|
||
if isinstance(result, ListArtifact): | ||
branch_task_ids = {artifact.value for artifact in result} | ||
else: | ||
branch_task_ids = {result.value} | ||
|
||
if not all(branch_task_id in self.child_ids for branch_task_id in branch_task_ids): | ||
raise ValueError(f"Branch task returned invalid child task id {branch_task_ids}") | ||
|
||
if self.structure is not None: | ||
children_to_skip = [child for child in self.children if child.id not in branch_task_ids] | ||
for child in children_to_skip: | ||
child.state = BaseTask.State.SKIPPED | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from griptape.artifacts import InfoArtifact, ListArtifact | ||
from griptape.artifacts.error_artifact import ErrorArtifact | ||
from griptape.structures import Workflow | ||
from griptape.tasks import BaseTask, PromptTask | ||
from griptape.tasks.branch_task import BranchTask | ||
|
||
|
||
class TestBranchTask: | ||
def test_one_branch(self): | ||
def on_run(_: BranchTask) -> InfoArtifact: | ||
return InfoArtifact("2") | ||
|
||
workflow = Workflow( | ||
tasks=[ | ||
PromptTask(id="1", child_ids=["branch"]), | ||
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]), | ||
PromptTask(id="2", child_ids=["4"]), | ||
PromptTask(id="3", child_ids=["4"]), | ||
PromptTask(id="4"), | ||
] | ||
) | ||
workflow.run() | ||
|
||
assert workflow.find_task("1").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("branch").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("2").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("3").state == BaseTask.State.SKIPPED | ||
assert workflow.find_task("4").state == BaseTask.State.FINISHED | ||
assert workflow.is_finished() | ||
|
||
def test_multi_branch(self): | ||
def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]: | ||
return ListArtifact([InfoArtifact("2"), InfoArtifact("3")]) | ||
|
||
workflow = Workflow( | ||
tasks=[ | ||
PromptTask(id="1", child_ids=["branch"]), | ||
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]), | ||
PromptTask(id="2", child_ids=["4"]), | ||
PromptTask(id="3", child_ids=["4"]), | ||
PromptTask(id="4"), | ||
] | ||
) | ||
workflow.run() | ||
|
||
assert workflow.find_task("1").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("branch").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("2").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("3").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("4").state == BaseTask.State.FINISHED | ||
assert workflow.is_finished() | ||
|
||
def test_no_branch(self): | ||
def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]: | ||
return ListArtifact([]) | ||
|
||
workflow = Workflow( | ||
tasks=[ | ||
PromptTask(id="1", child_ids=["branch"]), | ||
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]), | ||
PromptTask(id="2", child_ids=["4"]), | ||
PromptTask(id="3", child_ids=["4"]), | ||
PromptTask(id="4"), | ||
] | ||
) | ||
workflow.run() | ||
|
||
assert workflow.find_task("1").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("branch").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("2").state == BaseTask.State.SKIPPED | ||
assert workflow.find_task("3").state == BaseTask.State.SKIPPED | ||
assert workflow.find_task("4").state == BaseTask.State.SKIPPED | ||
assert workflow.is_finished() | ||
|
||
def test_no_structure(self): | ||
def on_run(_: BranchTask) -> InfoArtifact: | ||
return InfoArtifact("2") | ||
|
||
task = BranchTask(input="foo", id="branch", on_run=on_run, child_ids=["2", "3"]) | ||
result = task.run() | ||
|
||
assert result.value == "2" | ||
|
||
def test_bad_branch(self): | ||
def on_run(_: BranchTask) -> InfoArtifact: | ||
return InfoArtifact("42") | ||
|
||
workflow = Workflow( | ||
tasks=[ | ||
PromptTask(id="1", child_ids=["branch"]), | ||
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]), | ||
PromptTask(id="2", child_ids=["4"]), | ||
PromptTask(id="3", child_ids=["4"]), | ||
PromptTask(id="4"), | ||
] | ||
) | ||
workflow.run() | ||
|
||
assert isinstance(workflow.tasks[1].output, ErrorArtifact) | ||
assert workflow.tasks[1].output.value == "Branch task returned invalid child task id {'42'}" | ||
|
||
assert workflow.find_task("1").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("branch").state == BaseTask.State.FINISHED | ||
assert workflow.find_task("2").state == BaseTask.State.PENDING | ||
assert workflow.find_task("3").state == BaseTask.State.PENDING | ||
assert workflow.find_task("4").state == BaseTask.State.PENDING | ||
assert not workflow.is_finished() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters