diff --git a/CHANGELOG.md b/CHANGELOG.md index d29a5fb4d1..fa886636ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s. - **BREAKING**: `JsonExtractionEngine.template_schema` is now required. - **BREAKING**: `CsvExtractionEngine.column_names` is now required. +- **BREAKING**: Renamed`RuleMixin.all_rulesets` to `RuleMixin.rulesets`. - `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`. - `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`. - Remove `manifest.yml` requirements for custom tool creation. @@ -73,6 +74,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Anthropic native Tool calling. - Empty `ActionsSubtask.thought` being logged. +- `RuleMixin` no longer prevents setting `rulesets` _and_ `rules` at the same time. +- `PromptTask` will merge in its Structure's Rulesets and Rules. +- `PromptTask` not checking whether Structure was set before building Prompt Stack. ## [0.32.0] - 2024-09-17 diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index fbb8ed7e60..b62a0eba3c 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -56,8 +56,8 @@ def run(self, context: RagContext) -> BaseArtifact: def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} - if len(self.all_rulesets) > 0: - params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets) + if len(self.rulesets) > 0: + params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) if self.metadata is not None: params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index 7fe6a6346d..3e111a8f51 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -1,56 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional - -from attrs import Attribute, define, field +from attrs import define, field from griptape.rules import BaseRule, Ruleset -if TYPE_CHECKING: - from griptape.structures import Structure - @define(slots=False) class RuleMixin: DEFAULT_RULESET_NAME = "Default Ruleset" - ADDITIONAL_RULESET_NAME = "Additional Ruleset" - rulesets: list[Ruleset] = field(factory=list, kw_only=True) + _rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets") rules: list[BaseRule] = field(factory=list, kw_only=True) - structure: Optional[Structure] = field(default=None, kw_only=True) - - @rulesets.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None: - if not rulesets: - return - - if self.rules: - raise ValueError("Can't have both rulesets and rules specified.") - - @rules.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_rules(self, _: Attribute, rules: list[BaseRule]) -> None: - if not rules: - return - - if self.rulesets: - raise ValueError("Can't have both rules and rulesets specified.") @property - def all_rulesets(self) -> list[Ruleset]: - structure_rulesets = [] - - if self.structure: - if self.structure.rulesets: - structure_rulesets = self.structure.rulesets - elif self.structure.rules: - structure_rulesets = [Ruleset(name=self.DEFAULT_RULESET_NAME, rules=self.structure.rules)] + def rulesets(self) -> list[Ruleset]: + rulesets = self._rulesets - task_rulesets = [] - if self.rulesets: - task_rulesets = self.rulesets - elif self.rules: - task_ruleset_name = self.ADDITIONAL_RULESET_NAME if structure_rulesets else self.DEFAULT_RULESET_NAME - - task_rulesets = [Ruleset(name=task_ruleset_name, rules=self.rules)] + if self.rules: + rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=self.rules)) - return structure_rulesets + task_rulesets + return rulesets diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index d5a4f4fc0b..d2e6d2f3fb 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -4,26 +4,24 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional -from attrs import Attribute, Factory, define, field +from attrs import Factory, define, field from griptape.common import observable from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory, Run +from griptape.mixins.rule_mixin import RuleMixin if TYPE_CHECKING: from griptape.artifacts import BaseArtifact from griptape.memory.structure import BaseConversationMemory - from griptape.rules import BaseRule, Rule, Ruleset from griptape.tasks import BaseTask @define -class Structure(ABC): +class Structure(ABC, RuleMixin): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - rulesets: list[Ruleset] = field(factory=list, kw_only=True) - rules: list[BaseRule] = field(factory=list, kw_only=True) _tasks: list[BaseTask | list[BaseTask]] = field(factory=list, kw_only=True, alias="tasks") conversation_memory: Optional[BaseConversationMemory] = field( default=Factory(lambda: ConversationMemory()), @@ -37,22 +35,6 @@ class Structure(ABC): fail_fast: bool = field(default=True, kw_only=True) _execution_args: tuple = () - @rulesets.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None: - if not rulesets: - return - - if self.rules: - raise ValueError("can't have both rulesets and rules specified") - - @rules.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_rules(self, _: Attribute, rules: list[Rule]) -> None: - if not rules: - return - - if self.rulesets: - raise ValueError("can't have both rules and rulesets specified") - def __attrs_post_init__(self) -> None: tasks = self._tasks.copy() self._tasks.clear() diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 2c67430359..7e6d9d1155 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -78,7 +78,7 @@ def parents_output_text(self) -> str: @property def meta_memories(self) -> list[BaseMetaEntry]: - if self.structure and self.structure.meta_memory: + if self.structure is not None and self.structure.meta_memory: if self.max_meta_memory_entries: return self.structure.meta_memory.entries[: self.max_meta_memory_entries] else: @@ -191,7 +191,7 @@ def run(self) -> BaseArtifact: ... @property def full_context(self) -> dict[str, Any]: - if self.structure: + if self.structure is not None: structure_context = self.structure.context(self) structure_context.update(self.context) diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index 234840401f..43096dcede 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -18,6 +18,4 @@ class ExtractionTask(BaseTextInputTask): args: dict = field(kw_only=True, factory=dict) def run(self) -> ListArtifact | ErrorArtifact: - return self.extraction_engine.extract_artifacts( - ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args - ) + return self.extraction_engine.extract_artifacts(ListArtifact([self.input]), rulesets=self.rulesets, **self.args) diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 0ed28a11bf..649f9e3fb6 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -74,7 +74,7 @@ def run(self) -> ImageArtifact: prompts=[prompt_artifact.to_text()], image=image_artifact, mask=mask_artifact, - rulesets=self.all_rulesets, + rulesets=self.rulesets, negative_rulesets=self.negative_rulesets, ) diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index 6b63709db7..019f74fa1a 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -74,7 +74,7 @@ def run(self) -> ImageArtifact: prompts=[prompt_artifact.to_text()], image=image_artifact, mask=mask_artifact, - rulesets=self.all_rulesets, + rulesets=self.rulesets, negative_rulesets=self.negative_rulesets, ) diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 4d33563928..d2ebf79c27 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -53,7 +53,7 @@ def input(self, value: TextArtifact) -> None: def run(self) -> ImageArtifact: image_artifact = self.image_generation_engine.run( prompts=[self.input.to_text()], - rulesets=self.all_rulesets, + rulesets=self.rulesets, negative_rulesets=self.negative_rulesets, ) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index d2dd20b36f..127fabf48f 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -9,6 +9,7 @@ from griptape.common import PromptStack from griptape.configs import Defaults from griptape.mixins.rule_mixin import RuleMixin +from griptape.rules import Ruleset from griptape.tasks import BaseTask from griptape.utils import J2 @@ -32,6 +33,22 @@ class PromptTask(RuleMixin, BaseTask): alias="input", ) + @property + def rulesets(self) -> list: + default_rules = self.rules + rulesets = self._rulesets + + if self.structure is not None: + if self.structure._rulesets: + rulesets = self.structure._rulesets + self._rulesets + if self.structure.rules: + default_rules = self.structure.rules + self.rules + + if default_rules: + rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=default_rules)) + + return rulesets + @property def input(self) -> BaseArtifact: return self._process_task_input(self._input) @@ -45,7 +62,7 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: stack = PromptStack() - memory = self.structure.conversation_memory + memory = self.structure.conversation_memory if self.structure is not None else None system_template = self.generate_system_template(self) if system_template: @@ -64,7 +81,7 @@ def prompt_stack(self) -> PromptStack: def default_system_template_generator(self, _: PromptTask) -> str: return J2("tasks/prompt_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), ) def before_run(self) -> None: diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index dc1a7b8be4..4861510d6b 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -17,4 +17,4 @@ class TextSummaryTask(BaseTextInputTask): summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True) def run(self) -> TextArtifact: - return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets)) + return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.rulesets)) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 680a67603b..c131d69bc7 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -35,7 +35,7 @@ def input(self, value: TextArtifact) -> None: self._input = value def run(self) -> AudioArtifact: - audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) + audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.rulesets) if self.output_dir or self.output_file: self._write_to_file(audio_artifact) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 7ae63b9027..38d6e1512f 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -51,7 +51,7 @@ def preprocess(self, structure: Structure) -> ToolTask: def default_system_template_generator(self, _: PromptTask) -> str: return J2("tasks/tool_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_schema=utils.minify_json(json.dumps(self.tool.schema())), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index ed9860c660..2a4a926bdc 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -68,7 +68,7 @@ def tool_output_memory(self) -> list[TaskMemory]: @property def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools) - memory = self.structure.conversation_memory + memory = self.structure.conversation_memory if self.structure is not None else None stack.add_system_message(self.generate_system_template(self)) @@ -113,7 +113,7 @@ def prompt_stack(self) -> PromptStack: stack.add_assistant_message(self.generate_assistant_subtask_template(s)) stack.add_user_message(self.generate_user_subtask_template(s)) - if memory: + if memory is not None: # inserting at index 1 to place memory right after system prompt memory.add_to_prompt_stack(self.prompt_driver, stack, 1) @@ -132,7 +132,7 @@ def default_system_template_generator(self, _: PromptTask) -> str: schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. return J2("tasks/toolkit_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index e3feaeac52..ddc16178b6 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -66,7 +66,7 @@ def run(self) -> ImageArtifact: output_image_artifact = self.image_generation_engine.run( prompts=[prompt_artifact.to_text()], image=image_artifact, - rulesets=self.all_rulesets, + rulesets=self.rulesets, negative_rulesets=self.negative_rulesets, ) diff --git a/griptape/tools/prompt_summary/tool.py b/griptape/tools/prompt_summary/tool.py index 5175073808..e12b195474 100644 --- a/griptape/tools/prompt_summary/tool.py +++ b/griptape/tools/prompt_summary/tool.py @@ -52,4 +52,4 @@ def summarize(self, params: dict) -> BaseArtifact: else: return ErrorArtifact("memory not found") - return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.all_rulesets) + return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.rulesets) diff --git a/tests/unit/mixins/test_rule_mixin.py b/tests/unit/mixins/test_rule_mixin.py index 393d721a3f..851c31333c 100644 --- a/tests/unit/mixins/test_rule_mixin.py +++ b/tests/unit/mixins/test_rule_mixin.py @@ -1,5 +1,3 @@ -import pytest - from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Rule, Ruleset from griptape.structures import Agent @@ -23,8 +21,10 @@ def test_rulesets(self): assert mixin.rulesets == [ruleset] def test_rules_and_rulesets(self): - with pytest.raises(ValueError): - RuleMixin(rules=[Rule("foo")], rulesets=[Ruleset("bar", [Rule("baz")])]) + mixin = RuleMixin(rules=[Rule("foo")], rulesets=[Ruleset("bar", [Rule("baz")])]) + + assert mixin.rules == [Rule("foo")] + assert mixin.rulesets == [Ruleset("bar", [Rule("baz")]), Ruleset("Default Ruleset", [Rule("foo")])] def test_inherits_structure_rulesets(self): # Tests that a task using the mixin inherits rulesets from its structure. @@ -35,4 +35,4 @@ def test_inherits_structure_rulesets(self): task = PromptTask(rulesets=[ruleset2]) agent.add_task(task) - assert task.all_rulesets == [ruleset1, ruleset2] + assert task.rulesets == [ruleset1, ruleset2] diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 1b9fa4157b..d522a43a27 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -29,9 +29,9 @@ def test_rulesets(self): agent.add_task(PromptTask(rulesets=[Ruleset("Bar", [Rule("bar test")])])) assert isinstance(agent.task, PromptTask) - assert len(agent.task.all_rulesets) == 2 - assert agent.task.all_rulesets[0].name == "Foo" - assert agent.task.all_rulesets[1].name == "Bar" + assert len(agent.task.rulesets) == 2 + assert agent.task.rulesets[0].name == "Foo" + assert agent.task.rulesets[1].name == "Bar" def test_rules(self): agent = Agent(rules=[Rule("foo test")]) @@ -39,17 +39,22 @@ def test_rules(self): agent.add_task(PromptTask(rules=[Rule("bar test")])) assert isinstance(agent.task, PromptTask) - assert len(agent.task.all_rulesets) == 2 - assert agent.task.all_rulesets[0].name == "Default Ruleset" - assert agent.task.all_rulesets[1].name == "Additional Ruleset" + assert len(agent.task.rulesets) == 1 + assert agent.task.rulesets[0].name == "Default Ruleset" + assert len(agent.task.rulesets[0].rules) == 2 + assert agent.task.rulesets[0].rules[0].value == "foo test" + assert agent.task.rulesets[0].rules[1].value == "bar test" def test_rules_and_rulesets(self): - with pytest.raises(ValueError): - Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + agent = Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + assert len(agent.rulesets) == 2 + assert len(agent.rules) == 1 agent = Agent() - with pytest.raises(ValueError): - agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) + agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) + assert isinstance(agent.task, PromptTask) + assert len(agent.task.rulesets) == 2 + assert len(agent.task.rules) == 1 def test_with_task_memory(self): agent = Agent(tools=[MockTool(off_prompt=True)]) diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 377f717910..e6fa6d7702 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -45,14 +45,14 @@ def test_rulesets(self): ) assert isinstance(pipeline.tasks[0], PromptTask) - assert len(pipeline.tasks[0].all_rulesets) == 2 - assert pipeline.tasks[0].all_rulesets[0].name == "Foo" - assert pipeline.tasks[0].all_rulesets[1].name == "Bar" + assert len(pipeline.tasks[0].rulesets) == 2 + assert pipeline.tasks[0].rulesets[0].name == "Foo" + assert pipeline.tasks[0].rulesets[1].name == "Bar" assert isinstance(pipeline.tasks[1], PromptTask) - assert len(pipeline.tasks[1].all_rulesets) == 2 - assert pipeline.tasks[1].all_rulesets[0].name == "Foo" - assert pipeline.tasks[1].all_rulesets[1].name == "Baz" + assert len(pipeline.tasks[1].rulesets) == 2 + assert pipeline.tasks[1].rulesets[0].name == "Foo" + assert pipeline.tasks[1].rulesets[1].name == "Baz" def test_rules(self): pipeline = Pipeline(rules=[Rule("foo test")]) @@ -60,21 +60,24 @@ def test_rules(self): pipeline.add_tasks(PromptTask(rules=[Rule("bar test")]), PromptTask(rules=[Rule("baz test")])) assert isinstance(pipeline.tasks[0], PromptTask) - assert len(pipeline.tasks[0].all_rulesets) == 2 - assert pipeline.tasks[0].all_rulesets[0].name == "Default Ruleset" - assert pipeline.tasks[0].all_rulesets[1].name == "Additional Ruleset" + assert len(pipeline.tasks[0].rulesets) == 1 + assert pipeline.tasks[0].rulesets[0].name == "Default Ruleset" + assert len(pipeline.tasks[0].rulesets[0].rules) == 2 assert isinstance(pipeline.tasks[1], PromptTask) - assert pipeline.tasks[1].all_rulesets[0].name == "Default Ruleset" - assert pipeline.tasks[1].all_rulesets[1].name == "Additional Ruleset" + assert pipeline.tasks[1].rulesets[0].name == "Default Ruleset" + assert len(pipeline.tasks[1].rulesets[0].rules) == 2 def test_rules_and_rulesets(self): - with pytest.raises(ValueError): - Pipeline(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + pipeline = Pipeline(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + assert len(pipeline.rulesets) == 2 + assert len(pipeline.rules) == 1 pipeline = Pipeline() - with pytest.raises(ValueError): - pipeline.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) + pipeline.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) + assert isinstance(pipeline.tasks[0], PromptTask) + assert len(pipeline.tasks[0].rulesets) == 2 + assert len(pipeline.tasks[0].rules) == 1 def test_with_no_task_memory(self): pipeline = Pipeline() diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 03698a73ca..e3bd8e886e 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -42,14 +42,14 @@ def test_rulesets(self): ) assert isinstance(workflow.tasks[0], PromptTask) - assert len(workflow.tasks[0].all_rulesets) == 2 - assert workflow.tasks[0].all_rulesets[0].name == "Foo" - assert workflow.tasks[0].all_rulesets[1].name == "Bar" + assert len(workflow.tasks[0].rulesets) == 2 + assert workflow.tasks[0].rulesets[0].name == "Foo" + assert workflow.tasks[0].rulesets[1].name == "Bar" assert isinstance(workflow.tasks[1], PromptTask) - assert len(workflow.tasks[1].all_rulesets) == 2 - assert workflow.tasks[1].all_rulesets[0].name == "Foo" - assert workflow.tasks[1].all_rulesets[1].name == "Baz" + assert len(workflow.tasks[1].rulesets) == 2 + assert workflow.tasks[1].rulesets[0].name == "Foo" + assert workflow.tasks[1].rulesets[1].name == "Baz" def test_rules(self): workflow = Workflow(rules=[Rule("foo test")]) @@ -57,22 +57,25 @@ def test_rules(self): workflow.add_tasks(PromptTask(rules=[Rule("bar test")]), PromptTask(rules=[Rule("baz test")])) assert isinstance(workflow.tasks[0], PromptTask) - assert len(workflow.tasks[0].all_rulesets) == 2 - assert workflow.tasks[0].all_rulesets[0].name == "Default Ruleset" - assert workflow.tasks[0].all_rulesets[1].name == "Additional Ruleset" + assert len(workflow.tasks[0].rulesets) == 1 + assert workflow.tasks[0].rulesets[0].name == "Default Ruleset" + assert len(workflow.tasks[0].rulesets[0].rules) == 2 assert isinstance(workflow.tasks[1], PromptTask) - assert len(workflow.tasks[1].all_rulesets) == 2 - assert workflow.tasks[1].all_rulesets[0].name == "Default Ruleset" - assert workflow.tasks[1].all_rulesets[1].name == "Additional Ruleset" + assert len(workflow.tasks[1].rulesets) == 1 + assert workflow.tasks[1].rulesets[0].name == "Default Ruleset" + assert len(workflow.tasks[1].rulesets[0].rules) == 2 def test_rules_and_rulesets(self): - with pytest.raises(ValueError): - Workflow(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + workflow = Workflow(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]) + assert len(workflow.rulesets) == 2 + assert len(workflow.rules) == 1 workflow = Workflow() - with pytest.raises(ValueError): - workflow.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) + workflow.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])) + assert isinstance(workflow.tasks[0], PromptTask) + assert len(workflow.tasks[0].rulesets) == 2 + assert len(workflow.tasks[0].rules) == 1 def test_with_no_task_memory(self): workflow = Workflow() diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index ff6afe42b3..4735e89b44 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -49,11 +49,11 @@ def test_rulesets(self): rulesets=[Ruleset("Foo", [Rule("foo test")]), Ruleset("Bar", [Rule("bar test")])] ) - assert len(prompt_task.all_rulesets) == 2 - assert prompt_task.all_rulesets[0].name == "Foo" - assert prompt_task.all_rulesets[1].name == "Bar" + assert len(prompt_task.rulesets) == 2 + assert prompt_task.rulesets[0].name == "Foo" + assert prompt_task.rulesets[1].name == "Bar" def test_rules(self): prompt_task = MockTextInputTask(rules=[Rule("foo test"), Rule("bar test")]) - assert prompt_task.all_rulesets[0].name == "Default Ruleset" + assert prompt_task.rulesets[0].name == "Default Ruleset" diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 9fadf4e369..ac5b3c771f 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -248,7 +248,7 @@ def verify_structure_output(self, structure) -> dict: task_names = [task.__class__.__name__ for task in structure.tasks] prompt = structure.input_task.input.to_text() actual = structure.output.to_text() - rules = [rule.value for ruleset in structure.input_task.all_rulesets for rule in ruleset.rules] + rules = [rule.value for ruleset in structure.input_task.rulesets for rule in ruleset.rules] agent = Agent( rulesets=[