Skip to content

Commit

Permalink
Add branch task (#1404)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Dec 12, 2024
1 parent b936419 commit fe7d255
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task.
- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `per_structure`.
- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `Structure.ConversationMemoryStrategy.PER_STRUCTURE`.
- `BranchTask` for selecting which Tasks (if any) to run based on a condition.
- Support for `BranchTask` in `StructureVisualizer`.

## [1.0.0] - 2024-12-09

Expand Down
30 changes: 30 additions & 0 deletions docs/griptape-framework/structures/src/branch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from griptape.artifacts import InfoArtifact
from griptape.loaders import ImageLoader
from griptape.structures import Workflow
from griptape.tasks import BranchTask, PromptTask


def on_run(task: BranchTask) -> InfoArtifact:
if "hot dog" in task.input.value:
return InfoArtifact("hot_dog")
else:
return InfoArtifact("not_hot_dog")


image_artifact = ImageLoader().load("tests/resources/mountain.png")

workflow = Workflow(
tasks=[
PromptTask(["What is in this image?", image_artifact], child_ids=["branch"]),
BranchTask(
"{{ parents_output_text }}",
on_run=on_run,
id="branch",
child_ids=["hot_dog", "not_hot_dog"],
),
PromptTask("Tell me about hot dogs", id="hot_dog"),
PromptTask("Tell me about anything but hotdogs", id="not_hot_dog"),
]
)

workflow.run()
16 changes: 16 additions & 0 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ This task takes a python function, and authors can elect to return a custom arti
Output: "Silent code, loud impact."
```

## Branch Task

By default, a [Workflow](../structures/workflows.md) will only run a Task when all the Tasks it depends on have finished.

You can make use of [BranchTask](../../reference/griptape/tasks/branch_task.md) in order to tell the Workflow not to run all dependent tasks, but instead to pick and choose one or more paths to go down.

The `BranchTask`'s [on_run](../../reference/griptape/tasks/branch_task.md#griptape.tasks.branch_task.BranchTask.on_run) function can return one of three things:

1. An `InfoArtifact` containing the `id` of the Task to run next.
1. A `ListArtifact` of `InfoArtifact`s containing the `id`s of the Tasks to run next.
1. An _empty_ `ListArtifact` to indicate that no Tasks should be run next.

```python
--8<-- "docs/griptape-framework/structures/src/branch_task.py"
```

## Image Generation Tasks

To generate an image, use one of the following [Image Generation Tasks](../../reference/griptape/tasks/index.md). All Image Generation Tasks accept an [Image Generation Driver](../drivers/image-generation-drivers.md).
Expand Down
2 changes: 1 addition & 1 deletion griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def finished_tasks(self) -> list[BaseTask]:
return [s for s in self.tasks if s.is_finished()]

def is_finished(self) -> bool:
return all(s.is_finished() for s in self.tasks)
return all(not s.can_run() for s in self.tasks)

def is_running(self) -> bool:
return any(s for s in self.tasks if s.is_running())
Expand Down
2 changes: 2 additions & 0 deletions griptape/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .structure_run_task import StructureRunTask
from .audio_transcription_task import AudioTranscriptionTask
from .assistant_task import AssistantTask
from .branch_task import BranchTask

__all__ = [
"BaseTask",
Expand All @@ -40,4 +41,5 @@
"StructureRunTask",
"AudioTranscriptionTask",
"AssistantTask",
"BranchTask",
]
18 changes: 17 additions & 1 deletion griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class State(Enum):
PENDING = 1
RUNNING = 2
FINISHED = 3
SKIPPED = 4

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
state: State = field(default=State.PENDING, kw_only=True, metadata={"serializable": True})
Expand Down Expand Up @@ -136,6 +137,9 @@ def is_finished(self) -> bool:
def is_running(self) -> bool:
return self.state == BaseTask.State.RUNNING

def is_skipped(self) -> bool:
return self.state == BaseTask.State.SKIPPED

def before_run(self) -> None:
super().before_run()
if self.structure is not None:
Expand Down Expand Up @@ -181,7 +185,19 @@ def after_run(self) -> None:
)

def can_run(self) -> bool:
return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents)
# If this Task has been skipped or is not pending, it should not run
if self.is_skipped() or not self.is_pending():
return False

# If this Task has parents, and _all_ of them are skipped, it should not run
if self.parents and all(parent.is_skipped() for parent in self.parents):
self.state = BaseTask.State.SKIPPED
return False

# If _all_ this Task's unskipped parents are finished, it should run
unskipped_parents = [parent for parent in self.parents if not parent.is_skipped()]

return all(parent.is_finished() for parent in unskipped_parents)

def reset(self) -> BaseTask:
self.state = BaseTask.State.PENDING
Expand Down
31 changes: 31 additions & 0 deletions griptape/tasks/branch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import Callable, Union

from attrs import define, field

from griptape.artifacts import InfoArtifact, ListArtifact
from griptape.tasks import BaseTask, CodeExecutionTask


@define
class BranchTask(CodeExecutionTask):
on_run: Callable[[BranchTask], Union[InfoArtifact, ListArtifact[InfoArtifact]]] = field(kw_only=True)

def try_run(self) -> InfoArtifact | ListArtifact[InfoArtifact]:
result = self.on_run(self)

if isinstance(result, ListArtifact):
branch_task_ids = {artifact.value for artifact in result}
else:
branch_task_ids = {result.value}

if not all(branch_task_id in self.child_ids for branch_task_id in branch_task_ids):
raise ValueError(f"Branch task returned invalid child task id {branch_task_ids}")

if self.structure is not None:
children_to_skip = [child for child in self.children if child.id not in branch_task_ids]
for child in children_to_skip:
child.state = BaseTask.State.SKIPPED

return result
7 changes: 5 additions & 2 deletions griptape/utils/structure_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ def __render_tasks(self, tasks: list[BaseTask]) -> str:

def __render_task(self, task: BaseTask) -> str:
from griptape.drivers import LocalStructureRunDriver
from griptape.tasks import StructureRunTask
from griptape.tasks import BranchTask, StructureRunTask

parts = []
if task.children:
children = " & ".join([f"{self.build_node_id(child)}" for child in task.children])
parts.append(f"{self.build_node_id(task)}--> {children};")
if isinstance(task, BranchTask):
parts.append(f"{self.build_node_id(task)}{{ {self.build_node_id(task)} }}-.-> {children};")
else:
parts.append(f"{self.build_node_id(task)}--> {children};")
else:
parts.append(f"{self.build_node_id(task)};")

Expand Down
107 changes: 107 additions & 0 deletions tests/unit/tasks/test_branch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from griptape.artifacts import InfoArtifact, ListArtifact
from griptape.artifacts.error_artifact import ErrorArtifact
from griptape.structures import Workflow
from griptape.tasks import BaseTask, PromptTask
from griptape.tasks.branch_task import BranchTask


class TestBranchTask:
def test_one_branch(self):
def on_run(_: BranchTask) -> InfoArtifact:
return InfoArtifact("2")

workflow = Workflow(
tasks=[
PromptTask(id="1", child_ids=["branch"]),
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]),
PromptTask(id="2", child_ids=["4"]),
PromptTask(id="3", child_ids=["4"]),
PromptTask(id="4"),
]
)
workflow.run()

assert workflow.find_task("1").state == BaseTask.State.FINISHED
assert workflow.find_task("branch").state == BaseTask.State.FINISHED
assert workflow.find_task("2").state == BaseTask.State.FINISHED
assert workflow.find_task("3").state == BaseTask.State.SKIPPED
assert workflow.find_task("4").state == BaseTask.State.FINISHED
assert workflow.is_finished()

def test_multi_branch(self):
def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]:
return ListArtifact([InfoArtifact("2"), InfoArtifact("3")])

workflow = Workflow(
tasks=[
PromptTask(id="1", child_ids=["branch"]),
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]),
PromptTask(id="2", child_ids=["4"]),
PromptTask(id="3", child_ids=["4"]),
PromptTask(id="4"),
]
)
workflow.run()

assert workflow.find_task("1").state == BaseTask.State.FINISHED
assert workflow.find_task("branch").state == BaseTask.State.FINISHED
assert workflow.find_task("2").state == BaseTask.State.FINISHED
assert workflow.find_task("3").state == BaseTask.State.FINISHED
assert workflow.find_task("4").state == BaseTask.State.FINISHED
assert workflow.is_finished()

def test_no_branch(self):
def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]:
return ListArtifact([])

workflow = Workflow(
tasks=[
PromptTask(id="1", child_ids=["branch"]),
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]),
PromptTask(id="2", child_ids=["4"]),
PromptTask(id="3", child_ids=["4"]),
PromptTask(id="4"),
]
)
workflow.run()

assert workflow.find_task("1").state == BaseTask.State.FINISHED
assert workflow.find_task("branch").state == BaseTask.State.FINISHED
assert workflow.find_task("2").state == BaseTask.State.SKIPPED
assert workflow.find_task("3").state == BaseTask.State.SKIPPED
assert workflow.find_task("4").state == BaseTask.State.SKIPPED
assert workflow.is_finished()

def test_no_structure(self):
def on_run(_: BranchTask) -> InfoArtifact:
return InfoArtifact("2")

task = BranchTask(input="foo", id="branch", on_run=on_run, child_ids=["2", "3"])
result = task.run()

assert result.value == "2"

def test_bad_branch(self):
def on_run(_: BranchTask) -> InfoArtifact:
return InfoArtifact("42")

workflow = Workflow(
tasks=[
PromptTask(id="1", child_ids=["branch"]),
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]),
PromptTask(id="2", child_ids=["4"]),
PromptTask(id="3", child_ids=["4"]),
PromptTask(id="4"),
]
)
workflow.run()

assert isinstance(workflow.tasks[1].output, ErrorArtifact)
assert workflow.tasks[1].output.value == "Branch task returned invalid child task id {'42'}"

assert workflow.find_task("1").state == BaseTask.State.FINISHED
assert workflow.find_task("branch").state == BaseTask.State.FINISHED
assert workflow.find_task("2").state == BaseTask.State.PENDING
assert workflow.find_task("3").state == BaseTask.State.PENDING
assert workflow.find_task("4").state == BaseTask.State.PENDING
assert not workflow.is_finished()
24 changes: 24 additions & 0 deletions tests/unit/utils/test_structure_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from griptape.artifacts import InfoArtifact, ListArtifact
from griptape.drivers import LocalStructureRunDriver
from griptape.structures import Agent, Pipeline, Workflow
from griptape.tasks import PromptTask, StructureRunTask
from griptape.tasks.branch_task import BranchTask
from griptape.utils import StructureVisualizer


Expand Down Expand Up @@ -79,3 +81,25 @@ def test_structure_run_task(self):

def test_build_node_id(self):
assert StructureVisualizer(Pipeline()).build_node_id(PromptTask("test1", id="test1")) == "Test1"

def test_branch_task(self):
def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]:
return ListArtifact([])

workflow = Workflow(
tasks=[
PromptTask(id="1", child_ids=["branch"]),
BranchTask(id="branch", on_run=on_run, child_ids=["2", "3"]),
PromptTask(id="2", child_ids=["4"]),
PromptTask(id="3", child_ids=["4"]),
PromptTask(id="4"),
]
)

visualizer = StructureVisualizer(workflow)
result = visualizer.to_url()

assert (
result
== "https://mermaid.ink/svg/Z3JhcGggVEQ7CgkxLS0+IEJyYW5jaDsKCUJyYW5jaHsgQnJhbmNoIH0tLi0+IDIgJiAzOwoJMi0tPiA0OwoJMy0tPiA0OwoJNDs="
)

0 comments on commit fe7d255

Please sign in to comment.