Skip to content

Commit

Permalink
Fix Printing Structure After It Has Run (#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Dec 13, 2024
1 parent fe7d255 commit 7073c50
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 136 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## Fixed

- Exception when calling `Structure.to_json()` after it has run.

### Added

- `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task.
Expand Down
4 changes: 3 additions & 1 deletion griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import uuid
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional

Expand Down Expand Up @@ -210,7 +211,8 @@ def try_run(self) -> BaseArtifact: ...

@property
def full_context(self) -> dict[str, Any]:
context = self.context
# Need to deep copy so that the serialized context doesn't contain non-serializable data
context = deepcopy(self.context)
if self.structure is not None:
context.update(self.structure.context(self))

Expand Down
45 changes: 0 additions & 45 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,51 +255,6 @@ def test_task_outputs(self):
assert len(agent.task_outputs) == 1
assert agent.task_outputs[task.id] == task.output

def test_to_dict(self):
task = PromptTask("test prompt")
agent = Agent(prompt_driver=MockPromptDriver())
agent.add_task(task)
expected_agent_dict = {
"type": "Agent",
"id": agent.id,
"tasks": [
{
"type": agent.tasks[0].type,
"id": agent.tasks[0].id,
"state": str(agent.tasks[0].state),
"parent_ids": agent.tasks[0].parent_ids,
"child_ids": agent.tasks[0].child_ids,
"max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries,
"context": agent.tasks[0].context,
}
],
"conversation_memory": {
"type": agent.conversation_memory.type,
"runs": agent.conversation_memory.runs,
"meta": agent.conversation_memory.meta,
"max_runs": agent.conversation_memory.max_runs,
},
"conversation_memory_strategy": agent.conversation_memory_strategy,
}
assert agent.to_dict() == expected_agent_dict

def test_from_dict(self):
task = PromptTask("test prompt")
agent = Agent(prompt_driver=MockPromptDriver())
agent.add_task(task)

serialized_agent = agent.to_dict()
assert isinstance(serialized_agent, dict)

deserialized_agent = Agent.from_dict(serialized_agent)
assert isinstance(deserialized_agent, Agent)

assert deserialized_agent.task_outputs[task.id] is None
deserialized_agent.run()

assert len(deserialized_agent.task_outputs) == 1
assert deserialized_agent.task_outputs[task.id].value == "mock output"

def test_runnable_mixin(self):
mock_on_before_run = Mock()
mock_after_run = Mock()
Expand Down
45 changes: 0 additions & 45 deletions tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,48 +411,3 @@ def test_task_outputs(self):
pipeline.run()
assert len(pipeline.task_outputs) == 1
assert pipeline.task_outputs[task.id] == task.output

def test_to_dict(self):
task = PromptTask("test")
pipeline = Pipeline()
pipeline + [task]
expected_pipeline_dict = {
"type": pipeline.type,
"id": pipeline.id,
"tasks": [
{
"type": pipeline.tasks[0].type,
"id": pipeline.tasks[0].id,
"state": str(pipeline.tasks[0].state),
"parent_ids": pipeline.tasks[0].parent_ids,
"child_ids": pipeline.tasks[0].child_ids,
"max_meta_memory_entries": pipeline.tasks[0].max_meta_memory_entries,
"context": pipeline.tasks[0].context,
}
],
"conversation_memory": {
"type": pipeline.conversation_memory.type,
"runs": pipeline.conversation_memory.runs,
"meta": pipeline.conversation_memory.meta,
"max_runs": pipeline.conversation_memory.max_runs,
},
"conversation_memory_strategy": pipeline.conversation_memory_strategy,
"fail_fast": pipeline.fail_fast,
}
assert pipeline.to_dict() == expected_pipeline_dict

def test_from_dict(self):
task = PromptTask("test")
pipeline = Pipeline(tasks=[task])

serialized_pipeline = pipeline.to_dict()
assert isinstance(serialized_pipeline, dict)

deserialized_pipeline = Pipeline.from_dict(serialized_pipeline)
assert isinstance(deserialized_pipeline, Pipeline)

assert deserialized_pipeline.task_outputs[task.id] is None
deserialized_pipeline.run()

assert len(deserialized_pipeline.task_outputs) == 1
assert deserialized_pipeline.task_outputs[task.id].value == "mock output"
63 changes: 63 additions & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from griptape.structures import Agent, Pipeline
from griptape.tasks import PromptTask
from tests.mocks.mock_prompt_driver import MockPromptDriver


class TestStructure:
Expand Down Expand Up @@ -54,3 +55,65 @@ def test_conversation_mode_per_task_no_memory(self):
assert pipeline.conversation_memory is not None
assert len(pipeline.conversation_memory.runs) == 1
assert pipeline.conversation_memory.runs[0].input.value == "2"

def test_to_dict(self):
task = PromptTask("test prompt")
agent = Agent(prompt_driver=MockPromptDriver())
agent.add_task(task)
expected_agent_dict = {
"type": "Agent",
"id": agent.id,
"tasks": [
{
"type": agent.tasks[0].type,
"id": agent.tasks[0].id,
"state": str(agent.tasks[0].state),
"parent_ids": agent.tasks[0].parent_ids,
"child_ids": agent.tasks[0].child_ids,
"max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries,
"context": agent.tasks[0].context,
}
],
"conversation_memory": {
"type": agent.conversation_memory.type,
"runs": agent.conversation_memory.runs,
"meta": agent.conversation_memory.meta,
"max_runs": agent.conversation_memory.max_runs,
},
"conversation_memory_strategy": agent.conversation_memory_strategy,
}
assert agent.to_dict() == expected_agent_dict

agent.run()

expected_agent_dict = {
**expected_agent_dict,
"tasks": [
{
**expected_agent_dict["tasks"][0],
"state": str(agent.tasks[0].state),
}
],
"conversation_memory": {
**expected_agent_dict["conversation_memory"],
"runs": agent.conversation_memory.to_dict()["runs"],
},
}
assert agent.to_dict() == expected_agent_dict

def test_from_dict(self):
task = PromptTask("test prompt")
agent = Agent(prompt_driver=MockPromptDriver())
agent.add_task(task)

serialized_agent = agent.to_dict()
assert isinstance(serialized_agent, dict)

deserialized_agent = Agent.from_dict(serialized_agent)
assert isinstance(deserialized_agent, Agent)

assert deserialized_agent.task_outputs[task.id] is None
deserialized_agent.run()

assert len(deserialized_agent.task_outputs) == 1
assert deserialized_agent.task_outputs[task.id].value == "mock output"
45 changes: 0 additions & 45 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,48 +979,3 @@ def test_task_outputs(self):

assert len(workflow.task_outputs) == 1
assert workflow.task_outputs[task.id].value == "mock output"

def test_to_dict(self):
task = PromptTask("test")
workflow = Workflow(tasks=[task])

expected_workflow_dict = {
"type": workflow.type,
"id": workflow.id,
"tasks": [
{
"type": workflow.tasks[0].type,
"id": workflow.tasks[0].id,
"state": str(workflow.tasks[0].state),
"parent_ids": workflow.tasks[0].parent_ids,
"child_ids": workflow.tasks[0].child_ids,
"max_meta_memory_entries": workflow.tasks[0].max_meta_memory_entries,
"context": workflow.tasks[0].context,
}
],
"conversation_memory": {
"type": workflow.conversation_memory.type,
"runs": workflow.conversation_memory.runs,
"meta": workflow.conversation_memory.meta,
"max_runs": workflow.conversation_memory.max_runs,
},
"conversation_memory_strategy": workflow.conversation_memory_strategy,
"fail_fast": workflow.fail_fast,
}
assert workflow.to_dict() == expected_workflow_dict

def test_from_dict(self):
task = PromptTask("test")
workflow = Workflow(tasks=[task])

serialized_workflow = workflow.to_dict()
assert isinstance(serialized_workflow, dict)

deserialized_workflow = Workflow.from_dict(serialized_workflow)
assert isinstance(deserialized_workflow, Workflow)

assert deserialized_workflow.task_outputs[task.id] is None
deserialized_workflow.run()

assert len(deserialized_workflow.task_outputs) == 1
assert deserialized_workflow.task_outputs[task.id].value == "mock output"

0 comments on commit 7073c50

Please sign in to comment.