Skip to content

Commit

Permalink
Structure/task context improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 14, 2024
1 parent a68dbb9 commit 7c52ed1
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Request/response debug logging to all Prompt Drivers.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.
- `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`.

### Changed

- **BREAKING**: `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Pipeline.context["parent_output"]` has changed type from `str | None` to `BaseArtifact | None`.
- `_DefaultsConfig.logging_config` and `Defaults.drivers_config` are now lazily instantiated.
- `BaseTask.add_parent`/`BaseTask.add_child` now only add the parent/child task to the structure if it is not already present.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
Expand Down
7 changes: 4 additions & 3 deletions docs/griptape-framework/structures/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ You can access the final output of the Pipeline by using the [output](../../refe

Pipelines have access to the following [context](../../reference/griptape/structures/pipeline.md#griptape.structures.pipeline.Pipeline.context) variables in addition to the [base context](./tasks.md#context).

- `parent_output`: output from the parent.
- `parent`: parent task.
- `child`: child task.
- `task_outputs`: dictionary containing mapping of all task IDs to their outputs.
- `parent_output`: output from the parent task if one exists, otherwise `None`.
- `parent`: parent task if one exists, otherwise `None`.
- `child`: child task if one exists, otherwise `None`.

## Pipeline

Expand Down
7 changes: 4 additions & 3 deletions docs/griptape-framework/structures/workflows.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ You can access the final output of the Workflow by using the [output](../../refe

Workflows have access to the following [context](../../reference/griptape/structures/workflow.md#griptape.structures.workflow.Workflow.context) variables in addition to the [base context](./tasks.md#context):

- `parent_outputs`: dictionary containing mapping of parent IDs to their outputs.
- `task_outputs`: dictionary containing mapping of all task IDs to their outputs.
- `parent_outputs`: dictionary containing mapping of parent task IDs to their outputs.
- `parents_output_text`: string containing the concatenated outputs of all parent tasks.
- `parents`: parent tasks referenceable by IDs.
- `children`: child tasks referenceable by IDs.
- `parents`: dictionary containing mapping of parent task IDs to their task objects.
- `children`: dictionary containing mapping of child task IDs to their task objects.

## Workflow

Expand Down
3 changes: 2 additions & 1 deletion griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def context(self, task: BaseTask) -> dict[str, Any]:

context.update(
{
"parent_output": task.parents[0].output.to_text() if task.parents and task.parents[0].output else None,
"parent_output": task.parents[0].output if task.parents else None,
"task_outputs": self.task_outputs,
"parent": task.parents[0] if task.parents else None,
"child": task.children[0] if task.children else None,
},
Expand Down
4 changes: 4 additions & 0 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def output(self) -> BaseArtifact:
raise ValueError("Structure's output Task has no output. Run the Structure to generate output.")
return self.output_task.output

@property
def task_outputs(self) -> dict[str, Optional[BaseArtifact]]:
return {task.id: task.output for task in self.tasks}

@property
def finished_tasks(self) -> list[BaseTask]:
return [s for s in self.tasks if s.is_finished()]
Expand Down
1 change: 1 addition & 0 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def context(self, task: BaseTask) -> dict[str, Any]:

context.update(
{
"task_outputs": self.task_outputs,
"parent_outputs": task.parent_outputs,
"parents_output_text": task.parents_output_text,
"parents": {parent.id: parent for parent in task.parents},
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def children(self) -> list[BaseTask]:
raise ValueError("Structure must be set to access children")

@property
def parent_outputs(self) -> dict[str, str]:
return {parent.id: parent.output.to_text() if parent.output else "" for parent in self.parents}
def parent_outputs(self) -> dict[str, BaseArtifact]:
return {parent.id: parent.output for parent in self.parents if parent.output}

@property
def parents_output_text(self) -> str:
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,16 @@ def finished_tasks(self):
def test_fail_fast(self):
with pytest.raises(ValueError):
Agent(prompt_driver=MockPromptDriver(), fail_fast=True)

def test_task_outputs(self):
task = PromptTask("test prompt")
agent = Agent(prompt_driver=MockPromptDriver())

agent.add_task(task)

assert len(agent.task_outputs) == 1
assert agent.task_outputs[task.id] is None
agent.run("hello")

assert len(agent.task_outputs) == 1
assert agent.task_outputs[task.id] == task.output
15 changes: 14 additions & 1 deletion tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ def test_context(self):

context = pipeline.context(task)

assert context["parent_output"] == parent.output.to_text()
assert context["parent_output"] == parent.output
assert context["task_outputs"] == pipeline.task_outputs
assert context["structure"] == pipeline
assert context["parent"] == parent
assert context["child"] == child
Expand Down Expand Up @@ -398,3 +399,15 @@ def test_add_duplicate_task_directly(self):

with pytest.raises(ValueError, match=f"Duplicate task with id {task.id} found."):
pipeline.run()

def test_task_outputs(self):
task = PromptTask("test")
pipeline = Pipeline()

pipeline + [task]

assert len(pipeline.task_outputs) == 1
assert pipeline.task_outputs[task.id] is None
pipeline.run()
assert len(pipeline.task_outputs) == 1
assert pipeline.task_outputs[task.id] == task.output
17 changes: 15 additions & 2 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,13 +737,14 @@ def test_context(self):

context = workflow.context(task)

assert context["parent_outputs"] == {parent.id: ""}
assert context["parent_outputs"] == {}

workflow.run()

context = workflow.context(task)

assert context["parent_outputs"] == {parent.id: parent.output.to_text()}
assert context["task_outputs"] == workflow.task_outputs
assert context["parent_outputs"] == {parent.id: parent.output}
assert context["parents_output_text"] == "mock output"
assert context["structure"] == workflow
assert context["parents"] == {parent.id: parent}
Expand Down Expand Up @@ -966,3 +967,15 @@ def _validate_topology_4(workflow) -> None:
publish_website = workflow.find_task("publish_website")
assert publish_website.parent_ids == ["compare_movies"]
assert publish_website.child_ids == ["summarize_to_slack"]

def test_task_outputs(self):
task = PromptTask("test")
workflow = Workflow(tasks=[task])

assert len(workflow.task_outputs) == 1
assert workflow.task_outputs[task.id] is None

workflow.run()

assert len(workflow.task_outputs) == 1
assert workflow.task_outputs[task.id].value == "mock output"
5 changes: 2 additions & 3 deletions tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def test_parent_outputs(self, task):

parent_3.output = None
assert child.parent_outputs == {
parent_1.id: parent_1.output.to_text(),
parent_2.id: parent_2.output.to_text(),
parent_3.id: "",
parent_1.id: parent_1.output,
parent_2.id: parent_2.output,
}

def test_parents_output(self, task):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_full_context(self):
context = subtask.full_context

assert context["foo"] == "bar"
assert context["parent_output"] == parent.output.to_text()
assert context["parent_output"] == parent.output
assert context["structure"] == pipeline
assert context["parent"] == parent
assert context["child"] == child
Expand Down

0 comments on commit 7c52ed1

Please sign in to comment.