From 1a003bd392c9e2924aa9095ca799977935f9c5a1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 15 Oct 2024 11:45:36 -0700 Subject: [PATCH] Add RunnableMixin, implement in Structures, Tasks, Tools --- CHANGELOG.md | 7 ++++ .../structures/src/task_hooks.py | 38 ++++++++++++++++++ docs/griptape-framework/structures/tasks.md | 22 ++++++++++ griptape/mixins/runnable_mixin.py | 35 ++++++++++++++++ griptape/structures/agent.py | 2 +- griptape/structures/pipeline.py | 2 +- griptape/structures/structure.py | 5 ++- griptape/structures/workflow.py | 4 +- griptape/tasks/actions_subtask.py | 12 +++--- griptape/tasks/audio_transcription_task.py | 2 +- griptape/tasks/base_task.py | 40 ++++++++++--------- griptape/tasks/code_execution_task.py | 2 +- griptape/tasks/extraction_task.py | 2 +- griptape/tasks/image_query_task.py | 2 +- .../tasks/inpainting_image_generation_task.py | 2 +- .../outpainting_image_generation_task.py | 2 +- .../tasks/prompt_image_generation_task.py | 2 +- griptape/tasks/prompt_task.py | 2 +- griptape/tasks/rag_task.py | 2 +- griptape/tasks/structure_run_task.py | 2 +- griptape/tasks/text_summary_task.py | 2 +- griptape/tasks/text_to_speech_task.py | 2 +- griptape/tasks/tool_task.py | 2 +- griptape/tasks/toolkit_task.py | 2 +- .../tasks/variation_image_generation_task.py | 2 +- griptape/tools/base_tool.py | 13 ++++-- tests/mocks/mock_audio_input_task.py | 2 +- tests/mocks/mock_image_generation_task.py | 6 +-- tests/mocks/mock_task.py | 2 +- tests/mocks/mock_text_input_task.py | 2 +- tests/unit/mixins/test_runnable_mixin.py | 21 ++++++++++ tests/unit/structures/test_agent.py | 13 ++++++ tests/unit/tasks/test_actions_subtask.py | 4 +- tests/unit/tasks/test_base_task.py | 20 +++++++++- tests/unit/tasks/test_code_execution_task.py | 6 +-- tests/unit/tasks/test_image_query_task.py | 2 +- .../test_inpainting_image_generation_task.py | 4 +- .../test_outpainting_image_generation_task.py | 4 +- tests/unit/tasks/test_prompt_task.py | 25 ++++++++++++ .../test_variation_image_generation_task.py | 2 +- tests/unit/tools/test_base_tool.py | 17 ++++++-- 41 files changed, 269 insertions(+), 71 deletions(-) create mode 100644 docs/griptape-framework/structures/src/task_hooks.py create mode 100644 griptape/mixins/runnable_mixin.py create mode 100644 tests/unit/mixins/test_runnable_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2127121944..68e70a64fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing. - `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`. - `Chat.input_fn` for customizing the input to the Chat utility. +- `RunnableMixin` which adds `on_before_run` and `on_after_run` hooks. ### Changed @@ -25,6 +26,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Updated `EventListener.handler` return value behavior. - If `EventListener.handler` returns `None`, the event will not be published to the `event_listener_driver`. - If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is. +- **BREAKING**: Renamed `BaseTask.run` to `BaseTask.try_run`. +- **BREAKING**: Renamed `BaseTask.execute` to `BaseTask.run`. +- **BREAKING**: Renamed `BaseTask.can_execute` to `BaseTool.can_run`. +- **BREAKING**: Renamed `BaseTool.run` to `BaseTool.try_run`. +- **BREAKING**: Renamed `BaseTool.execute` to `BaseTool.run`. - Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`. - `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. - `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. @@ -40,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`. - Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory` - `@activity` decorated functions can now accept kwargs that are defined in the activity schema. +- Implemented `RunnableMixin` in `Structure`, `BaseTask`, and `BaseTool`. ### Fixed diff --git a/docs/griptape-framework/structures/src/task_hooks.py b/docs/griptape-framework/structures/src/task_hooks.py new file mode 100644 index 0000000000..e2e884a10f --- /dev/null +++ b/docs/griptape-framework/structures/src/task_hooks.py @@ -0,0 +1,38 @@ +import json +import re + +from griptape.structures import Agent +from griptape.tasks import PromptTask +from griptape.tasks.base_task import BaseTask + +SSN_PATTERN = re.compile(r"\b\d{3}-\d{2}-\d{4}\b") + +original_input = None + + +def on_before_run(task: BaseTask) -> None: + global original_input + + original_input = task.input.value + + if isinstance(task, PromptTask): + task.input = SSN_PATTERN.sub("xxx-xx-xxxx", task.input.value) + + +def on_after_run(task: BaseTask) -> None: + if task.output is not None: + task.output.value = json.dumps( + {"original_input": original_input, "masked_input": task.input.value, "output": task.output.value}, indent=2 + ) + + +agent = Agent( + tasks=[ + PromptTask( + "Respond to this user: {{ args[0] }}", + on_before_run=on_before_run, + on_after_run=on_after_run, + ) + ] +) +agent.run("Hello! My favorite color is blue, and my social security number is 123-45-6789.") diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 7472010205..ef268a3ce6 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -54,6 +54,28 @@ Additional [context](../../reference/griptape/structures/structure.md#griptape.s sleeves, and let's get baking! 🍰🎉 ``` +## Hooks + +All Tasks implement [RunnableMixin](../../reference/griptape/mixins/runnable_mixin.md) which provides `on_before_run` and `on_after_run` hooks for the Task lifecycle. + +These hooks can be used to perform actions before and after the Task is run. For example, you can mask sensitive information before running the Task, and transform the output after the Task is run. + +```python +--8<-- "docs/griptape-framework/structures/src/task_hooks.py" +``` + +``` +[10/15/24 15:14:10] INFO PromptTask 63a0c734059c42808c87dff351adc8ab + Input: Respond to this user: Hello! My favorite color is blue, and my social security number is xxx-xx-xxxx. +[10/15/24 15:14:11] INFO PromptTask 63a0c734059c42808c87dff351adc8ab + Output: { + "original_input": "Respond to this user: Hello! My favorite color is blue, and my social security number is 123-45-6789.", + "masked_input": "Respond to this user: Hello! My favorite color is blue, and my social security number is xxx-xx-xxxx.", + "output": "Hello! It's great to hear that your favorite color is blue. However, it's important to keep your personal information, like your + social security number, private and secure. If you have any questions or need assistance, feel free to ask!" + } +``` + ## Prompt Task For general purpose prompting, use the [PromptTask](../../reference/griptape/tasks/prompt_task.md): diff --git a/griptape/mixins/runnable_mixin.py b/griptape/mixins/runnable_mixin.py new file mode 100644 index 0000000000..4571e21088 --- /dev/null +++ b/griptape/mixins/runnable_mixin.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Generic, Optional, TypeVar, cast + +from attrs import define, field + +# Generics magic that allows us to reference the type of the class that is implementing the mixin +T = TypeVar("T", bound="RunnableMixin") + + +@define() +class RunnableMixin(ABC, Generic[T]): + """Mixin for classes that can be "run". + + Implementing classes should pass themselves as the generic type to ensure that the correct type is used in the callbacks. + + Attributes: + on_before_run: Optional callback that is called at the very beginning of the `run` method. + on_after_run: Optional callback that is called at the very end of the `run` method. + """ + + on_before_run: Optional[Callable[[T], None]] = field(kw_only=True, default=None) + on_after_run: Optional[Callable[[T], None]] = field(kw_only=True, default=None) + + def before_run(self, *args, **kwargs) -> Any: + if self.on_before_run is not None: + self.on_before_run(cast(T, self)) + + @abstractmethod + def run(self, *args, **kwargs) -> Any: ... + + def after_run(self, *args, **kwargs) -> Any: + if self.on_after_run is not None: + self.on_after_run(cast(T, self)) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 121220bc26..be1b73e342 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -74,6 +74,6 @@ def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]: @observable def try_run(self, *args) -> Agent: - self.task.execute() + self.task.run() return self diff --git a/griptape/structures/pipeline.py b/griptape/structures/pipeline.py index 497e77bdaa..dcf8503f77 100644 --- a/griptape/structures/pipeline.py +++ b/griptape/structures/pipeline.py @@ -72,7 +72,7 @@ def __run_from_task(self, task: Optional[BaseTask]) -> None: if task is None: return else: - if isinstance(task.execute(), ErrorArtifact) and self.fail_fast: + if isinstance(task.run(), ErrorArtifact) and self.fail_fast: return else: self.__run_from_task(next(iter(task.children), None)) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index c2702fdbfa..ce715a7b4e 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -12,6 +12,7 @@ from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory, Run from griptape.mixins.rule_mixin import RuleMixin +from griptape.mixins.runnable_mixin import RunnableMixin from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: @@ -21,7 +22,7 @@ @define -class Structure(ABC, RuleMixin, SerializableMixin): +class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) _tasks: list[Union[BaseTask, list[BaseTask]]] = field( factory=list, kw_only=True, alias="tasks", metadata={"serializable": True} @@ -139,6 +140,7 @@ def resolve_relationships(self) -> None: @observable def before_run(self, args: Any) -> None: + RunnableMixin.before_run(self, args) self._execution_args = args [task.reset() for task in self.tasks] @@ -155,6 +157,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: + RunnableMixin.after_run(self) if self.conversation_memory and self.output_task.output is not None: run = Run(input=self.input_task.input, output=self.output_task.output) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 99af20dc29..9777fb520e 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -107,8 +107,8 @@ def try_run(self, *args) -> Workflow: ordered_tasks = self.order_tasks() for task in ordered_tasks: - if task.can_execute(): - future = self.futures_executor.submit(task.execute) + if task.can_run(): + future = self.futures_executor.submit(task.run) futures_list[future] = task # Wait for all tasks to complete diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 8d68ae9963..3e730546fa 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -113,14 +113,14 @@ def before_run(self) -> None: ] logger.info("".join(parts)) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: try: if any(isinstance(a.output, ErrorArtifact) for a in self.actions): errors = [a.output.value for a in self.actions if isinstance(a.output, ErrorArtifact)] self.output = ErrorArtifact("\n\n".join(errors)) else: - results = self.execute_actions(self.actions) + results = self.run_actions(self.actions) actions_output = [] for result in results: @@ -138,13 +138,13 @@ def run(self) -> BaseArtifact: else: return ErrorArtifact("no tool output") - def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]: - return utils.execute_futures_list([self.futures_executor.submit(self.execute_action, a) for a in actions]) + def run_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]: + return utils.execute_futures_list([self.futures_executor.submit(self.run_action, a) for a in actions]) - def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: + def run_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: if action.tool is not None: if action.path is not None: - output = action.tool.execute(getattr(action.tool, action.path), self, action) + output = action.tool.run(getattr(action.tool, action.path), self, action) else: output = ErrorArtifact("action path not found") else: diff --git a/griptape/tasks/audio_transcription_task.py b/griptape/tasks/audio_transcription_task.py index 3d83cf7e79..819f166ec8 100644 --- a/griptape/tasks/audio_transcription_task.py +++ b/griptape/tasks/audio_transcription_task.py @@ -18,5 +18,5 @@ class AudioTranscriptionTask(BaseAudioInputTask): kw_only=True, ) - def run(self) -> TextArtifact: + def try_run(self) -> TextArtifact: return self.audio_transcription_engine.run(self.input) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index b6012c7e1b..3b28c900d7 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -8,14 +8,14 @@ from attrs import Factory, define, field -from griptape.artifacts import ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact from griptape.configs import Defaults from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin +from griptape.mixins.runnable_mixin import RunnableMixin from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: - from griptape.artifacts import BaseArtifact from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure @@ -23,7 +23,7 @@ @define -class BaseTask(FuturesExecutorMixin, SerializableMixin, ABC): +class BaseTask(FuturesExecutorMixin, SerializableMixin, RunnableMixin["BaseTask"], ABC): class State(Enum): PENDING = 1 EXECUTING = 2 @@ -137,6 +137,7 @@ def is_executing(self) -> bool: return self.state == BaseTask.State.EXECUTING def before_run(self) -> None: + RunnableMixin.before_run(self) if self.structure is not None: EventBus.publish_event( StartTaskEvent( @@ -148,25 +149,13 @@ def before_run(self) -> None: ), ) - def after_run(self) -> None: - if self.structure is not None: - EventBus.publish_event( - FinishTaskEvent( - task_id=self.id, - task_parent_ids=self.parent_ids, - task_child_ids=self.child_ids, - task_input=self.input, - task_output=self.output, - ), - ) - - def execute(self) -> Optional[BaseArtifact]: + def run(self) -> BaseArtifact: try: self.state = BaseTask.State.EXECUTING self.before_run() - self.output = self.run() + self.output = self.try_run() self.after_run() except Exception as e: @@ -178,7 +167,20 @@ def execute(self) -> Optional[BaseArtifact]: return self.output - def can_execute(self) -> bool: + def after_run(self) -> None: + RunnableMixin.after_run(self) + if self.structure is not None: + EventBus.publish_event( + FinishTaskEvent( + task_id=self.id, + task_parent_ids=self.parent_ids, + task_child_ids=self.child_ids, + task_input=self.input, + task_output=self.output, + ), + ) + + def can_run(self) -> bool: return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents) def reset(self) -> BaseTask: @@ -188,7 +190,7 @@ def reset(self) -> BaseTask: return self @abstractmethod - def run(self) -> BaseArtifact: ... + def try_run(self) -> BaseArtifact: ... @property def full_context(self) -> dict[str, Any]: diff --git a/griptape/tasks/code_execution_task.py b/griptape/tasks/code_execution_task.py index d627382fd0..390d08e911 100644 --- a/griptape/tasks/code_execution_task.py +++ b/griptape/tasks/code_execution_task.py @@ -14,5 +14,5 @@ class CodeExecutionTask(BaseTextInputTask): run_fn: Callable[[CodeExecutionTask], BaseArtifact] = field(kw_only=True) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: return self.run_fn(self) diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index 43096dcede..35cc83003e 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -17,5 +17,5 @@ class ExtractionTask(BaseTextInputTask): extraction_engine: BaseExtractionEngine = field(kw_only=True) args: dict = field(kw_only=True, factory=dict) - def run(self) -> ListArtifact | ErrorArtifact: + def try_run(self) -> ListArtifact | ErrorArtifact: return self.extraction_engine.extract_artifacts(ListArtifact([self.input]), rulesets=self.rulesets, **self.args) diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index 5d1fcc79ab..abd14ec567 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -64,7 +64,7 @@ def input( ) -> None: self._input = value - def run(self) -> TextArtifact: + def try_run(self) -> TextArtifact: query = self.input.value[0] if all(isinstance(artifact, ImageArtifact) for artifact in self.input.value[1:]): diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index a00e345fbe..88f868c72f 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -59,7 +59,7 @@ def input( ) -> None: self._input = value - def run(self) -> ImageArtifact: + def try_run(self) -> ImageArtifact: prompt_artifact = self.input[0] image_artifact = self.input[1] diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index ee928c8003..60fbff4570 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -59,7 +59,7 @@ def input( ) -> None: self._input = value - def run(self) -> ImageArtifact: + def try_run(self) -> ImageArtifact: prompt_artifact = self.input[0] image_artifact = self.input[1] diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 5676f4d65f..a76c8d0d8d 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -50,7 +50,7 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - def run(self) -> ImageArtifact: + def try_run(self) -> ImageArtifact: image_artifact = self.image_generation_engine.run( prompts=[self.input.to_text()], rulesets=self.rulesets, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index ed3ffa4524..98eb9e3095 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -94,7 +94,7 @@ def after_run(self) -> None: logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: message = self.prompt_driver.run(self.prompt_stack) return message.to_artifact() diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index b7ea8d7c7d..3244e6c1a9 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -11,7 +11,7 @@ class RagTask(BaseTextInputTask): rag_engine: RagEngine = field(kw_only=True, default=Factory(lambda: RagEngine())) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: outputs = self.rag_engine.process_query(self.input.to_text()).outputs if len(outputs) > 0: diff --git a/griptape/tasks/structure_run_task.py b/griptape/tasks/structure_run_task.py index 6860958aa4..db5f8bc497 100644 --- a/griptape/tasks/structure_run_task.py +++ b/griptape/tasks/structure_run_task.py @@ -22,7 +22,7 @@ class StructureRunTask(PromptTask): driver: BaseStructureRunDriver = field(kw_only=True) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: if isinstance(self.input, ListArtifact): return self.driver.run(*self.input.value) else: diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 4861510d6b..5cb7510acc 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -16,5 +16,5 @@ class TextSummaryTask(BaseTextInputTask): summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True) - def run(self) -> TextArtifact: + def try_run(self) -> TextArtifact: return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.rulesets)) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 5f897164cd..ef67ca44d4 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -34,7 +34,7 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - def run(self) -> AudioArtifact: + def try_run(self) -> AudioArtifact: audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.rulesets) if self.output_dir or self.output_file: diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index a9a36ddb6a..325400f5c6 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -60,7 +60,7 @@ def default_system_template_generator(self, _: PromptTask) -> str: def actions_schema(self) -> Schema: return self._actions_schema_for_tools([self.tool]) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: result = self.prompt_driver.run(prompt_stack=self.prompt_stack) if self.prompt_driver.use_native_tools: diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 2a4a926bdc..00fc981d8d 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -165,7 +165,7 @@ def set_default_tools_memory(self, memory: TaskMemory) -> None: if tool.output_memory is None and tool.off_prompt: tool.output_memory = {getattr(a, "name"): [self.task_memory] for a in tool.activities()} - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: from griptape.tasks import ActionsSubtask self.subtasks.clear() diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index c443cd08b6..c0db1a64bb 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -56,7 +56,7 @@ def input(self) -> ListArtifact: def input(self, value: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact]) -> None: self._input = value - def run(self) -> ImageArtifact: + def try_run(self) -> ImageArtifact: prompt_artifact = self.input[0] image_artifact = self.input[1] diff --git a/griptape/tools/base_tool.py b/griptape/tools/base_tool.py index 7efa9f77f5..41e513f718 100644 --- a/griptape/tools/base_tool.py +++ b/griptape/tools/base_tool.py @@ -16,6 +16,7 @@ from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact from griptape.common import observable from griptape.mixins.activity_mixin import ActivityMixin +from griptape.mixins.runnable_mixin import RunnableMixin from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: @@ -25,7 +26,7 @@ @define -class BaseTool(ActivityMixin, SerializableMixin, ABC): +class BaseTool(ActivityMixin, SerializableMixin, RunnableMixin["BaseTool"], ABC): """Abstract class for all tools to inherit from for. Attributes: @@ -112,11 +113,11 @@ def activity_schemas(self) -> list[Schema]: return schemas - def execute(self, activity: Callable, subtask: ActionsSubtask, action: ToolAction) -> BaseArtifact: + def run(self, activity: Callable, subtask: ActionsSubtask, action: ToolAction) -> BaseArtifact: try: output = self.before_run(activity, subtask, action) - output = self.run(activity, subtask, action, output) + output = self.try_run(activity, subtask, action, output) output = self.after_run(activity, subtask, action, output) except Exception as e: @@ -125,10 +126,12 @@ def execute(self, activity: Callable, subtask: ActionsSubtask, action: ToolActio return output def before_run(self, activity: Callable, subtask: ActionsSubtask, action: ToolAction) -> Optional[dict]: + RunnableMixin.before_run(self) + return action.input @observable(tags=["Tool.run()"]) - def run( + def try_run( self, activity: Callable, subtask: ActionsSubtask, @@ -153,6 +156,8 @@ def after_run( action: ToolAction, value: BaseArtifact, ) -> BaseArtifact: + RunnableMixin.after_run(self) + if value: if self.output_memory: output_memories = self.output_memory[getattr(activity, "name")] or [] diff --git a/tests/mocks/mock_audio_input_task.py b/tests/mocks/mock_audio_input_task.py index 95b8c88d08..cd358c92ba 100644 --- a/tests/mocks/mock_audio_input_task.py +++ b/tests/mocks/mock_audio_input_task.py @@ -6,5 +6,5 @@ @define class MockAudioInputTask(BaseAudioInputTask): - def run(self) -> TextArtifact: + def try_run(self) -> TextArtifact: return TextArtifact(self.input.to_text()) diff --git a/tests/mocks/mock_image_generation_task.py b/tests/mocks/mock_image_generation_task.py index b55c5c9953..bc0d8e35ff 100644 --- a/tests/mocks/mock_image_generation_task.py +++ b/tests/mocks/mock_image_generation_task.py @@ -1,4 +1,4 @@ -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.tasks import BaseImageGenerationTask @@ -6,7 +6,7 @@ @define class MockImageGenerationTask(BaseImageGenerationTask): - _input: TextArtifact = field(default="input") + _input: TextArtifact = field(default=Factory(lambda: TextArtifact("input"))) @property def input(self) -> TextArtifact: @@ -16,5 +16,5 @@ def input(self) -> TextArtifact: def input(self, value: str) -> None: self._input = TextArtifact(value) - def run(self) -> ImageArtifact: + def try_run(self) -> ImageArtifact: return ImageArtifact(value=b"image data", format="png", width=512, height=512) diff --git a/tests/mocks/mock_task.py b/tests/mocks/mock_task.py index 81aa037137..86f0254b6e 100644 --- a/tests/mocks/mock_task.py +++ b/tests/mocks/mock_task.py @@ -12,5 +12,5 @@ class MockTask(BaseTask): def input(self) -> BaseArtifact: return TextArtifact(self.mock_input) - def run(self) -> BaseArtifact: + def try_run(self) -> BaseArtifact: return self.input diff --git a/tests/mocks/mock_text_input_task.py b/tests/mocks/mock_text_input_task.py index f1439bd428..149fb80596 100644 --- a/tests/mocks/mock_text_input_task.py +++ b/tests/mocks/mock_text_input_task.py @@ -6,5 +6,5 @@ @define class MockTextInputTask(BaseTextInputTask): - def run(self) -> TextArtifact: + def try_run(self) -> TextArtifact: return TextArtifact(self.input.to_text()) diff --git a/tests/unit/mixins/test_runnable_mixin.py b/tests/unit/mixins/test_runnable_mixin.py new file mode 100644 index 0000000000..b3e0a3f538 --- /dev/null +++ b/tests/unit/mixins/test_runnable_mixin.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock + +from tests.unit.tasks.test_base_task import MockTask + + +class TestRunnableMixin: + def test_before_run(self): + mock_on_before_run = Mock() + mock_task = MockTask(on_before_run=mock_on_before_run) + + mock_task.run() + + assert mock_on_before_run.called + + def test_after_run(self): + mock_on_after_run = Mock() + mock_task = MockTask(on_after_run=mock_on_after_run) + + mock_task.run() + + assert mock_on_after_run.called diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index b8b6bb1b4d..86f4a1141c 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + import pytest from griptape.memory import TaskMemory @@ -296,3 +298,14 @@ def test_from_dict(self): assert len(deserialized_agent.task_outputs) == 1 assert deserialized_agent.task_outputs[task.id].value == "mock output" + + def test_runnable_mixin(self): + mock_on_before_run = Mock() + mock_after_run = Mock() + agent = Agent(prompt_driver=MockPromptDriver(), on_before_run=mock_on_before_run, on_after_run=mock_after_run) + + args = "test" + agent.run(args) + + mock_on_before_run.assert_called_once_with(agent) + mock_after_run.assert_called_once_with(agent) diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index e25a42120a..1521c756e5 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -173,7 +173,7 @@ def test_execute_tool(self): task = ToolkitTask(tools=[MockTool()]) Agent().add_task(task) subtask = task.add_subtask(ActionsSubtask(valid_input)) - subtask.execute() + subtask.run() assert isinstance(subtask.output, ListArtifact) assert isinstance(subtask.output.value[0], TextArtifact) @@ -188,7 +188,7 @@ def test_execute_tool_exception(self): task = ToolkitTask(tools=[MockTool()]) Agent().add_task(task) subtask = task.add_subtask(ActionsSubtask(valid_input)) - subtask.execute() + subtask.run() assert isinstance(subtask.output, ListArtifact) assert isinstance(subtask.output.value[0], ErrorArtifact) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 3437eb117b..85addfe3d4 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -111,8 +111,8 @@ def test_children_property_no_structure(self, task): assert len(parent.children) == 3 - def test_execute_publish_events(self, task): - task.execute() + def test_run_publish_events(self, task): + task.run() assert EventBus.event_listeners[0].handler.call_count == 2 @@ -191,3 +191,19 @@ def test_from_dict(self): assert str(workflow.tasks[0].state) == "State.FINISHED" assert workflow.tasks[0].id == deserialized_task.id assert workflow.tasks[0].output.value == "foobar" + + def test_runnable_mixin(self): + mock_on_before_run = Mock() + mock_after_run = Mock() + task = MockTask("foobar", on_before_run=mock_on_before_run, on_after_run=mock_after_run) + + task.run() + + mock_on_before_run.assert_called_once_with(task) + mock_after_run.assert_called_once_with(task) + + def test_full_context(self, task): + task.structure = Agent() + task.structure._execution_args = ("foo", "bar") + + assert task.full_context == {"args": ("foo", "bar"), "structure": task.structure} diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index f0eb37ede0..8e69f53a3e 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -23,7 +23,7 @@ class TestCodeExecutionTask: def test_hello_world_fn(self): task = CodeExecutionTask(run_fn=hello_world) - assert task.run().value == "Hello World!" + assert task.try_run().value == "Hello World!" # Using a Pipeline # Overriding the input because we are implementing the task not the Pipeline @@ -31,11 +31,11 @@ def test_noop_fn(self): pipeline = Pipeline() task = CodeExecutionTask("No Op", run_fn=non_outputting) pipeline.add_task(task) - temp = task.run() + temp = task.try_run() assert temp.value == "No Op" def test_error_fn(self): task = CodeExecutionTask(run_fn=deliberate_exception) with pytest.raises(ValueError): - task.run() + task.try_run() diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index 01c116772a..349340ad4f 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -73,4 +73,4 @@ def test_run(self, image_query_engine, text_artifact, image_artifact): def test_bad_run(self, image_query_engine, text_artifact, image_artifact): with pytest.raises(ValueError, match="All inputs"): - ImageQueryTask(("foo", [image_artifact, text_artifact]), image_query_engine=image_query_engine).run() + ImageQueryTask(("foo", [image_artifact, text_artifact]), image_query_engine=image_query_engine).try_run() diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 5c4507d49c..94f2d69a8d 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -43,10 +43,10 @@ def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArti def test_bad_input(self, image_artifact): with pytest.raises(ValueError): - InpaintingImageGenerationTask(("foo", "bar", image_artifact)).run() # pyright: ignore[reportArgumentType] + InpaintingImageGenerationTask(("foo", "bar", image_artifact)).try_run() # pyright: ignore[reportArgumentType] with pytest.raises(ValueError): - InpaintingImageGenerationTask(("foo", image_artifact, "baz")).run() # pyright: ignore[reportArgumentType] + InpaintingImageGenerationTask(("foo", image_artifact, "baz")).try_run() # pyright: ignore[reportArgumentType] def test_config_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index ba5e52a820..6218c4a60b 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -43,10 +43,10 @@ def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArti def test_bad_input(self, image_artifact): with pytest.raises(ValueError): - OutpaintingImageGenerationTask(("foo", "bar", image_artifact)).run() # pyright: ignore[reportArgumentType] + OutpaintingImageGenerationTask(("foo", "bar", image_artifact)).try_run() # pyright: ignore[reportArgumentType] with pytest.raises(ValueError): - OutpaintingImageGenerationTask(("foo", image_artifact, "baz")).run() # pyright: ignore[reportArgumentType] + OutpaintingImageGenerationTask(("foo", image_artifact, "baz")).try_run() # pyright: ignore[reportArgumentType] def test_config_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index cfe8532260..0b60f09bd9 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -32,6 +32,15 @@ def test_config_prompt_driver(self): assert isinstance(task.prompt_driver, MockPromptDriver) def test_input(self): + # Structure context + pipeline = Pipeline() + task = PromptTask() + pipeline.add_task(task) + pipeline._execution_args = ("foo", "bar") + assert task.input.value == "foo" + pipeline._execution_args = ("fizz", "buzz") + assert task.input.value == "fizz" + # Str task = PromptTask("test") @@ -118,6 +127,22 @@ def test_input(self): assert task.input.value == str({"default": "test"}) + def test_input_context(self): + pipeline = Pipeline( + tasks=[ + PromptTask( + "foo", + prompt_driver=MockPromptDriver(), + on_before_run=lambda task: task.children[0].input, + ), + PromptTask("{{ parent_output }}", prompt_driver=MockPromptDriver()), + ] + ) + + pipeline.run() + + assert pipeline.tasks[1].input.value == "mock output" + def test_prompt_stack(self): task = PromptTask("{{ test }}", context={"test": "test value"}, rules=[Rule("test rule")]) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index f6afbf03e6..4c471a4f75 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -43,7 +43,7 @@ def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArti def test_bad_input(self, image_artifact): with pytest.raises(ValueError): - VariationImageGenerationTask(("foo", "bar")).run() # pyright: ignore[reportArgumentType] + VariationImageGenerationTask(("foo", "bar")).try_run() # pyright: ignore[reportArgumentType] def test_config_image_generation_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 5ac3849d59..a48889d2b9 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -1,5 +1,6 @@ import inspect import os +from unittest.mock import Mock import pytest from schema import Or, Schema, SchemaMissingKeyError @@ -248,9 +249,9 @@ def test_find_input_memory(self): assert MockTool().find_input_memory("foo") is None assert MockTool(input_memory=[defaults.text_task_memory("foo")]).find_input_memory("foo") is not None - def test_execute(self, tool): + def test_run(self, tool): action = ToolAction(input={}, name="", tag="") - assert tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar" + assert tool.run(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar" def test_schema(self, tool): tool = MockTool() @@ -308,10 +309,20 @@ def test_from_dict(self, tool): deserialized_tool = MockTool.from_dict(serialized_tool) assert isinstance(deserialized_tool, BaseTool) - assert deserialized_tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar" + assert deserialized_tool.run(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar" def test_method_kwargs_var_injection(self, tool): tool = MockToolKwargs() params = {"values": {"test_kwarg": "foo", "test_kwarg_kwargs": "bar"}} assert tool.test_with_kwargs(params) == "ack foo" + + def test_runnable_mixin(self, tool): + mock_on_before_run = Mock() + mock_after_run = Mock() + tool = MockTool(on_before_run=mock_on_before_run, on_after_run=mock_after_run) + + tool.run(tool.test_list_output, ActionsSubtask("foo"), ToolAction(input={}, name="", tag="")).to_text() + + mock_on_before_run.assert_called_once_with(tool) + mock_after_run.assert_called_once_with(tool)