Skip to content

Commit

Permalink
Add RunnableMixin, implement in Structures, Tasks, Tools
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 18, 2024
1 parent bf96755 commit 1a003bd
Show file tree
Hide file tree
Showing 41 changed files with 269 additions and 71 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.
- `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`.
- `Chat.input_fn` for customizing the input to the Chat utility.
- `RunnableMixin` which adds `on_before_run` and `on_after_run` hooks.

### Changed

Expand All @@ -25,6 +26,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Updated `EventListener.handler` return value behavior.
- If `EventListener.handler` returns `None`, the event will not be published to the `event_listener_driver`.
- If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is.
- **BREAKING**: Renamed `BaseTask.run` to `BaseTask.try_run`.
- **BREAKING**: Renamed `BaseTask.execute` to `BaseTask.run`.
- **BREAKING**: Renamed `BaseTask.can_execute` to `BaseTool.can_run`.
- **BREAKING**: Renamed `BaseTool.run` to `BaseTool.try_run`.
- **BREAKING**: Renamed `BaseTool.execute` to `BaseTool.run`.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
Expand All @@ -40,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`.
- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory`
- `@activity` decorated functions can now accept kwargs that are defined in the activity schema.
- Implemented `RunnableMixin` in `Structure`, `BaseTask`, and `BaseTool`.

### Fixed

Expand Down
38 changes: 38 additions & 0 deletions docs/griptape-framework/structures/src/task_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json
import re

from griptape.structures import Agent
from griptape.tasks import PromptTask
from griptape.tasks.base_task import BaseTask

SSN_PATTERN = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")

original_input = None


def on_before_run(task: BaseTask) -> None:
global original_input

original_input = task.input.value

if isinstance(task, PromptTask):
task.input = SSN_PATTERN.sub("xxx-xx-xxxx", task.input.value)


def on_after_run(task: BaseTask) -> None:
if task.output is not None:
task.output.value = json.dumps(
{"original_input": original_input, "masked_input": task.input.value, "output": task.output.value}, indent=2
)


