From ca39b3b6e65a3069b7879b0f19f48b49808ffaef Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 16 Dec 2024 11:26:06 -0800 Subject: [PATCH] Implement serialization in rulesets (#1451) --- CHANGELOG.md | 4 ++ griptape/mixins/rule_mixin.py | 13 +++-- griptape/rules/base_rule.py | 8 ++- griptape/rules/json_schema_rule.py | 4 +- griptape/rules/rule.py | 4 +- griptape/rules/ruleset.py | 9 +-- griptape/schemas/base_schema.py | 4 ++ tests/unit/mixins/test_rule_mixin.py | 75 ++++++++++++++++++++++++- tests/unit/structures/test_structure.py | 2 + tests/unit/tasks/test_tool_task.py | 1 + 10 files changed, 108 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index adf93ffd7..42246aa52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BranchTask` for selecting which Tasks (if any) to run based on a condition. - Support for `BranchTask` in `StructureVisualizer`. +### Changed + +- Rulesets can now be serialized and deserialized. + ### Fixed - Exception when calling `Structure.to_json()` after it has run. diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index b00504118..cae731865 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -1,22 +1,27 @@ from __future__ import annotations -from attrs import define, field +import uuid +from attrs import Factory, define, field + +from griptape.mixins.serializable_mixin import SerializableMixin from griptape.rules import BaseRule, Ruleset @define(slots=False) -class RuleMixin: +class RuleMixin(SerializableMixin): DEFAULT_RULESET_NAME = "Default Ruleset" - _rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets") + _rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets", metadata={"serializable": True}) rules: list[BaseRule] = field(factory=list, kw_only=True) + _default_ruleset_name: str = field(default=Factory(lambda: RuleMixin.DEFAULT_RULESET_NAME), kw_only=True) + _default_ruleset_id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @property def rulesets(self) -> list[Ruleset]: rulesets = self._rulesets.copy() if self.rules: - rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=self.rules)) + rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules)) return rulesets diff --git a/griptape/rules/base_rule.py b/griptape/rules/base_rule.py index a3880d49f..baf66dee3 100644 --- a/griptape/rules/base_rule.py +++ b/griptape/rules/base_rule.py @@ -5,10 +5,12 @@ from attrs import define, field +from griptape.mixins.serializable_mixin import SerializableMixin -@define(frozen=True) -class BaseRule(ABC): - value: Any = field() + +@define() +class BaseRule(ABC, SerializableMixin): + value: Any = field(metadata={"serializable": True}) meta: dict[str, Any] = field(factory=dict, kw_only=True) def __str__(self) -> str: diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py index ce41f26db..c068eb4a1 100644 --- a/griptape/rules/json_schema_rule.py +++ b/griptape/rules/json_schema_rule.py @@ -8,9 +8,9 @@ from griptape.utils import J2 -@define(frozen=True) +@define() class JsonSchemaRule(BaseRule): - value: dict = field() + value: dict = field(metadata={"serializable": True}) generate_template: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2"))) def to_text(self) -> str: diff --git a/griptape/rules/rule.py b/griptape/rules/rule.py index 952770adf..2abc55eed 100644 --- a/griptape/rules/rule.py +++ b/griptape/rules/rule.py @@ -5,9 +5,9 @@ from griptape.rules import BaseRule -@define(frozen=True) +@define() class Rule(BaseRule): - value: str = field() + value: str = field(metadata={"serializable": True}) def to_text(self) -> str: return self.value diff --git a/griptape/rules/ruleset.py b/griptape/rules/ruleset.py index a4194cf3f..970e937d5 100644 --- a/griptape/rules/ruleset.py +++ b/griptape/rules/ruleset.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.configs import Defaults +from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from collections.abc import Sequence @@ -15,8 +16,8 @@ @define -class Ruleset: - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) +class Ruleset(SerializableMixin): + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) name: str = field( default=Factory(lambda self: self.id, takes_self=True), metadata={"serializable": True}, @@ -24,8 +25,8 @@ class Ruleset: ruleset_driver: BaseRulesetDriver = field( default=Factory(lambda: Defaults.drivers_config.ruleset_driver), kw_only=True ) - meta: dict[str, Any] = field(factory=dict, kw_only=True) - rules: Sequence[BaseRule] = field(factory=list) + meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) + rules: Sequence[BaseRule] = field(factory=list, metadata={"serializable": True}) def __attrs_post_init__(self) -> None: rules, meta = self.ruleset_driver.load(self.name) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 4a049752b..0738e1258 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -174,6 +174,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: from griptape.memory import TaskMemory from griptape.memory.structure import BaseConversationMemory, Run from griptape.memory.task.storage import BaseArtifactStorage + from griptape.rules.base_rule import BaseRule + from griptape.rules.ruleset import Ruleset from griptape.structures import Structure from griptape.tasks import BaseTask from griptape.tokenizers import BaseTokenizer @@ -210,6 +212,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: "State": BaseTask.State, "BaseConversationMemory": BaseConversationMemory, "BaseArtifactStorage": BaseArtifactStorage, + "BaseRule": BaseRule, + "Ruleset": Ruleset, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel diff --git a/tests/unit/mixins/test_rule_mixin.py b/tests/unit/mixins/test_rule_mixin.py index cc45a2b26..19d3cbeac 100644 --- a/tests/unit/mixins/test_rule_mixin.py +++ b/tests/unit/mixins/test_rule_mixin.py @@ -1,5 +1,5 @@ from griptape.mixins.rule_mixin import RuleMixin -from griptape.rules import Rule, Ruleset +from griptape.rules import JsonSchemaRule, Rule, Ruleset from griptape.structures import Agent from griptape.tasks import PromptTask @@ -41,3 +41,76 @@ def test_inherits_structure_rulesets(self): agent.add_task(task) assert task.rulesets == [ruleset1, ruleset2] + + def test_to_dict(self): + mixin = RuleMixin( + rules=[ + Rule("foo"), + JsonSchemaRule( + { + "type": "object", + "properties": { + "foo": {"type": "string"}, + }, + "required": ["foo"], + } + ), + ], + rulesets=[Ruleset("bar", [Rule("baz")])], + ) + + assert mixin.to_dict() == { + "rulesets": [ + { + "id": mixin.rulesets[0].id, + "meta": {}, + "name": "bar", + "rules": [{"type": "Rule", "value": "baz"}], + "type": "Ruleset", + }, + { + "name": "Default Ruleset", + "id": mixin.rulesets[1].id, + "meta": {}, + "rules": [ + {"type": "Rule", "value": "foo"}, + { + "type": "JsonSchemaRule", + "value": { + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + ], + "type": "Ruleset", + }, + ], + "type": "RuleMixin", + } + + def test_from_dict(self): + mixin = RuleMixin( + rules=[ + Rule("foo"), + JsonSchemaRule( + { + "type": "object", + "properties": { + "foo": {"type": "string"}, + }, + "required": ["foo"], + } + ), + ], + rulesets=[Ruleset("bar", [Rule("baz")])], + ) + + new_mixin = RuleMixin.from_dict(mixin.to_dict()) + + for idx, _ in enumerate(new_mixin.rulesets): + rules = mixin.rulesets[idx].rules + new_rules = new_mixin.rulesets[idx].rules + for idx, _ in enumerate(rules): + assert rules[idx].value == new_rules[idx].value + assert rules[idx].meta == new_rules[idx].meta diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 38af9d6ab..756a011e6 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -72,8 +72,10 @@ def test_to_dict(self): "child_ids": agent.tasks[0].child_ids, "max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries, "context": agent.tasks[0].context, + "rulesets": [], } ], + "rulesets": [], "conversation_memory": { "type": agent.conversation_memory.type, "runs": agent.conversation_memory.runs, diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index cb2a6b341..598f85ed3 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -250,6 +250,7 @@ def test_to_dict(self): "child_ids": task.child_ids, "max_meta_memory_entries": task.max_meta_memory_entries, "context": task.context, + "rulesets": [], "tool": { "type": task.tool.type, "name": task.tool.name,