From 7c52ed15e5c467564e6db02b43436b0ac65f9e3f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 14 Oct 2024 13:29:51 -0700 Subject: [PATCH] Structure/task context improvements --- CHANGELOG.md | 4 ++++ docs/griptape-framework/structures/pipelines.md | 7 ++++--- docs/griptape-framework/structures/workflows.md | 7 ++++--- griptape/structures/pipeline.py | 3 ++- griptape/structures/structure.py | 4 ++++ griptape/structures/workflow.py | 1 + griptape/tasks/base_task.py | 4 ++-- tests/unit/structures/test_agent.py | 13 +++++++++++++ tests/unit/structures/test_pipeline.py | 15 ++++++++++++++- tests/unit/structures/test_workflow.py | 17 +++++++++++++++-- tests/unit/tasks/test_base_task.py | 5 ++--- tests/unit/tasks/test_base_text_input_task.py | 2 +- 12 files changed, 66 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0dc4ba21..7b8c7059e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Request/response debug logging to all Prompt Drivers. - `BaseEventListener.flush_events()` to flush events from an Event Listener. - Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing. +- `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`. ### Changed - **BREAKING**: `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead. +- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. +- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. +- `Pipeline.context["parent_output"]` has changed type from `str | None` to `BaseArtifact | None`. - `_DefaultsConfig.logging_config` and `Defaults.drivers_config` are now lazily instantiated. - `BaseTask.add_parent`/`BaseTask.add_child` now only add the parent/child task to the structure if it is not already present. - `BaseEventListener.flush_events()` to flush events from an Event Listener. diff --git a/docs/griptape-framework/structures/pipelines.md b/docs/griptape-framework/structures/pipelines.md index fd293be30..28910863e 100644 --- a/docs/griptape-framework/structures/pipelines.md +++ b/docs/griptape-framework/structures/pipelines.md @@ -13,9 +13,10 @@ You can access the final output of the Pipeline by using the [output](../../refe Pipelines have access to the following [context](../../reference/griptape/structures/pipeline.md#griptape.structures.pipeline.Pipeline.context) variables in addition to the [base context](./tasks.md#context). -- `parent_output`: output from the parent. -- `parent`: parent task. -- `child`: child task. +- `task_outputs`: dictionary containing mapping of all task IDs to their outputs. +- `parent_output`: output from the parent task if one exists, otherwise `None`. +- `parent`: parent task if one exists, otherwise `None`. +- `child`: child task if one exists, otherwise `None`. ## Pipeline diff --git a/docs/griptape-framework/structures/workflows.md b/docs/griptape-framework/structures/workflows.md index 85a4cec68..f1eff199a 100644 --- a/docs/griptape-framework/structures/workflows.md +++ b/docs/griptape-framework/structures/workflows.md @@ -13,10 +13,11 @@ You can access the final output of the Workflow by using the [output](../../refe Workflows have access to the following [context](../../reference/griptape/structures/workflow.md#griptape.structures.workflow.Workflow.context) variables in addition to the [base context](./tasks.md#context): -- `parent_outputs`: dictionary containing mapping of parent IDs to their outputs. +- `task_outputs`: dictionary containing mapping of all task IDs to their outputs. +- `parent_outputs`: dictionary containing mapping of parent task IDs to their outputs. - `parents_output_text`: string containing the concatenated outputs of all parent tasks. -- `parents`: parent tasks referenceable by IDs. -- `children`: child tasks referenceable by IDs. +- `parents`: dictionary containing mapping of parent task IDs to their task objects. +- `children`: dictionary containing mapping of child task IDs to their task objects. ## Workflow diff --git a/griptape/structures/pipeline.py b/griptape/structures/pipeline.py index a5134a964..497e77bda 100644 --- a/griptape/structures/pipeline.py +++ b/griptape/structures/pipeline.py @@ -59,7 +59,8 @@ def context(self, task: BaseTask) -> dict[str, Any]: context.update( { - "parent_output": task.parents[0].output.to_text() if task.parents and task.parents[0].output else None, + "parent_output": task.parents[0].output if task.parents else None, + "task_outputs": self.task_outputs, "parent": task.parents[0] if task.parents else None, "child": task.children[0] if task.children else None, }, diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index d2e6d2f3f..29ee6281c 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -72,6 +72,10 @@ def output(self) -> BaseArtifact: raise ValueError("Structure's output Task has no output. Run the Structure to generate output.") return self.output_task.output + @property + def task_outputs(self) -> dict[str, Optional[BaseArtifact]]: + return {task.id: task.output for task in self.tasks} + @property def finished_tasks(self) -> list[BaseTask]: return [s for s in self.tasks if s.is_finished()] diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 1c65c59c4..99af20dc2 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -125,6 +125,7 @@ def context(self, task: BaseTask) -> dict[str, Any]: context.update( { + "task_outputs": self.task_outputs, "parent_outputs": task.parent_outputs, "parents_output_text": task.parents_output_text, "parents": {parent.id: parent for parent in task.parents}, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index a8ed10829..ff1f6a11d 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -69,8 +69,8 @@ def children(self) -> list[BaseTask]: raise ValueError("Structure must be set to access children") @property - def parent_outputs(self) -> dict[str, str]: - return {parent.id: parent.output.to_text() if parent.output else "" for parent in self.parents} + def parent_outputs(self) -> dict[str, BaseArtifact]: + return {parent.id: parent.output for parent in self.parents if parent.output} @property def parents_output_text(self) -> str: diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index d522a43a2..97697749e 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -239,3 +239,16 @@ def finished_tasks(self): def test_fail_fast(self): with pytest.raises(ValueError): Agent(prompt_driver=MockPromptDriver(), fail_fast=True) + + def test_task_outputs(self): + task = PromptTask("test prompt") + agent = Agent(prompt_driver=MockPromptDriver()) + + agent.add_task(task) + + assert len(agent.task_outputs) == 1 + assert agent.task_outputs[task.id] is None + agent.run("hello") + + assert len(agent.task_outputs) == 1 + assert agent.task_outputs[task.id] == task.output diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index e6fa6d770..b2580de4c 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -360,7 +360,8 @@ def test_context(self): context = pipeline.context(task) - assert context["parent_output"] == parent.output.to_text() + assert context["parent_output"] == parent.output + assert context["task_outputs"] == pipeline.task_outputs assert context["structure"] == pipeline assert context["parent"] == parent assert context["child"] == child @@ -398,3 +399,15 @@ def test_add_duplicate_task_directly(self): with pytest.raises(ValueError, match=f"Duplicate task with id {task.id} found."): pipeline.run() + + def test_task_outputs(self): + task = PromptTask("test") + pipeline = Pipeline() + + pipeline + [task] + + assert len(pipeline.task_outputs) == 1 + assert pipeline.task_outputs[task.id] is None + pipeline.run() + assert len(pipeline.task_outputs) == 1 + assert pipeline.task_outputs[task.id] == task.output diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index e3bd8e886..1a9b4e2d1 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -737,13 +737,14 @@ def test_context(self): context = workflow.context(task) - assert context["parent_outputs"] == {parent.id: ""} + assert context["parent_outputs"] == {} workflow.run() context = workflow.context(task) - assert context["parent_outputs"] == {parent.id: parent.output.to_text()} + assert context["task_outputs"] == workflow.task_outputs + assert context["parent_outputs"] == {parent.id: parent.output} assert context["parents_output_text"] == "mock output" assert context["structure"] == workflow assert context["parents"] == {parent.id: parent} @@ -966,3 +967,15 @@ def _validate_topology_4(workflow) -> None: publish_website = workflow.find_task("publish_website") assert publish_website.parent_ids == ["compare_movies"] assert publish_website.child_ids == ["summarize_to_slack"] + + def test_task_outputs(self): + task = PromptTask("test") + workflow = Workflow(tasks=[task]) + + assert len(workflow.task_outputs) == 1 + assert workflow.task_outputs[task.id] is None + + workflow.run() + + assert len(workflow.task_outputs) == 1 + assert workflow.task_outputs[task.id].value == "mock output" diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index cf5d3e847..cd94aeef6 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -52,9 +52,8 @@ def test_parent_outputs(self, task): parent_3.output = None assert child.parent_outputs == { - parent_1.id: parent_1.output.to_text(), - parent_2.id: parent_2.output.to_text(), - parent_3.id: "", + parent_1.id: parent_1.output, + parent_2.id: parent_2.output, } def test_parents_output(self, task): diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 0adab72ff..1be3904c5 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -44,7 +44,7 @@ def test_full_context(self): context = subtask.full_context assert context["foo"] == "bar" - assert context["parent_output"] == parent.output.to_text() + assert context["parent_output"] == parent.output assert context["structure"] == pipeline assert context["parent"] == parent assert context["child"] == child