From bf000887799429869b0a0e254f1ed3d553b2f098 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 14 Aug 2024 09:02:31 -0700 Subject: [PATCH] Modify arguments/return type of add_parent and add_child (#1048) --- CHANGELOG.md | 3 ++ .../structures/workflows.md | 28 +++++++++++ griptape/tasks/actions_subtask.py | 20 ++++---- griptape/tasks/base_task.py | 46 +++++++++++-------- tests/unit/structures/test_workflow.py | 14 +++--- tests/unit/tasks/test_base_task.py | 36 +++++++++++++++ 6 files changed, 111 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef91010ae..8b69e7f32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. - Global config, `griptape.config.config`, for setting global configuration defaults. - Unique name generation for all `RagEngine` modules. +- Support for bitshift composition in `BaseTask` for adding parent/child tasks. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. @@ -31,8 +32,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Dropped `Client` from all Tool names for better naming consistency. - **BREAKING**: Dropped `_client` suffix from all Tool packages. - **BREAKING**: Added `Tool` suffix to all Tool names for better naming consistency. +- **BREAKING**: `BaseTask.add_parent/child` now take a `BaseTask` instead of `str | BaseTask`. - Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. +- `BaseTask.add_parent/child` now returns `self`, allowing for chaining. ## [0.29.1] - 2024-08-02 diff --git a/docs/griptape-framework/structures/workflows.md b/docs/griptape-framework/structures/workflows.md index 5f7c271fc..87bb2ae89 100644 --- a/docs/griptape-framework/structures/workflows.md +++ b/docs/griptape-framework/structures/workflows.md @@ -201,3 +201,31 @@ output: [06/18/24 09:52:23] INFO PromptTask new-animal Output: elephant ``` + +### Bitshift Composition + +Task relationships can also be set up with the Python bitshift operators `>>` and `<<`. The following four statements are all functionally equivalent: + +```python +task1 >> task2 +task1.add_child(task2) + +task2 << task1 +task2.add_parent(task1) +``` + +When using the bitshift to compose operators, the relationship is set in the direction that the bitshift operator points. +For example, `task1 >> task2` means that `task1` runs first and `task2` runs second. +Multiple operators can be composed – keep in mind the chain is executed left-to-right and the rightmost object is always returned. For example: + +```python +task1 >> task2 >> task3 << task4 +``` + +is equivalent to: + +```python +task1.add_child(task2) +task2.add_child(task3) +task3.add_parent(task4) +``` diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 0f885d260..7d2fb5efd 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -68,17 +68,15 @@ def children(self) -> list[BaseTask]: else: raise Exception("ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin.") - def add_child(self, child: str | BaseTask) -> None: - child_id = child if isinstance(child, str) else child.id - - if child_id not in self.child_ids: - self.child_ids.append(child_id) - - def add_parent(self, parent: str | BaseTask) -> None: - parent_id = parent if isinstance(parent, str) else parent.id - - if parent_id not in self.parent_ids: - self.parent_ids.append(parent_id) + def add_child(self, child: BaseTask) -> BaseTask: + if child.id not in self.child_ids: + self.child_ids.append(child.id) + return child + + def add_parent(self, parent: BaseTask) -> BaseTask: + if parent.id not in self.parent_ids: + self.parent_ids.append(parent.id) + return parent def attach_to(self, parent_task: BaseTask) -> None: self.parent_task_id = parent_task.id diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index b3086bebb..1899eccd6 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -42,6 +42,16 @@ class State(Enum): kw_only=True, ) + def __rshift__(self, other: BaseTask) -> BaseTask: + self.add_child(other) + + return other + + def __lshift__(self, other: BaseTask) -> BaseTask: + self.add_parent(other) + + return other + def __attrs_post_init__(self) -> None: if self.structure is not None: self.structure.add_task(self) @@ -83,37 +93,37 @@ def meta_memories(self) -> list[BaseMetaEntry]: def __str__(self) -> str: return str(self.output.value) - def add_parents(self, parents: list[str | BaseTask]) -> None: + def add_parents(self, parents: list[BaseTask]) -> None: for parent in parents: self.add_parent(parent) - def add_parent(self, parent: str | BaseTask) -> None: - parent_id = parent if isinstance(parent, str) else parent.id + def add_parent(self, parent: BaseTask) -> BaseTask: + if parent.id not in self.parent_ids: + self.parent_ids.append(parent.id) - if parent_id not in self.parent_ids: - self.parent_ids.append(parent_id) + if self.id not in parent.child_ids: + parent.child_ids.append(self.id) - if isinstance(parent, BaseTask): - parent.add_child(self.id) + if self.structure is not None: + self.structure.add_task(parent) - if self.structure is not None: - self.structure.add_task(parent) + return self - def add_children(self, children: list[str | BaseTask]) -> None: + def add_children(self, children: list[BaseTask]) -> None: for child in children: self.add_child(child) - def add_child(self, child: str | BaseTask) -> None: - child_id = child if isinstance(child, str) else child.id + def add_child(self, child: BaseTask) -> BaseTask: + if child.id not in self.child_ids: + self.child_ids.append(child.id) - if child_id not in self.child_ids: - self.child_ids.append(child_id) + if self.id not in child.parent_ids: + child.parent_ids.append(self.id) - if isinstance(child, BaseTask): - child.add_parent(self.id) + if self.structure is not None: + self.structure.add_task(child) - if self.structure is not None: - self.structure.add_task(child) + return self def preprocess(self, structure: Structure) -> BaseTask: self.structure = structure diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 79c9868e1..45610bf20 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -277,8 +277,8 @@ def test_run_topology_1_imperative_parents(self): task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") task2.add_parent(task1) - task3.add_parent("task1") - task4.add_parents([task2, "task3"]) + task3.add_parent(task1) + task4.add_parents([task2, task3]) workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -306,8 +306,8 @@ def test_run_topology_1_imperative_parents_structure_init(self): task3 = PromptTask("test3", id="task3", structure=workflow) task4 = PromptTask("test4", id="task4", structure=workflow) task2.add_parent(task1) - task3.add_parent("task1") - task4.add_parents([task2, "task3"]) + task3.add_parent(task1) + task4.add_parents([task2, task3]) workflow.run() @@ -432,9 +432,9 @@ def test_run_topology_2_imperative_parents(self): taskd = PromptTask("testd", id="taskd") taske = PromptTask("teste", id="taske") taskb.add_parent(taska) - taskc.add_parent("taska") + taskc.add_parent(taska) taskd.add_parents([taska, taskb, taskc]) - taske.add_parents(["taska", taskd, "taskc"]) + taske.add_parents([taska, taskd, taskc]) workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -466,7 +466,7 @@ def test_run_topology_2_imperative_mixed(self): taska.add_children([taskb, taskc, taskd, taske]) taskb.add_child(taskd) taskd.add_parent(taskc) - taske.add_parents(["taska", taskd, "taskc"]) + taske.add_parents([taska, taskd, taskc]) workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 1b45b4e98..e87a99ccb 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -116,3 +116,39 @@ def test_execute_publish_events(self, task): task.execute() assert event_bus.event_listeners[0].handler.call_count == 2 + + def test_add_parent(self, task): + parent = MockTask("parent foobar", id="parent_foobar") + + result = task.add_parent(parent) + + assert parent.id in task.parent_ids + assert task.id in parent.child_ids + assert result == task + + def test_add_child(self, task): + child = MockTask("child foobar", id="child_foobar") + + result = task.add_child(child) + + assert child.id in task.child_ids + assert task.id in child.parent_ids + assert result == task + + def test_add_parent_bitshift(self, task): + parent = MockTask("parent foobar", id="parent_foobar") + + added_task = task << parent + + assert parent.id in task.parent_ids + assert task.id in parent.child_ids + assert added_task == parent + + def test_add_child_bitshift(self, task): + child = MockTask("child foobar", id="child_foobar") + + added_task = task >> child + + assert child.id in task.child_ids + assert task.id in child.parent_ids + assert added_task == child