agent = Agent(
tasks=[
PromptTask(
"Respond to this user: {{ args[0] }}",
on_before_run=on_before_run,
on_after_run=on_after_run,
)
]
)
agent.run("Hello! My favorite color is blue, and my social security number is 123-45-6789.")
22 changes: 22 additions & 0 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ Additional [context](../../reference/griptape/structures/structure.md#griptape.s
sleeves, and let's get baking! 🍰🎉
```

## Hooks

All Tasks implement [RunnableMixin](../../reference/griptape/mixins/runnable_mixin.md) which provides `on_before_run` and `on_after_run` hooks for the Task lifecycle.

These hooks can be used to perform actions before and after the Task is run. For example, you can mask sensitive information before running the Task, and transform the output after the Task is run.

```python
--8<-- "docs/griptape-framework/structures/src/task_hooks.py"
```

```
[10/15/24 15:14:10] INFO PromptTask 63a0c734059c42808c87dff351adc8ab
Input: Respond to this user: Hello! My favorite color is blue, and my social security number is xxx-xx-xxxx.
[10/15/24 15:14:11] INFO PromptTask 63a0c734059c42808c87dff351adc8ab
Output: {
"original_input": "Respond to this user: Hello! My favorite color is blue, and my social security number is 123-45-6789.",
"masked_input": "Respond to this user: Hello! My favorite color is blue, and my social security number is xxx-xx-xxxx.",
"output": "Hello! It's great to hear that your favorite color is blue. However, it's important to keep your personal information, like your
social security number, private and secure. If you have any questions or need assistance, feel free to ask!"
}
```

## Prompt Task

For general purpose prompting, use the [PromptTask](../../reference/griptape/tasks/prompt_task.md):
Expand Down
35 changes: 35 additions & 0 deletions griptape/mixins/runnable_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Generic, Optional, TypeVar, cast

from attrs import define, field

# Generics magic that allows us to reference the type of the class that is implementing the mixin
T = TypeVar("T", bound="RunnableMixin")


@define()
class RunnableMixin(ABC, Generic[T]):
"""Mixin for classes that can be "run".
Implementing classes should pass themselves as the generic type to ensure that the correct type is used in the callbacks.
Attributes:
on_before_run: Optional callback that is called at the very beginning of the `run` method.
on_after_run: Optional callback that is called at the very end of the `run` method.
"""

on_before_run: Optional[Callable[[T], None]] = field(kw_only=True, default=None)
on_after_run: Optional[Callable[[T], None]] = field(kw_only=True, default=None)

def before_run(self, *args, **kwargs) -> Any:
if self.on_before_run is not None:
self.on_before_run(cast(T, self))

@abstractmethod
def run(self, *args, **kwargs) -> Any: ...

def after_run(self, *args, **kwargs) -> Any:
if self.on_after_run is not None:
self.on_after_run(cast(T, self))
2 changes: 1 addition & 1 deletion griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:

@observable
def try_run(self, *args) -> Agent:
self.task.execute()
self.task.run()

return self
2 changes: 1 addition & 1 deletion griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __run_from_task(self, task: Optional[BaseTask]) -> None:
if task is None:
return
else:
if isinstance(task.execute(), ErrorArtifact) and self.fail_fast:
if isinstance(task.run(), ErrorArtifact) and self.fail_fast:
return
else:
self.__run_from_task(next(iter(task.children), None))
5 changes: 4 additions & 1 deletion griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory, Run
from griptape.mixins.rule_mixin import RuleMixin
from griptape.mixins.runnable_mixin import RunnableMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
Expand All @@ -21,7 +22,7 @@


@define
class Structure(ABC, RuleMixin, SerializableMixin):
class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
_tasks: list[Union[BaseTask, list[BaseTask]]] = field(
factory=list, kw_only=True, alias="tasks", metadata={"serializable": True}
Expand Down Expand Up @@ -139,6 +140,7 @@ def resolve_relationships(self) -> None:

@observable
def before_run(self, args: Any) -> None:
RunnableMixin.before_run(self, args)
self._execution_args = args

[task.reset() for task in self.tasks]
Expand All @@ -155,6 +157,7 @@ def before_run(self, args: Any) -> None:

@observable
def after_run(self) -> None:
RunnableMixin.after_run(self)
if self.conversation_memory and self.output_task.output is not None:
run = Run(input=self.input_task.input, output=self.output_task.output)

Expand Down
4 changes: 2 additions & 2 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def try_run(self, *args) -> Workflow:
ordered_tasks = self.order_tasks()

for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor.submit(task.execute)
if task.can_run():
future = self.futures_executor.submit(task.run)
futures_list[future] = task

# Wait for all tasks to complete
Expand Down
12 changes: 6 additions & 6 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ def before_run(self) -> None:
]
logger.info("".join(parts))

def run(self) -> BaseArtifact:
def try_run(self) -> BaseArtifact:
try:
if any(isinstance(a.output, ErrorArtifact) for a in self.actions):
errors = [a.output.value for a in self.actions if isinstance(a.output, ErrorArtifact)]

self.output = ErrorArtifact("\n\n".join(errors))
else:
results = self.execute_actions(self.actions)
results = self.run_actions(self.actions)

actions_output = []
for result in results:
Expand All @@ -138,13 +138,13 @@ def run(self) -> BaseArtifact:
else:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
return utils.execute_futures_list([self.futures_executor.submit(self.execute_action, a) for a in actions])
def run_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
return utils.execute_futures_list([self.futures_executor.submit(self.run_action, a) for a in actions])

def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
def run_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
if action.path is not None:
output = action.tool.execute(getattr(action.tool, action.path), self, action)
output = action.tool.run(getattr(action.tool, action.path), self, action)
else:
output = ErrorArtifact("action path not found")
else:
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/audio_transcription_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ class AudioTranscriptionTask(BaseAudioInputTask):
kw_only=True,
)

def run(self) -> TextArtifact:
def try_run(self) -> TextArtifact:
return self.audio_transcription_engine.run(self.input)
40 changes: 21 additions & 19 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@

from attrs import Factory, define, field

from griptape.artifacts import ErrorArtifact
from griptape.artifacts import BaseArtifact, ErrorArtifact
from griptape.configs import Defaults
from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.mixins.runnable_mixin import RunnableMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.memory.meta import BaseMetaEntry
from griptape.structures import Structure

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class BaseTask(FuturesExecutorMixin, SerializableMixin, ABC):
class BaseTask(FuturesExecutorMixin, SerializableMixin, RunnableMixin["BaseTask"], ABC):
class State(Enum):
PENDING = 1
EXECUTING = 2
Expand Down Expand Up @@ -137,6 +137,7 @@ def is_executing(self) -> bool:
return self.state == BaseTask.State.EXECUTING

def before_run(self) -> None:
RunnableMixin.before_run(self)
if self.structure is not None:
EventBus.publish_event(
StartTaskEvent(
Expand All @@ -148,25 +149,13 @@ def before_run(self) -> None:
),
)

def after_run(self) -> None:
if self.structure is not None:
EventBus.publish_event(
FinishTaskEvent(
task_id=self.id,
task_parent_ids=self.parent_ids,
task_child_ids=self.child_ids,
task_input=self.input,
task_output=self.output,
),
)

def execute(self) -> Optional[BaseArtifact]:
def run(self) -> BaseArtifact:
try:
self.state = BaseTask.State.EXECUTING

self.before_run()

self.output = self.run()
self.output = self.try_run()

self.after_run()
except Exception as e:
Expand All @@ -178,7 +167,20 @@ def execute(self) -> Optional[BaseArtifact]:

return self.output

def can_execute(self) -> bool:
def after_run(self) -> None:
RunnableMixin.after_run(self)
if self.structure is not None:
EventBus.publish_event(
FinishTaskEvent(
task_id=self.id,
task_parent_ids=self.parent_ids,
task_child_ids=self.child_ids,
task_input=self.input,
task_output=self.output,
),
)

def can_run(self) -> bool:
return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents)

def reset(self) -> BaseTask:
Expand All @@ -188,7 +190,7 @@ def reset(self) -> BaseTask:
return self

@abstractmethod
def run(self) -> BaseArtifact: ...
def try_run(self) -> BaseArtifact: ...

@property
def full_context(self) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/code_execution_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
class CodeExecutionTask(BaseTextInputTask):
run_fn: Callable[[CodeExecutionTask], BaseArtifact] = field(kw_only=True)

def run(self) -> BaseArtifact:
def try_run(self) -> BaseArtifact:
return self.run_fn(self)
2 changes: 1 addition & 1 deletion griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ class ExtractionTask(BaseTextInputTask):
extraction_engine: BaseExtractionEngine = field(kw_only=True)
args: dict = field(kw_only=True, factory=dict)

def run(self) -> ListArtifact | ErrorArtifact:
def try_run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract_artifacts(ListArtifact([self.input]), rulesets=self.rulesets, **self.args)
2 changes: 1 addition & 1 deletion griptape/tasks/image_query_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def input(
) -> None:
self._input = value

def run(self) -> TextArtifact:
def try_run(self) -> TextArtifact:
query = self.input.value[0]

if all(isinstance(artifact, ImageArtifact) for artifact in self.input.value[1:]):
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/inpainting_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def input(
) -> None:
self._input = value

def run(self) -> ImageArtifact:
def try_run(self) -> ImageArtifact:
prompt_artifact = self.input[0]

image_artifact = self.input[1]
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/outpainting_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def input(
) -> None:
self._input = value

def run(self) -> ImageArtifact:
def try_run(self) -> ImageArtifact:
prompt_artifact = self.input[0]

image_artifact = self.input[1]
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def input(self) -> TextArtifact:
def input(self, value: TextArtifact) -> None:
self._input = value

def run(self) -> ImageArtifact:
def try_run(self) -> ImageArtifact:
image_artifact = self.image_generation_engine.run(
prompts=[self.input.to_text()],
rulesets=self.rulesets,
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def after_run(self) -> None:

logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text())

def run(self) -> BaseArtifact:
def try_run(self) -> BaseArtifact:
message = self.prompt_driver.run(self.prompt_stack)

return message.to_artifact()
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/rag_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class RagTask(BaseTextInputTask):
rag_engine: RagEngine = field(kw_only=True, default=Factory(lambda: RagEngine()))

def run(self) -> BaseArtifact:
def try_run(self) -> BaseArtifact:
outputs = self.rag_engine.process_query(self.input.to_text()).outputs

if len(outputs) > 0:
Expand Down
Loading

0 comments on commit 1a003bd

Please sign in to comment.