Skip to content

Commit

Permalink
Change conversation_memory_strategy from enum to string (#1427)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Dec 12, 2024
1 parent 5c1b89b commit b936419
Show file tree
Hide file tree
Showing 11 changed files with 16 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task.
- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `Structure.ConversationMemoryStrategy.PER_STRUCTURE`.
- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `per_structure`.

## [1.0.0] - 2024-12-09

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/conversation-memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ In this example, the `improve` Task is "forgotten" after the Structure's run is

#### Per Task

You can change when Conversation Memory Runs are created by modifying [Structure.conversation_memory_strategy](../../reference/griptape/structures/structure.md#griptape.structures.Structure.conversation_memory_strategy) from the default [PER_STRUCTURE](../../reference/griptape/structures/structure.md#griptape.structures.structure.ConversationMemoryStrategy.PER_STRUCTURE) to [PER_TASK](../../reference/griptape/structures/structure.md#griptape.structures.structure.ConversationMemoryStrategy.PER_TASK).
You can change when Conversation Memory Runs are created by modifying [Structure.conversation_memory_strategy](../../reference/griptape/structures/structure.md#griptape.structures.Structure.conversation_memory_strategy) from the default `per_structure` to `per_task`.

```python
--8<-- "docs/griptape-framework/structures/src/conversation_memory_per_task.py"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from griptape.tasks import PromptTask

pipeline = Pipeline(
conversation_memory_strategy=Pipeline.ConversationMemoryStrategy.PER_TASK,
conversation_memory_strategy="per_task",
tasks=[
PromptTask("Respond to this request: {{ args[0] }}", id="input"),
PromptTask("Improve the writing", id="improve"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from griptape.tasks import PromptTask

pipeline = Pipeline(
conversation_memory_strategy=Pipeline.ConversationMemoryStrategy.PER_TASK,
conversation_memory_strategy="per_task",
tasks=[
PromptTask("Respond to this request: {{ args[0] }}", id="input"),
PromptTask("Improve the writing of this: {{ parent_output }}", id="improve", conversation_memory=None),
Expand Down
1 change: 0 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"Sequence": Sequence,
"TaskMemory": TaskMemory,
"State": BaseTask.State,
"ConversationMemoryStrategy": Structure.ConversationMemoryStrategy,
"BaseConversationMemory": BaseConversationMemory,
"BaseArtifactStorage": BaseArtifactStorage,
# Third party modules
Expand Down
13 changes: 4 additions & 9 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import uuid
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

from attrs import Factory, define, field

Expand All @@ -24,10 +23,6 @@

@define
class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
class ConversationMemoryStrategy(Enum):
PER_STRUCTURE = 1
PER_TASK = 2

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
_tasks: list[Union[BaseTask, list[BaseTask]]] = field(
factory=list, kw_only=True, alias="tasks", metadata={"serializable": True}
Expand All @@ -37,8 +32,8 @@ class ConversationMemoryStrategy(Enum):
kw_only=True,
metadata={"serializable": True},
)
conversation_memory_strategy: ConversationMemoryStrategy = field(
default=ConversationMemoryStrategy.PER_STRUCTURE, kw_only=True, metadata={"serializable": True}
conversation_memory_strategy: Literal["per_structure", "per_task"] = field(
default="per_structure", kw_only=True, metadata={"serializable": True}
)
task_memory: TaskMemory = field(
default=Factory(lambda self: TaskMemory(), takes_self=True),
Expand Down Expand Up @@ -172,7 +167,7 @@ def after_run(self) -> None:

if self.output_task is not None:
if (
self.conversation_memory_strategy == self.ConversationMemoryStrategy.PER_STRUCTURE
self.conversation_memory_strategy == "per_structure"
and self.conversation_memory is not None
and self.input_task is not None
and self.output_task.output is not None
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def after_run(self) -> None:
structure = self.structure
if (
structure is not None
and structure.conversation_memory_strategy == structure.ConversationMemoryStrategy.PER_TASK
and structure.conversation_memory_strategy == "per_task"
and self.conversation_memory is not None
and self.output is not None
):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_to_dict(self):
"meta": agent.conversation_memory.meta,
"max_runs": agent.conversation_memory.max_runs,
},
"conversation_memory_strategy": str(agent.conversation_memory_strategy),
"conversation_memory_strategy": agent.conversation_memory_strategy,
}
assert agent.to_dict() == expected_agent_dict

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_to_dict(self):
"meta": pipeline.conversation_memory.meta,
"max_runs": pipeline.conversation_memory.max_runs,
},
"conversation_memory_strategy": str(pipeline.conversation_memory_strategy),
"conversation_memory_strategy": pipeline.conversation_memory_strategy,
"fail_fast": pipeline.fail_fast,
}
assert pipeline.to_dict() == expected_pipeline_dict
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from griptape.structures import Agent, Pipeline, Structure
from griptape.structures import Agent, Pipeline
from griptape.tasks import PromptTask


Expand All @@ -20,7 +20,7 @@ def test_output(self):

def test_conversation_mode_per_structure(self):
pipeline = Pipeline(
conversation_memory_strategy=Structure.ConversationMemoryStrategy.PER_STRUCTURE,
conversation_memory_strategy="per_structure",
tasks=[PromptTask("1"), PromptTask("2")],
)

Expand All @@ -32,7 +32,7 @@ def test_conversation_mode_per_structure(self):

def test_conversation_mode_per_task(self):
pipeline = Pipeline(
conversation_memory_strategy=Structure.ConversationMemoryStrategy.PER_TASK,
conversation_memory_strategy="per_task",
tasks=[PromptTask("1"), PromptTask("2")],
)

Expand All @@ -46,7 +46,7 @@ def test_conversation_mode_per_task(self):

def test_conversation_mode_per_task_no_memory(self):
pipeline = Pipeline(
conversation_memory_strategy=Structure.ConversationMemoryStrategy.PER_TASK,
conversation_memory_strategy="per_task",
tasks=[PromptTask(conversation_memory=None), PromptTask("2")],
)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def test_to_dict(self):
"meta": workflow.conversation_memory.meta,
"max_runs": workflow.conversation_memory.max_runs,
},
"conversation_memory_strategy": str(workflow.conversation_memory_strategy),
"conversation_memory_strategy": workflow.conversation_memory_strategy,
"fail_fast": workflow.fail_fast,
}
assert workflow.to_dict() == expected_workflow_dict
Expand Down

0 comments on commit b936419

Please sign in to comment.