From adf4836e0ee9c79da8f073a84716cb872dd7c519 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:23:25 -0700 Subject: [PATCH 01/63] Add global event bus --- CHANGELOG.md | 3 + docs/griptape-framework/misc/events.md | 61 ++++++++++--------- griptape/config/base_structure_config.py | 40 ------------ .../base_audio_transcription_driver.py | 10 +-- .../embedding/base_embedding_driver.py | 4 +- .../base_image_generation_driver.py | 10 +-- .../image_query/base_image_query_driver.py | 10 +-- .../base_conversation_memory_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 16 ++--- .../base_text_to_speech_driver.py | 9 +-- .../vector/base_vector_store_driver.py | 4 +- griptape/events/__init__.py | 2 + .../event_bus.py} | 5 +- griptape/mixins/__init__.py | 2 - griptape/structures/structure.py | 12 ++-- griptape/tasks/actions_subtask.py | 6 +- griptape/tasks/base_task.py | 6 +- griptape/utils/stream.py | 9 +-- tests/unit/config/test_structure_config.py | 35 ----------- tests/unit/conftest.py | 12 ++++ .../test_base_audio_transcription_driver.py | 4 +- .../test_base_image_generation_driver.py | 9 +-- .../test_base_image_query_driver.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 7 +-- .../test_base_audio_transcription_driver.py | 4 +- tests/unit/events/test_event_bus.py | 45 ++++++++++++++ tests/unit/events/test_event_listener.py | 29 ++++----- tests/unit/mixins/test_events_mixin.py | 59 ------------------ tests/unit/tasks/test_base_task.py | 5 +- 29 files changed, 176 insertions(+), 250 deletions(-) rename griptape/{mixins/event_publisher_mixin.py => events/event_bus.py} (96%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/events/test_event_bus.py delete mode 100644 tests/unit/mixins/test_events_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3582ec02c..ea88983f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. ### Changed +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 1f50fd6d0..187321dc6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,15 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, + EventBus ) def handler(event: BaseEvent): print(event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener( handler, event_types=[ @@ -44,7 +43,8 @@ agent = Agent( ], ) ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -69,7 +69,8 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener +from griptape.events import BaseEvent, EventListener, EventBus + def handler1(event: BaseEvent): @@ -79,13 +80,12 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -131,7 +131,7 @@ Handler 2 list: - return [ - self.prompt_driver, - self.image_generation_driver, - self.image_query_driver, - self.embedding_driver, - self.vector_store_driver, - self.conversation_memory_driver, - self.text_to_speech_driver, - self.audio_transcription_driver, - ] - - @property - def structure(self) -> Optional[Structure]: - return self._structure - - @structure.setter - def structure(self, structure: Structure) -> None: - if structure != self.structure: - event_publisher_drivers = [ - driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) - ] - - for driver in event_publisher_drivers: - if self._event_listener is not None: - driver.remove_event_listener(self._event_listener) - - self._event_listener = EventListener(structure.publish_event) - for driver in event_publisher_drivers: - driver.add_event_listener(self._event_listener) - - self._structure = structure - def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() merged_config = dict_merge(base_config, config) diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index c81ea1d5b..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import AudioArtifact, TextArtifact @define -class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - self.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - self.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 690726060..8998f00e5 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -7,7 +7,7 @@ from attrs import define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import TextArtifact @@ -15,7 +15,7 @@ @define -class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index f500d6d09..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @define -class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - self.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b39f198d4..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,24 +5,24 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index f13b82c29..1caeb902f 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory -class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, memory: BaseConversationMemory) -> None: ... diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e5fd0408d..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,8 +16,8 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from collections.abc import Iterator @@ -26,7 +26,7 @@ @define(kw_only=True) -class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublishe use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - self.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - self.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - self.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - self.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index 788d92974..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,23 +5,24 @@ from attrs import define, field +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @define -class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - self.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - self.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index d1da78188..ed1f2d589 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -10,14 +10,14 @@ from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,6 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -48,4 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", + "EventBus", ] diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/events/event_bus.py similarity index 96% rename from griptape/mixins/event_publisher_mixin.py rename to griptape/events/event_bus.py index 67a302ed6..9239e66bd 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/events/event_bus.py @@ -9,7 +9,7 @@ @define -class EventPublisherMixin: +class _EventBus: event_listeners: list[EventListener] = field(factory=list, kw_only=True) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: @@ -32,3 +32,6 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) + + +EventBus = _EventBus() diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 944027c59..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,7 +4,6 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin -from .event_publisher_mixin import EventPublisherMixin __all__ = [ "ActivityMixin", @@ -13,5 +12,4 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", - "EventPublisherMixin", ] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 079e0b741..df7113c23 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,13 +28,11 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events.finish_structure_run_event import FinishStructureRunEvent -from griptape.events.start_structure_run_event import StartStructureRunEvent +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.mixins import EventPublisherMixin from griptape.utils import deprecation_warn if TYPE_CHECKING: @@ -44,7 +42,7 @@ @define -class Structure(ABC, EventPublisherMixin): +class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @@ -97,8 +95,6 @@ def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self - self.config.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) @@ -261,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - self.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -273,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - self.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index cde59d0ef..07f49f52a 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - self.structure.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - self.structure.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 8c50e4df9..9a8361e6c 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import FinishTaskEvent, StartTaskEvent +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index bf33e5df8..4a7899b2a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,10 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events.completion_chunk_event import CompletionChunkEvent -from griptape.events.event_listener import EventListener -from griptape.events.finish_prompt_event import FinishPromptEvent -from griptape.events.finish_structure_run_event import FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -64,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - self.structure.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - self.structure.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,7 +1,6 @@ import pytest from griptape.config import StructureConfig -from griptape.structures import Agent class TestStructureConfig: @@ -61,37 +60,3 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 - - def test_drivers(self, config): - assert config.drivers == [ - config.prompt_driver, - config.image_generation_driver, - config.image_query_driver, - config.embedding_driver, - config.vector_store_driver, - config.conversation_memory_driver, - config.text_to_speech_driver, - config.audio_transcription_driver, - ] - - def test_structure(self, config): - structure_1 = Agent( - config=config, - ) - - assert config.structure == structure_1 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 - - structure_2 = Agent( - config=config, - ) - assert config.structure == structure_2 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..0be2f9758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from griptape.events import EventBus + + +@pytest.fixture(autouse=True) +def event_bus(): + EventBus.event_listeners = [] + + yield EventBus + + EventBus.event_listeners = [] diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 519e40f57..fc41837fd 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 7447b2c08..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -14,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -30,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -52,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -80,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index 14de15f2d..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 2708b0a88..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(EventPublisherMixin, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) @@ -42,8 +42,7 @@ def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): - pipeline = Pipeline() - result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[])) + result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 8af5dc827..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py new file mode 100644 index 000000000..fd862913e --- /dev/null +++ b/tests/unit/events/test_event_bus.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from griptape.events import EventBus, EventListener +from tests.mocks.mock_event import MockEvent + + +class TestEventBus: + def test_add_event_listeners(self): + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 + + def test_add_event_listener(self): + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listener(self): + listener = EventListener() + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) + + assert len(EventBus.event_listeners) == 0 + + def test_remove_unknown_event_listener(self): + EventBus.remove_event_listener(EventListener()) + + def test_publish_event(self): + # Given + mock_handler = Mock() + mock_handler.return_value = None + EventBus.event_listeners = [EventListener(handler=mock_handler)] + mock_event = MockEvent() + + # When + EventBus.publish_event(mock_event) + + # Then + mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b245c2be9..5601aef34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -37,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -58,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - pipeline.event_listeners = [ + EventBus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -86,25 +87,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - pipeline.event_listeners = [] + EventBus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(pipeline.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_3) - pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) - assert len(pipeline.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py deleted file mode 100644 index 99f5541ba..000000000 --- a/tests/unit/mixins/test_events_mixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -from griptape.events import EventListener -from griptape.mixins import EventPublisherMixin -from tests.mocks.mock_event import MockEvent - - -class TestEventsMixin: - def test_init(self): - assert EventPublisherMixin() - - def test_add_event_listeners(self): - mixin = EventPublisherMixin() - - mixin.add_event_listeners([EventListener(), EventListener()]) - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listeners(self): - mixin = EventPublisherMixin() - - listeners = [EventListener(), EventListener()] - mixin.add_event_listeners(listeners) - mixin.remove_event_listeners(listeners) - assert len(mixin.event_listeners) == 0 - - def test_add_event_listener(self): - mixin = EventPublisherMixin() - - mixin.add_event_listener(EventListener()) - mixin.add_event_listener(EventListener()) - - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listener(self): - mixin = EventPublisherMixin() - - listener = EventListener() - mixin.add_event_listener(listener) - mixin.remove_event_listener(listener) - - assert len(mixin.event_listeners) == 0 - - def test_remove_unknown_event_listener(self): - mixin = EventPublisherMixin() - - mixin.remove_event_listener(EventListener()) - - def test_publish_event(self): - # Given - mock_handler = Mock() - mock_handler.return_value = None - mixin = EventPublisherMixin(event_listeners=[EventListener(handler=mock_handler)]) - mock_event = MockEvent() - - # When - mixin.publish_event(mock_event) - - # Then - mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 4f4b43d40..636515106 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -15,11 +16,11 @@ class TestBaseTask: @pytest.fixture() def task(self): + EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], - event_listeners=[EventListener(handler=Mock())], ) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -117,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert task.structure.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 From ae20c82af62747b04ec1053a0674dadb10db8fda Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 5 Aug 2024 17:31:12 -0700 Subject: [PATCH 02/63] WIP --- griptape/config/__init__.py | 2 ++ griptape/config/config.py | 3 ++ ...zon_dynamodb_conversation_memory_driver.py | 5 ++- .../local_conversation_memory_driver.py | 8 +++-- .../redis_conversation_memory_driver.py | 5 ++- .../audio/audio_transcription_engine.py | 7 +++-- .../engines/audio/text_to_speech_engine.py | 8 +++-- .../extraction/base_extraction_engine.py | 3 +- .../image/base_image_generation_engine.py | 8 +++-- .../engines/image_query/image_query_engine.py | 6 ++-- .../response/prompt_response_rag_module.py | 3 +- .../vector_store_retrieval_rag_module.py | 3 +- .../engines/summary/prompt_summary_engine.py | 6 ++-- .../structure/base_conversation_memory.py | 7 +++-- .../structure/summary_conversation_memory.py | 19 ++---------- .../task/storage/text_artifact_storage.py | 5 +-- griptape/structures/structure.py | 31 +++++++------------ griptape/tasks/audio_transcription_task.py | 22 ++----------- griptape/tasks/csv_extraction_task.py | 17 ++-------- griptape/tasks/extraction_task.py | 6 +--- griptape/tasks/image_query_task.py | 17 ++-------- .../tasks/inpainting_image_generation_task.py | 22 ++----------- griptape/tasks/json_extraction_task.py | 17 ++-------- .../outpainting_image_generation_task.py | 23 ++------------ .../tasks/prompt_image_generation_task.py | 22 ++----------- griptape/tasks/prompt_task.py | 12 ++----- griptape/tasks/rag_task.py | 23 ++------------ griptape/tasks/text_summary_task.py | 19 ++---------- griptape/tasks/text_to_speech_task.py | 19 ++---------- .../tasks/variation_image_generation_task.py | 22 ++----------- tests/unit/conftest.py | 16 ++++++++++ tests/unit/events/test_event_listener.py | 5 +-- .../test_summary_conversation_memory.py | 3 +- .../tasks/test_audio_transcription_task.py | 3 +- tests/unit/tasks/test_csv_extraction_task.py | 9 ++---- tests/unit/tasks/test_image_query_task.py | 9 +----- .../test_inpainting_image_generation_task.py | 9 +----- tests/unit/tasks/test_json_extraction_task.py | 13 ++------ .../test_outpainting_image_generation_task.py | 9 +----- .../test_prompt_image_generation_task.py | 11 +------ tests/unit/tasks/test_prompt_task.py | 11 +------ tests/unit/tasks/test_text_summary_task.py | 11 +------ tests/unit/tasks/test_text_to_speech_task.py | 3 +- tests/unit/tasks/test_toolkit_task.py | 9 ++++-- .../test_variation_image_generation_task.py | 9 +----- 45 files changed, 144 insertions(+), 356 deletions(-) create mode 100644 griptape/config/config.py diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 541eb0db0..4b0f8eb28 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -9,6 +9,7 @@ from .anthropic_structure_config import AnthropicStructureConfig from .google_structure_config import GoogleStructureConfig from .cohere_structure_config import CohereStructureConfig +from .config import Config __all__ = [ @@ -21,4 +22,5 @@ "AnthropicStructureConfig", "GoogleStructureConfig", "CohereStructureConfig", + "Config", ] diff --git a/griptape/config/config.py b/griptape/config/config.py new file mode 100644 index 000000000..e3017f8b6 --- /dev/null +++ b/griptape/config/config.py @@ -0,0 +1,3 @@ +from .openai_structure_config import OpenAiStructureConfig + +Config = OpenAiStructureConfig() diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index e52174c28..44f214d7c 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -5,12 +5,13 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory from griptape.utils import import_optional_dependency if TYPE_CHECKING: import boto3 + from griptape.memory.structure import BaseConversationMemory + @define class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): @@ -38,6 +39,8 @@ def store(self, memory: BaseConversationMemory) -> None: ) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index 8d6399e13..f7b6e7d6e 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -2,12 +2,14 @@ import os from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory + +if TYPE_CHECKING: + from griptape.memory.structure import BaseConversationMemory @define @@ -18,6 +20,8 @@ def store(self, memory: BaseConversationMemory) -> None: Path(self.file_path).write_text(memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + if not os.path.exists(self.file_path): return None memory = BaseConversationMemory.from_json(Path(self.file_path).read_text()) diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 2ba3737e8..9afc2f204 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -6,12 +6,13 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory from griptape.utils.import_utils import import_optional_dependency if TYPE_CHECKING: from redis import Redis + from griptape.memory.structure import BaseConversationMemory + @define class RedisConversationMemoryDriver(BaseConversationMemoryDriver): @@ -54,6 +55,8 @@ def store(self, memory: BaseConversationMemory) -> None: self.client.hset(self.index, self.conversation_id, memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + key = self.index memory_json = self.client.hget(key, self.conversation_id) if memory_json: diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index 3631b2d17..a3769842d 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -1,12 +1,15 @@ -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.config import Config from griptape.drivers import BaseAudioTranscriptionDriver @define class AudioTranscriptionEngine: - audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True) + audio_transcription_driver: BaseAudioTranscriptionDriver = field( + default=Factory(lambda: Config.audio_transcription_driver), kw_only=True + ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: return self.audio_transcription_driver.try_run(audio) diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index af5d5a494..361ecc127 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @@ -11,7 +13,9 @@ @define class TextToSpeechEngine: - text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True) + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: Config.text_to_speech_driver), kw_only=True + ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: return self.text_to_speech_driver.try_text_to_audio(prompts=prompts) diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index f263ee0aa..3ff6a96e3 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -6,6 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact @@ -17,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 47a853871..eabf38be3 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @@ -13,7 +15,9 @@ @define class BaseImageGenerationEngine(ABC): - image_generation_driver: BaseImageGenerationDriver = field(kw_only=True) + image_generation_driver: BaseImageGenerationDriver = field( + kw_only=True, default=Factory(lambda: Config.image_generation_driver) + ) @abstractmethod def run(self, prompts: list[str], *args, rulesets: Optional[list[Ruleset]], **kwargs) -> ImageArtifact: ... diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index d0a1e99d4..ed6a64ee3 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @@ -11,7 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(kw_only=True) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.image_query_driver), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 0b7cbd953..2e7b486b6 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact +from griptape.config import Config from griptape.engines.rag.modules import BaseResponseRagModule from griptape.utils import J2 @@ -16,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field() + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 0a07b4c50..b0deca67d 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field from griptape import utils +from griptape.config import Config from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -17,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field() + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index c5d8e695d..d06ebaa2f 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -6,8 +6,8 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack +from griptape.config import Config from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index c3d3c501e..8794288c8 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -3,9 +3,10 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import Factory, define, field from griptape.common import PromptStack +from griptape.config import Config from griptape.mixins import SerializableMixin if TYPE_CHECKING: @@ -16,7 +17,9 @@ @define class BaseConversationMemory(SerializableMixin, ABC): - driver: Optional[BaseConversationMemoryDriver] = field(default=None, kw_only=True) + driver: Optional[BaseConversationMemoryDriver] = field( + default=Factory(lambda: Config.conversation_memory_driver), kw_only=True + ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) autoload: bool = field(default=True, kw_only=True) diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index f29bbb767..807775d63 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -5,8 +5,8 @@ from attrs import Factory, define, field -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack +from griptape.config import Config from griptape.memory.structure import ConversationMemory from griptape.utils import J2 @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - _prompt_driver: BasePromptDriver = field(kw_only=True, default=None, alias="prompt_driver") + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.prompt_driver)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) @@ -27,19 +27,6 @@ class SummaryConversationMemory(ConversationMemory): kw_only=True, ) - @property - def prompt_driver(self) -> BasePromptDriver: - if self._prompt_driver is None: - if self.structure is not None: - self._prompt_driver = self.structure.config.prompt_driver - else: - raise ValueError("Prompt Driver is not set.") - return self._prompt_driver - - @prompt_driver.setter - def prompt_driver(self, value: BasePromptDriver) -> None: - self._prompt_driver = value - def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: stack = PromptStack() if self.summary: diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8e66c5aba..8a918c5f2 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Any, Optional -from attrs import Attribute, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.config import Config from griptape.engines.rag import RagContext, RagEngine from griptape.memory.task.storage import BaseArtifactStorage @@ -15,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field() + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..9f1fa9a2b 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -11,7 +11,7 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig +from griptape.config import BaseStructureConfig, Config from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -59,10 +59,7 @@ class Structure(ABC): custom_logger: Optional[Logger] = field(default=None, kw_only=True) logger_level: int = field(default=logging.INFO, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( - default=Factory( - lambda self: ConversationMemory(driver=self.config.conversation_memory_driver), - takes_self=True, - ), + default=Factory(lambda: ConversationMemory()), kw_only=True, ) rag_engine: RagEngine = field(default=Factory(lambda self: self.default_rag_engine, takes_self=True), kw_only=True) @@ -154,8 +151,6 @@ def finished_tasks(self) -> list[BaseTask]: @property def default_config(self) -> BaseStructureConfig: if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None: - config = StructureConfig() - prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") if self.prompt_driver is None else self.prompt_driver embedding_driver = OpenAiEmbeddingDriver() if self.embedding_driver is None else self.embedding_driver @@ -165,26 +160,24 @@ def default_config(self) -> BaseStructureConfig: vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - config.prompt_driver = prompt_driver - config.vector_store_driver = vector_store_driver - config.embedding_driver = embedding_driver - else: - config = OpenAiStructureConfig() + Config.prompt_driver = prompt_driver + Config.vector_store_driver = vector_store_driver + Config.embedding_driver = embedding_driver - return config + return Config @property def default_rag_engine(self) -> RagEngine: return RagEngine( retrieval_stage=RetrievalRagStage( - retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=self.config.vector_store_driver)], + retrieval_modules=[VectorStoreRetrievalRagModule()], ), response_stage=ResponseRagStage( before_response_modules=[ RulesetsBeforeResponseRagModule(rulesets=self.rulesets), MetadataBeforeResponseRagModule(), ], - response_module=PromptResponseRagModule(prompt_driver=self.config.prompt_driver), + response_module=PromptResponseRagModule(), ), ) @@ -195,10 +188,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=self.config.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=self.config.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=self.config.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=self.config.prompt_driver), + vector_store_driver=Config.vector_store_driver, + summary_engine=PromptSummaryEngine(prompt_driver=Config.prompt_driver), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.prompt_driver), + json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.prompt_driver), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/audio_transcription_task.py b/griptape/tasks/audio_transcription_task.py index 3a4b17b9e..3d83cf7e7 100644 --- a/griptape/tasks/audio_transcription_task.py +++ b/griptape/tasks/audio_transcription_task.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import AudioTranscriptionEngine from griptape.tasks.base_audio_input_task import BaseAudioInputTask @@ -13,26 +13,10 @@ @define class AudioTranscriptionTask(BaseAudioInputTask): - _audio_transcription_engine: AudioTranscriptionEngine = field( - default=None, + audio_transcription_engine: AudioTranscriptionEngine = field( + default=Factory(lambda: AudioTranscriptionEngine()), kw_only=True, - alias="audio_transcription_engine", ) - @property - def audio_transcription_engine(self) -> AudioTranscriptionEngine: - if self._audio_transcription_engine is None: - if self.structure is not None: - self._audio_transcription_engine = AudioTranscriptionEngine( - audio_transcription_driver=self.structure.config.audio_transcription_driver, - ) - else: - raise ValueError("Audio Generation Engine is not set.") - return self._audio_transcription_engine - - @audio_transcription_engine.setter - def audio_transcription_engine(self, value: AudioTranscriptionEngine) -> None: - self._audio_transcription_engine = value - def run(self) -> TextArtifact: return self.audio_transcription_engine.run(self.input) diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py index 538596dfe..c252893de 100644 --- a/griptape/tasks/csv_extraction_task.py +++ b/griptape/tasks/csv_extraction_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import CsvExtractionEngine from griptape.tasks import ExtractionTask @@ -8,17 +8,4 @@ @define class CsvExtractionTask(ExtractionTask): - _extraction_engine: CsvExtractionEngine = field(default=None, kw_only=True, alias="extraction_engine") - - @property - def extraction_engine(self) -> CsvExtractionEngine: - if self._extraction_engine is None: - if self.structure is not None: - self._extraction_engine = CsvExtractionEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Extraction Engine is not set.") - return self._extraction_engine - - @extraction_engine.setter - def extraction_engine(self, value: CsvExtractionEngine) -> None: - self._extraction_engine = value + extraction_engine: CsvExtractionEngine = field(default=Factory(lambda: CsvExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index d8f492693..a1c18eff0 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -13,12 +13,8 @@ @define class ExtractionTask(BaseTextInputTask): - _extraction_engine: BaseExtractionEngine = field(kw_only=True, default=None, alias="extraction_engine") + extraction_engine: BaseExtractionEngine = field(kw_only=True) args: dict = field(kw_only=True) - @property - def extraction_engine(self) -> BaseExtractionEngine: - return self._extraction_engine - def run(self) -> ListArtifact | ErrorArtifact: return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args) diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index ea1b53739..1c77bbc0a 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import ImageQueryEngine @@ -24,7 +24,7 @@ class ImageQueryTask(BaseTask): image_query_engine: The engine used to execute the query. """ - _image_query_engine: ImageQueryEngine = field(default=None, kw_only=True, alias="image_query_engine") + image_query_engine: ImageQueryEngine = field(default=Factory(lambda: ImageQueryEngine()), kw_only=True) _input: ( tuple[str, list[ImageArtifact]] | tuple[TextArtifact, list[ImageArtifact]] @@ -62,19 +62,6 @@ def input( ) -> None: self._input = value - @property - def image_query_engine(self) -> ImageQueryEngine: - if self._image_query_engine is None: - if self.structure is not None: - self._image_query_engine = ImageQueryEngine(image_query_driver=self.structure.config.image_query_driver) - else: - raise ValueError("Image Query Engine is not set.") - return self._image_query_engine - - @image_query_engine.setter - def image_query_engine(self, value: ImageQueryEngine) -> None: - self._image_query_engine = value - def run(self) -> TextArtifact: query = self.input.value[0] diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 2096c60e4..ec8014672 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import InpaintingImageGenerationEngine @@ -28,10 +28,9 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: InpaintingImageGenerationEngine = field( - default=None, + image_generation_engine: InpaintingImageGenerationEngine = field( + default=Factory(lambda: InpaintingImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: ( tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact @@ -60,21 +59,6 @@ def input( ) -> None: self._input = value - @property - def image_generation_engine(self) -> InpaintingImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = InpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: InpaintingImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/tasks/json_extraction_task.py b/griptape/tasks/json_extraction_task.py index ce51b316f..94db187da 100644 --- a/griptape/tasks/json_extraction_task.py +++ b/griptape/tasks/json_extraction_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import JsonExtractionEngine from griptape.tasks import ExtractionTask @@ -8,17 +8,4 @@ @define class JsonExtractionTask(ExtractionTask): - _extraction_engine: JsonExtractionEngine = field(default=None, kw_only=True, alias="extraction_engine") - - @property - def extraction_engine(self) -> JsonExtractionEngine: - if self._extraction_engine is None: - if self.structure is not None: - self._extraction_engine = JsonExtractionEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Extraction Engine is not set.") - return self._extraction_engine - - @extraction_engine.setter - def extraction_engine(self, value: JsonExtractionEngine) -> None: - self._extraction_engine = value + extraction_engine: JsonExtractionEngine = field(default=Factory(lambda: JsonExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index a23fafd0f..bee3293a1 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import OutpaintingImageGenerationEngine @@ -28,10 +28,9 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: OutpaintingImageGenerationEngine = field( - default=None, + image_generation_engine: OutpaintingImageGenerationEngine = field( + default=Factory(lambda: OutpaintingImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: ( tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact @@ -60,22 +59,6 @@ def input( ) -> None: self._input = value - @property - def image_generation_engine(self) -> OutpaintingImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = OutpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: OutpaintingImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 66cffab3e..efc3faf2d 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.engines import PromptImageGenerationEngine @@ -30,10 +30,9 @@ class PromptImageGenerationTask(BaseImageGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) - _image_generation_engine: PromptImageGenerationEngine = field( - default=None, + image_generation_engine: PromptImageGenerationEngine = field( + default=Factory(lambda: PromptImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) @property @@ -49,21 +48,6 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - @property - def image_generation_engine(self) -> PromptImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = PromptImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: PromptImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: image_artifact = self.image_generation_engine.run( prompts=[self.input.to_text()], diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 386ebe239..9f698787f 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -6,6 +6,7 @@ from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack +from griptape.config import Config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -16,7 +17,7 @@ @define class PromptTask(RuleMixin, BaseTask): - _prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True, alias="prompt_driver") + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, @@ -56,15 +57,6 @@ def prompt_stack(self) -> PromptStack: return stack - @property - def prompt_driver(self) -> BasePromptDriver: - if self._prompt_driver is None: - if self.structure is not None: - self._prompt_driver = self.structure.config.prompt_driver - else: - raise ValueError("Prompt Driver is not set") - return self._prompt_driver - def default_system_template_generator(self, _: PromptTask) -> str: return J2("tasks/prompt_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index 3f88f34d1..97b295209 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -1,32 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.engines.rag import RagEngine from griptape.tasks import BaseTextInputTask -if TYPE_CHECKING: - from griptape.engines.rag import RagEngine - @define class RagTask(BaseTextInputTask): - _rag_engine: RagEngine = field(kw_only=True, default=None, alias="rag_engine") - - @property - def rag_engine(self) -> RagEngine: - if self._rag_engine is None: - if self.structure is not None: - self._rag_engine = self.structure.rag_engine - else: - raise ValueError("rag_engine is not set.") - return self._rag_engine - - @rag_engine.setter - def rag_engine(self, value: RagEngine) -> None: - self._rag_engine = value + rag_engine: RagEngine = field(kw_only=True, default=Factory(lambda: RagEngine())) def run(self) -> BaseArtifact: result = self.rag_engine.process_query(self.input.to_text()).output diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 5bd1b547e..dc1a7b8be 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.engines import PromptSummaryEngine @@ -14,20 +14,7 @@ @define class TextSummaryTask(BaseTextInputTask): - _summary_engine: Optional[BaseSummaryEngine] = field(default=None, alias="summary_engine") - - @property - def summary_engine(self) -> Optional[BaseSummaryEngine]: - if self._summary_engine is None: - if self.structure is not None: - self._summary_engine = PromptSummaryEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Summary Engine is not set.") - return self._summary_engine - - @summary_engine.setter - def summary_engine(self, value: BaseSummaryEngine) -> None: - self._summary_engine = value + summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True) def run(self) -> TextArtifact: return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets)) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 3ca503dfe..680a67603 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.engines import TextToSpeechEngine @@ -19,7 +19,7 @@ class TextToSpeechTask(BaseAudioGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) - _text_to_speech_engine: TextToSpeechEngine = field(default=None, kw_only=True, alias="text_to_speech_engine") + text_to_speech_engine: TextToSpeechEngine = field(default=Factory(lambda: TextToSpeechEngine()), kw_only=True) @property def input(self) -> TextArtifact: @@ -34,21 +34,6 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - @property - def text_to_speech_engine(self) -> TextToSpeechEngine: - if self._text_to_speech_engine is None: - if self.structure is not None: - self._text_to_speech_engine = TextToSpeechEngine( - text_to_speech_driver=self.structure.config.text_to_speech_driver, - ) - else: - raise ValueError("Audio Generation Engine is not set.") - return self._text_to_speech_engine - - @text_to_speech_engine.setter - def text_to_speech_engine(self, value: TextToSpeechEngine) -> None: - self._text_to_speech_engine = value - def run(self) -> AudioArtifact: audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index df4579efa..6295b1af6 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import VariationImageGenerationEngine @@ -28,10 +28,9 @@ class VariationImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: VariationImageGenerationEngine = field( - default=None, + image_generation_engine: VariationImageGenerationEngine = field( + default=Factory(lambda: VariationImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact = field( default=None, @@ -57,21 +56,6 @@ def input(self) -> ListArtifact: def input(self, value: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact]) -> None: self._input = value - @property - def image_generation_engine(self) -> VariationImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = VariationImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: VariationImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7d2f8203d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,6 +1,8 @@ import pytest +from griptape.config import Config from griptape.events import EventBus +from tests.mocks.mock_structure_config import MockStructureConfig @pytest.fixture(autouse=True) @@ -10,3 +12,17 @@ def event_bus(): yield EventBus EventBus.event_listeners = [] + + +@pytest.fixture(autouse=True) +def mock_config(): + mock_structure_config = MockStructureConfig() + Config.prompt_driver = mock_structure_config.prompt_driver + Config.image_generation_driver = mock_structure_config.image_generation_driver + Config.image_query_driver = mock_structure_config.image_query_driver + Config.embedding_driver = mock_structure_config.embedding_driver + Config.vector_store_driver = mock_structure_config.vector_store_driver + Config.text_to_speech_driver = mock_structure_config.text_to_speech_driver + Config.audio_transcription_driver = mock_structure_config.audio_transcription_driver + + return Config diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..92c5a3653 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -25,8 +25,9 @@ class TestEventListener: @pytest.fixture() - def pipeline(self): + def pipeline(self, mock_config): task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) + mock_config.prompt_driver = MockPromptDriver(stream=True) pipeline = Pipeline(prompt_driver=MockPromptDriver(stream=True)) pipeline.add_task(task) @@ -34,7 +35,7 @@ def pipeline(self): task.add_subtask(ActionsSubtask("foo")) return pipeline - def test_untyped_listeners(self, pipeline): + def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 4396c7b23..579be214e 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -5,7 +5,6 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestSummaryConversationMemory: @@ -85,7 +84,7 @@ def test_from_json(self): def test_config_prompt_driver(self): memory = SummaryConversationMemory() - pipeline = Pipeline(conversation_memory=memory, config=MockStructureConfig()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 734e111cf..4cc860c32 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -7,7 +7,6 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import AudioTranscriptionTask, BaseTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestAudioTranscriptionTask: @@ -34,7 +33,7 @@ def callable_input(task: BaseTask) -> AudioArtifact: def test_config_audio_transcription_engine(self, audio_artifact): task = AudioTranscriptionTask(audio_artifact) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine) diff --git a/tests/unit/tasks/test_csv_extraction_task.py b/tests/unit/tasks/test_csv_extraction_task.py index 7d37c3897..ec8f70b23 100644 --- a/tests/unit/tasks/test_csv_extraction_task.py +++ b/tests/unit/tasks/test_csv_extraction_task.py @@ -4,7 +4,6 @@ from griptape.structures import Agent from griptape.tasks import CsvExtractionTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestCsvExtractionTask: @@ -13,7 +12,7 @@ def task(self): return CsvExtractionTask(args={"column_names": ["test1"]}) def test_run(self, task): - agent = Agent(config=MockStructureConfig()) + agent = Agent() agent.add_task(task) @@ -23,11 +22,7 @@ def test_run(self, task): assert result.value[0].value == {"test1": "mock output"} def test_config_extraction_engine(self, task): - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.extraction_engine, CsvExtractionEngine) assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) - - def test_missing_extraction_engine(self, task): - with pytest.raises(ValueError): - task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index 447faa01c..01c116772 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, ImageQueryTask from tests.mocks.mock_image_query_driver import MockImageQueryDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestImageQueryTask: @@ -61,17 +60,11 @@ def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArti def test_config_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_query_engine, ImageQueryEngine) assert isinstance(task.image_query_engine.image_query_driver, MockImageQueryDriver) - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - - with pytest.raises(ValueError, match="Image Query Engine"): - task.image_query_engine # noqa: B018 - def test_run(self, image_query_engine, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) task.run() diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 61c437bb7..5c4507d49 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, InpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestInpaintingImageGenerationTask: @@ -51,13 +50,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, InpaintingImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index ba7d1ce30..0189e6679 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -5,7 +5,6 @@ from griptape.structures import Agent from griptape.tasks import JsonExtractionTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestJsonExtractionTask: @@ -13,11 +12,9 @@ class TestJsonExtractionTask: def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) - def test_run(self, task): - mock_config = MockStructureConfig() - assert isinstance(mock_config.prompt_driver, MockPromptDriver) + def test_run(self, task, mock_config): mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - agent = Agent(config=mock_config) + agent = Agent() agent.add_task(task) @@ -28,11 +25,7 @@ def test_run(self, task): assert result.value[1].value == '{"test_key_2": "test_value_2"}' def test_config_extraction_engine(self, task): - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.extraction_engine, JsonExtractionEngine) assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) - - def test_missing_extraction_engine(self, task): - with pytest.raises(ValueError): - task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index 593451120..ba5e52a82 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, OutpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestOutpaintingImageGenerationTask: @@ -51,13 +50,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, OutpaintingImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_image_generation_task.py b/tests/unit/tasks/test_prompt_image_generation_task.py index 1c4b639fb..3ad0302f2 100644 --- a/tests/unit/tasks/test_prompt_image_generation_task.py +++ b/tests/unit/tasks/test_prompt_image_generation_task.py @@ -1,13 +1,10 @@ from unittest.mock import Mock -import pytest - from griptape.artifacts import TextArtifact from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, PromptImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptImageGenerationTask: @@ -28,13 +25,7 @@ def callable_input(task: BaseTask) -> TextArtifact: def test_config_image_generation_engine_engine(self): task = PromptImageGenerationTask("foo bar") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, PromptImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_summary_engine(self): - task = PromptImageGenerationTask("foo bar") - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 083ea6da5..4a618e0d1 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,3 @@ -import pytest - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact @@ -9,7 +7,6 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptTask: @@ -30,16 +27,10 @@ def test_to_text(self): def test_config_prompt_driver(self): task = PromptTask("test") - Pipeline(config=MockStructureConfig()).add_task(task) + Pipeline().add_task(task) assert isinstance(task.prompt_driver, MockPromptDriver) - def test_missing_prompt_driver(self): - task = PromptTask("test") - - with pytest.raises(ValueError): - task.prompt_driver # noqa: B018 - def test_input(self): # Str task = PromptTask("test") diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index bb08f9d31..438d2bae4 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -1,10 +1,7 @@ -import pytest - from griptape.engines import PromptSummaryEngine from griptape.structures import Agent from griptape.tasks import TextSummaryTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestTextSummaryTask: @@ -26,13 +23,7 @@ def test_context_propagation(self): def test_config_summary_engine(self): task = TextSummaryTask("test") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.summary_engine, PromptSummaryEngine) assert isinstance(task.summary_engine.prompt_driver, MockPromptDriver) - - def test_missing_summary_engine(self): - task = TextSummaryTask("test") - - with pytest.raises(ValueError): - task.summary_engine # noqa: B018 diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index bf1f19d5a..3c629c69d 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -5,7 +5,6 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestTextToSpeechTask: @@ -26,7 +25,7 @@ def callable_input(task: BaseTask) -> TextArtifact: def test_config_text_to_speech_engine(self): task = TextToSpeechTask("foo bar") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.text_to_speech_engine, TextToSpeechEngine) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index cd5dd21f8..a47e4687b 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -170,8 +170,9 @@ def test_init(self): except ValueError: assert True - def test_run(self): + def test_run(self, mock_config): output = """Answer: done""" + mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) @@ -184,8 +185,9 @@ def test_run(self): assert len(task.subtasks) == 1 assert result.output_task.output.to_text() == "done" - def test_run_max_subtasks(self): + def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' + mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) @@ -197,8 +199,9 @@ def test_run_max_subtasks(self): assert len(task.subtasks) == 3 assert isinstance(task.output, ErrorArtifact) - def test_run_invalid_react_prompt(self): + def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" + mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index a910fb8e0..f6afbf03e 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, VariationImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestVariationImageGenerationTask: @@ -48,13 +47,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, VariationImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_summary_engine(self, text_artifact, image_artifact): - task = VariationImageGenerationTask((text_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 From 44f9ebd3b3f9195bf4d850b2fa551b289708dc6d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 11:20:36 -0700 Subject: [PATCH 03/63] WIP Event listners --- griptape/config/base_structure_config.py | 2 +- griptape/config/structure_config.py | 22 ++++---- griptape/drivers/prompt/base_prompt_driver.py | 2 + .../structure/base_conversation_memory.py | 2 +- griptape/structures/agent.py | 1 + griptape/structures/structure.py | 50 +------------------ griptape/utils/stream.py | 6 ++- .../test_azure_openai_structure_config.py | 15 +----- tests/unit/config/test_structure_config.py | 25 +--------- tests/unit/utils/test_stream.py | 7 +-- 10 files changed, 30 insertions(+), 102 deletions(-) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index c2aa82d7e..84743c4da 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -22,7 +22,7 @@ @define -class BaseStructureConfig(BaseConfig, ABC): +class BaseStructureConfig(BaseConfig, ABC, EventPublisherMixin): prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index ef95012ce..d68b6e2e2 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -1,19 +1,11 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field from griptape.config import BaseStructureConfig from griptape.drivers import ( - BaseAudioTranscriptionDriver, - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BaseImageQueryDriver, - BasePromptDriver, - BaseTextToSpeechDriver, - BaseVectorStoreDriver, DummyAudioTranscriptionDriver, DummyEmbeddingDriver, DummyImageGenerationDriver, @@ -23,6 +15,18 @@ DummyVectorStoreDriver, ) +if TYPE_CHECKING: + from griptape.drivers import ( + BaseAudioTranscriptionDriver, + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseTextToSpeechDriver, + BaseVectorStoreDriver, + ) + @define class StructureConfig(BaseStructureConfig): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..b6c28560b 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -113,6 +113,8 @@ def __process_run(self, prompt_stack: PromptStack) -> Message: return result def __process_stream(self, prompt_stack: PromptStack) -> Message: + from griptape.config import Config + delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 8794288c8..fb1cfdd8b 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = self.structure.config.prompt_driver + prompt_driver = Config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index b133a7b6b..31e0a424f 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -32,6 +32,7 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() + if len(self.tasks) == 0: if self.tools: task = ToolkitTask(self.input, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 9f1fa9a2b..73b5e617a 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -11,14 +11,7 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import BaseStructureConfig, Config -from griptape.drivers import ( - BaseEmbeddingDriver, - BasePromptDriver, - LocalVectorStoreDriver, - OpenAiChatPromptDriver, - OpenAiEmbeddingDriver, -) +from griptape.config import Config from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -33,7 +26,6 @@ from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.utils import deprecation_warn if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory @@ -46,13 +38,6 @@ class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - stream: Optional[bool] = field(default=None, kw_only=True) - prompt_driver: Optional[BasePromptDriver] = field(default=None) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - config: BaseStructureConfig = field( - default=Factory(lambda self: self.default_config, takes_self=True), - kw_only=True, - ) rulesets: list[Ruleset] = field(factory=list, kw_only=True) rules: list[Rule] = field(factory=list, kw_only=True) tasks: list[BaseTask] = field(factory=list, kw_only=True) @@ -99,21 +84,6 @@ def __attrs_post_init__(self) -> None: def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: return self.add_tasks(*other) if isinstance(other, list) else self + [other] - @prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_prompt_driver(self, attribute: Attribute, value: BasePromptDriver) -> None: - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver` instead.") - - @embedding_driver.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_embedding_driver(self, attribute: Attribute, value: BaseEmbeddingDriver) -> None: - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.embedding_driver` instead.") - - @stream.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_stream(self, attribute: Attribute, value: bool) -> None: # noqa: FBT001 - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver.stream` instead.") - @property def execution_args(self) -> tuple: return self._execution_args @@ -148,24 +118,6 @@ def output(self) -> Optional[BaseArtifact]: def finished_tasks(self) -> list[BaseTask]: return [s for s in self.tasks if s.is_finished()] - @property - def default_config(self) -> BaseStructureConfig: - if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None: - prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") if self.prompt_driver is None else self.prompt_driver - - embedding_driver = OpenAiEmbeddingDriver() if self.embedding_driver is None else self.embedding_driver - - if self.stream is not None: - prompt_driver.stream = self.stream - - vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - - Config.prompt_driver = prompt_driver - Config.vector_store_driver = vector_store_driver - Config.embedding_driver = embedding_driver - - return Config - @property def default_rag_engine(self) -> RagEngine: return RagEngine( diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..7b5381202 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,7 +34,9 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - if not structure.config.prompt_driver.stream: + from griptape.config import Config + + if not Config.prompt_driver.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) @@ -54,6 +56,8 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args) -> None: + from griptape.config import Config + def event_handler(event: BaseEvent) -> None: self._event_queue.put(event) diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index dcdc3a1dc..abeb6b878 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -8,7 +8,7 @@ class TestAzureOpenAiStructureConfig: def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") - @pytest.fixture() + @pytest.fixture def config(self): return AzureOpenAiStructureConfig( azure_endpoint="http://localhost:8080", @@ -85,16 +85,3 @@ def test_to_dict(self, config): "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } - - def test_from_dict(self, config: AzureOpenAiStructureConfig): - assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - # override values in the dict config - # serialize and deserialize the config - new_config = config.merge_config( - { - "prompt_driver": {"azure_deployment": "new-test-gpt-4"}, - "embedding_driver": {"model": "new-text-embedding-3-small"}, - } - ).to_dict() - assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 96a68628f..c7f02034f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -4,7 +4,7 @@ class TestStructureConfig: - @pytest.fixture() + @pytest.fixture def config(self): return StructureConfig() @@ -33,29 +33,6 @@ def test_to_dict(self, config): def test_from_dict(self, config): assert StructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - def test_unchanged_merge_config(self, config): - assert ( - config.merge_config( - { - "type": "StructureConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - } - ).to_dict() - == config.to_dict() - ) - - def test_changed_merge_config(self, config): - config = config.merge_config( - {"prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}} - ) - - assert config.prompt_driver.temperature == 0.1 - def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index da6695139..555daa4fd 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,18 +2,19 @@ import pytest +from griptape.config import Config from griptape.structures import Agent from griptape.utils import Stream -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - return Agent(prompt_driver=MockPromptDriver(stream=request.param, max_attempts=0)) + Config.prompt_driver.stream = request.param + return Agent() def test_init(self, agent): - if agent.prompt_driver.stream: + if Config.prompt_driver.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent From 6b76237978b756650e7c1ab5ce5c7cfc97968f85 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:12:25 -0700 Subject: [PATCH 04/63] Fix tests --- griptape/utils/chat.py | 8 +- tests/mocks/mock_structure_config.py | 5 + .../test_azure_openai_structure_config.py | 2 +- tests/unit/config/test_structure_config.py | 2 +- .../test_base_audio_transcription_driver.py | 2 +- ...est_dynamodb_conversation_memory_driver.py | 13 +-- .../test_local_conversation_memory_driver.py | 10 +- ...est_open_telemetry_observability_driver.py | 3 +- .../drivers/prompt/test_base_prompt_driver.py | 24 ++-- .../test_local_structure_run_driver.py | 7 +- .../extraction/test_csv_extraction_engine.py | 3 +- ...est_footnote_prompt_response_rag_module.py | 3 +- .../test_prompt_response_rag_module.py | 3 +- tests/unit/engines/rag/test_rag_engine.py | 23 +--- .../summary/test_prompt_summary_engine.py | 6 +- tests/unit/events/test_event_listener.py | 6 +- .../test_finish_actions_subtask_event.py | 3 +- tests/unit/events/test_finish_task_event.py | 3 +- .../test_start_actions_subtask_event.py | 3 +- tests/unit/events/test_start_task_event.py | 3 +- .../structure/test_conversation_memory.py | 16 ++- .../test_summary_conversation_memory.py | 8 +- tests/unit/structures/test_agent.py | 68 ++---------- tests/unit/structures/test_pipeline.py | 63 +++-------- tests/unit/structures/test_workflow.py | 104 ++++++------------ .../tasks/test_audio_transcription_task.py | 3 +- .../tasks/test_base_multi_text_input_task.py | 3 +- tests/unit/tasks/test_base_task.py | 5 +- tests/unit/tasks/test_base_text_input_task.py | 3 +- tests/unit/tasks/test_code_execution_task.py | 3 +- tests/unit/tasks/test_extraction_task.py | 5 +- tests/unit/tasks/test_prompt_task.py | 2 +- tests/unit/tasks/test_rag_task.py | 7 +- tests/unit/tasks/test_structure_run_task.py | 8 +- tests/unit/tasks/test_text_summary_task.py | 2 +- tests/unit/tasks/test_text_to_speech_task.py | 3 +- tests/unit/tasks/test_tool_task.py | 10 +- tests/unit/tasks/test_toolkit_task.py | 7 +- tests/unit/tools/test_structure_run_client.py | 4 +- tests/unit/utils/test_chat.py | 3 +- tests/unit/utils/test_conversation.py | 10 +- tests/unit/utils/test_file_utils.py | 5 +- tests/unit/utils/test_structure_visualizer.py | 5 +- tests/utils/defaults.py | 6 +- tests/utils/test_reference_utils.py | 3 +- 45 files changed, 164 insertions(+), 324 deletions(-) diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index e98eeaa4d..a8bdc9b13 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,12 +25,16 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - if self.structure.config.prompt_driver.stream: + from griptape.config import Config + + if Config.prompt_driver.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 def start(self) -> None: + from griptape.config import Config + if self.intro_text: self.output_fn(self.intro_text) while True: @@ -40,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if self.structure.config.prompt_driver.stream: + if Config.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py index 3f95288f4..0b374449d 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_structure_config.py @@ -1,6 +1,7 @@ from attrs import Factory, define, field from griptape.config import StructureConfig +from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -21,3 +22,7 @@ class MockStructureConfig(StructureConfig): embedding_driver: MockEmbeddingDriver = field( default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} ) + vector_store_driver: LocalVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + metadata={"serializable": True}, + ) diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index abeb6b878..810cb41a1 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -8,7 +8,7 @@ class TestAzureOpenAiStructureConfig: def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") - @pytest.fixture + @pytest.fixture() def config(self): return AzureOpenAiStructureConfig( azure_endpoint="http://localhost:8080", diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index c7f02034f..cce97647e 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -4,7 +4,7 @@ class TestStructureConfig: - @pytest.fixture + @pytest.fixture() def config(self): return StructureConfig() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..29aecfdf9 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -12,7 +12,7 @@ class TestBaseAudioTranscriptionDriver: def driver(self): return MockAudioTranscriptionDriver() - def test_run_publish_events(self, driver): + def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() EventBus.add_event_listener(EventListener(handler=mock_handler)) diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index 8e700d0a5..f1a5df1be 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -6,7 +6,6 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.aws import mock_aws_credentials @@ -40,7 +39,6 @@ def test_store(self): session = boto3.Session(region_name=self.AWS_REGION) dynamodb = session.resource("dynamodb") table = dynamodb.Table(self.DYNAMODB_TABLE_NAME) - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=session, table_name=self.DYNAMODB_TABLE_NAME, @@ -49,7 +47,7 @@ def test_store(self): partition_key_value=self.PARTITION_KEY_VALUE, ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -65,7 +63,6 @@ def test_store_with_sort_key(self): session = boto3.Session(region_name=self.AWS_REGION) dynamodb = session.resource("dynamodb") table = dynamodb.Table(self.DYNAMODB_TABLE_NAME) - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=session, table_name=self.DYNAMODB_TABLE_NAME, @@ -76,7 +73,7 @@ def test_store_with_sort_key(self): sort_key_value="foo", ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -89,7 +86,6 @@ def test_store_with_sort_key(self): assert "Item" in response def test_load(self): - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), table_name=self.DYNAMODB_TABLE_NAME, @@ -98,7 +94,7 @@ def test_load(self): partition_key_value=self.PARTITION_KEY_VALUE, ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -113,7 +109,6 @@ def test_load(self): assert new_memory.runs[0].output.value == "mock output" def test_load_with_sort_key(self): - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), table_name=self.DYNAMODB_TABLE_NAME, @@ -124,7 +119,7 @@ def test_load_with_sort_key(self): sort_key_value="foo", ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index e1a383ab9..dff66d0fc 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -7,7 +7,6 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestLocalConversationMemoryDriver: @@ -22,10 +21,9 @@ def _run_before_and_after_tests(self): self.__delete_file(self.MEMORY_FILE_PATH) def test_store(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver, autoload=False) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -41,10 +39,9 @@ def test_store(self): assert True def test_load(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -60,10 +57,9 @@ def test_load(self): assert new_memory.max_runs == 5 def test_autoload(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py b/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py index 4f7ce50f0..758505b26 100644 --- a/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py +++ b/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py @@ -8,7 +8,6 @@ from griptape.drivers import OpenTelemetryObservabilityDriver from griptape.observability.observability import Observability from griptape.structures.agent import Agent -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.expected_spans import ExpectedSpan, ExpectedSpans @@ -170,7 +169,7 @@ def test_observability_agent(self, driver, mock_span_exporter): ) with Observability(observability_driver=driver): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.run("Hi") assert mock_span_exporter.export.call_count == 1 diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..d95e7a5a7 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -10,17 +10,17 @@ class TestBasePromptDriver: - def test_run_via_pipeline_retries_success(self): - driver = MockPromptDriver(max_attempts=1) - pipeline = Pipeline(prompt_driver=driver) + def test_run_via_pipeline_retries_success(self, mock_config): + mock_config.prompt_driver = MockPromptDriver(max_attempts=2) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) assert isinstance(pipeline.run().output_task.output, TextArtifact) - def test_run_via_pipeline_retries_failure(self): - driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) - pipeline = Pipeline(prompt_driver=driver) + def test_run_via_pipeline_retries_failure(self, mock_config): + mock_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -46,9 +46,9 @@ def test_run_with_stream(self): assert isinstance(result, Message) assert result.value == "mock output" - def test_run_with_tools(self): - driver = MockPromptDriver(max_attempts=1, use_native_tools=True) - pipeline = Pipeline(prompt_driver=driver) + def test_run_with_tools(self, mock_config): + mock_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) @@ -56,9 +56,9 @@ def test_run_with_tools(self): assert isinstance(output, TextArtifact) assert output.value == "mock output" - def test_run_with_tools_and_stream(self): - driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True) - pipeline = Pipeline(prompt_driver=driver) + def test_run_with_tools_and_stream(self, mock_config): + mock_config.driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True) + pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 316f7bf71..318a41aa2 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -9,7 +9,7 @@ class TestLocalStructureRunDriver: def test_run(self): pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent(prompt_driver=MockPromptDriver())) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent()) task = StructureRunTask(driver=driver) @@ -17,10 +17,11 @@ def test_run(self): assert task.run().to_text() == "mock output" - def test_run_with_env(self): + def test_run_with_env(self, mock_config): pipeline = Pipeline() - agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["KEY"])) + mock_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index f69d8a0ba..d84fc7cdd 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -1,13 +1,12 @@ import pytest from griptape.engines import CsvExtractionEngine -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestCsvExtractionEngine: @pytest.fixture() def engine(self): - return CsvExtractionEngine(prompt_driver=MockPromptDriver()) + return CsvExtractionEngine() def test_extract(self, engine): result = engine.extract("foo", column_names=["test1"]) diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 385cf0c04..f7819c6d7 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -4,13 +4,12 @@ from griptape.common import Reference from griptape.engines.rag import RagContext from griptape.engines.rag.modules import FootnotePromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFootnotePromptResponseRagModule: @pytest.fixture() def module(self): - return FootnotePromptResponseRagModule(prompt_driver=MockPromptDriver()) + return FootnotePromptResponseRagModule() def test_run(self, module): assert module.run(RagContext(query="test")).output.value == "mock output" diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index 2f8a912e2..0e3526a52 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -3,13 +3,12 @@ from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestPromptResponseRagModule: @pytest.fixture() def module(self): - return PromptResponseRagModule(prompt_driver=MockPromptDriver()) + return PromptResponseRagModule() def test_run(self, module): assert module.run(RagContext(query="test")).output.value == "mock output" diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index c3d728bb3..40ab4af4d 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -1,36 +1,25 @@ import pytest -from griptape.drivers import LocalVectorStoreDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestRagEngine: @pytest.fixture() def engine(self): return RagEngine( - retrieval_stage=RetrievalRagStage( - retrieval_modules=[ - VectorStoreRetrievalRagModule( - vector_store_driver=LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - ) - ] - ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())), + retrieval_stage=RetrievalRagStage(retrieval_modules=[VectorStoreRetrievalRagModule()]), + response_stage=ResponseRagStage(response_module=PromptResponseRagModule()), ) def test_module_name_uniqueness(self): - vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - with pytest.raises(ValueError): RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), - VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test"), + VectorStoreRetrievalRagModule(name="test"), ] ) ) @@ -38,8 +27,8 @@ def test_module_name_uniqueness(self): assert RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule(name="test1", vector_store_driver=vector_store_driver), - VectorStoreRetrievalRagModule(name="test2", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test1"), + VectorStoreRetrievalRagModule(name="test2"), ] ) ) diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 4d9c65e03..138444ae3 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -12,7 +12,7 @@ class TestPromptSummaryEngine: @pytest.fixture() def engine(self): - return PromptSummaryEngine(prompt_driver=MockPromptDriver()) + return PromptSummaryEngine() def test_summarize_text(self, engine): assert engine.summarize_text("foobar") == "mock output" @@ -24,10 +24,10 @@ def test_summarize_artifacts(self, engine): def test_max_token_multiplier_invalid(self, engine): with pytest.raises(ValueError): - PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=0) + PromptSummaryEngine(max_token_multiplier=0) with pytest.raises(ValueError): - PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=10000) + PromptSummaryEngine(max_token_multiplier=10000) def test_chunked_summary(self, engine): def smaller_input(prompt_stack: PromptStack): diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 92c5a3653..ed978db78 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,10 +26,10 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) mock_config.prompt_driver = MockPromptDriver(stream=True) + task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) - pipeline = Pipeline(prompt_driver=MockPromptDriver(stream=True)) + pipeline = Pipeline() pipeline.add_task(task) task.add_subtask(ActionsSubtask("foo")) @@ -49,7 +49,7 @@ def test_untyped_listeners(self, pipeline, mock_config): assert event_handler_1.call_count == 9 assert event_handler_2.call_count == 9 - def test_typed_listeners(self, pipeline): + def test_typed_listeners(self, pipeline, mock_config): start_prompt_event_handler = Mock() finish_prompt_event_handler = Mock() start_task_event_handler = Mock() diff --git a/tests/unit/events/test_finish_actions_subtask_event.py b/tests/unit/events/test_finish_actions_subtask_event.py index 5e2a0807a..5fc35755b 100644 --- a/tests/unit/events/test_finish_actions_subtask_event.py +++ b/tests/unit/events/test_finish_actions_subtask_event.py @@ -3,7 +3,6 @@ from griptape.events import FinishActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -17,7 +16,7 @@ def finish_subtask_event(self): "Answer: test output" ) task = ToolkitTask(tools=[MockTool()]) - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) subtask = ActionsSubtask(valid_input) task.add_subtask(subtask) diff --git a/tests/unit/events/test_finish_task_event.py b/tests/unit/events/test_finish_task_event.py index df1d6d42a..2568752bb 100644 --- a/tests/unit/events/test_finish_task_event.py +++ b/tests/unit/events/test_finish_task_event.py @@ -3,14 +3,13 @@ from griptape.events import FinishTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFinishTaskEvent: @pytest.fixture() def finish_task_event(self): task = PromptTask() - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent.run() diff --git a/tests/unit/events/test_start_actions_subtask_event.py b/tests/unit/events/test_start_actions_subtask_event.py index 8b628057c..b7236911f 100644 --- a/tests/unit/events/test_start_actions_subtask_event.py +++ b/tests/unit/events/test_start_actions_subtask_event.py @@ -3,7 +3,6 @@ from griptape.events import StartActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -17,7 +16,7 @@ def start_subtask_event(self): "Answer: test output" ) task = ToolkitTask(tools=[MockTool()]) - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) subtask = ActionsSubtask(valid_input) task.add_subtask(subtask) diff --git a/tests/unit/events/test_start_task_event.py b/tests/unit/events/test_start_task_event.py index ea027f147..111d35934 100644 --- a/tests/unit/events/test_start_task_event.py +++ b/tests/unit/events/test_start_task_event.py @@ -3,14 +3,13 @@ from griptape.events import StartTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStartTaskEvent: @pytest.fixture() def start_task_event(self): task = PromptTask() - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent.run() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 2ffd7b8cb..77cebf193 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -60,7 +60,7 @@ def test_from_json(self): def test_buffering(self): memory = ConversationMemory(max_runs=2) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask()) @@ -75,7 +75,7 @@ def test_buffering(self): assert pipeline.conversation_memory.runs[1].input.value == "run5" def test_add_to_prompt_stack_autopruing_disabled(self): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() memory = ConversationMemory( autoprune=False, runs=[ @@ -94,9 +94,11 @@ def test_add_to_prompt_stack_autopruing_disabled(self): assert len(prompt_stack.messages) == 12 - def test_add_to_prompt_stack_autopruning_enabled(self): + def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) + + mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ @@ -117,7 +119,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): assert len(prompt_stack.messages) == 3 # No memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) + mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ @@ -140,7 +143,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) + mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 579be214e..42246e349 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -9,9 +9,9 @@ class TestSummaryConversationMemory: def test_unsummarized_subtasks(self): - memory = SummaryConversationMemory(offset=1, prompt_driver=MockPromptDriver()) + memory = SummaryConversationMemory(offset=1) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) @@ -23,9 +23,9 @@ def test_unsummarized_subtasks(self): assert len(memory.unsummarized_runs()) == 1 def test_after_run(self): - memory = SummaryConversationMemory(offset=1, prompt_driver=MockPromptDriver()) + memory = SummaryConversationMemory(offset=1) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index a09ad0f9a..15e1399b6 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,23 +1,17 @@ import pytest -from griptape.engines import PromptSummaryEngine from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Agent from griptape.tasks import BaseTask, PromptTask, ToolkitTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool class TestAgent: def test_init(self): - driver = MockPromptDriver() - agent = Agent(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + agent = Agent(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert agent.prompt_driver is driver assert isinstance(agent.task, PromptTask) assert isinstance(agent.task, PromptTask) assert agent.rulesets[0].name == "TestRuleset" @@ -76,18 +70,6 @@ def test_with_no_task_memory_and_empty_tool_output_memory(self): assert agent.tools[0].input_memory[0] == agent.task_memory assert agent.tools[0].output_memory == {} - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - agent = Agent(tools=[MockTool()], embedding_driver=embedding_driver) - - storage = list(agent.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_without_default_task_memory(self): agent = Agent(task_memory=None, tools=[MockTool()]) @@ -95,7 +77,7 @@ def test_without_default_task_memory(self): assert agent.tools[0].output_memory is None def test_with_memory(self): - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + agent = Agent(conversation_memory=ConversationMemory()) assert agent.conversation_memory is not None assert len(agent.conversation_memory.runs) == 0 @@ -117,7 +99,7 @@ def test_tasks_initialization(self): assert agent.tasks[0] == task def test_add_task(self): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() assert len(agent.tasks) == 1 @@ -145,7 +127,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() try: agent.add_tasks(first_task, second_task) @@ -160,7 +142,7 @@ def test_add_tasks(self): assert True def test_prompt_stack_without_memory(self): - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None, rules=[Rule("test")]) + agent = Agent(conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") @@ -177,7 +159,7 @@ def test_prompt_stack_without_memory(self): assert len(task1.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory(), rules=[Rule("test")]) + agent = Agent(conversation_memory=ConversationMemory(), rules=[Rule("test")]) task1 = PromptTask("test") @@ -195,7 +177,7 @@ def test_prompt_stack_with_memory(self): def test_run(self): task = PromptTask("test") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) assert task.state == BaseTask.State.PENDING @@ -207,7 +189,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent._execution_args = ("test1", "test2") @@ -220,7 +202,7 @@ def test_run_with_args(self): def test_context(self): task = PromptTask("test prompt") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) @@ -230,37 +212,9 @@ def test_context(self): assert context["structure"] == agent - def test_task_memory_defaults(self): - prompt_driver = MockPromptDriver() - embedding_driver = MockEmbeddingDriver() - agent = Agent(prompt_driver=prompt_driver, embedding_driver=embedding_driver) - - storage = list(agent.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - - assert storage.rag_engine.response_stage.response_module.prompt_driver == prompt_driver - assert ( - storage.rag_engine.retrieval_stage.retrieval_modules[0].vector_store_driver.embedding_driver - == embedding_driver - ) - assert isinstance(storage.summary_engine, PromptSummaryEngine) - assert storage.summary_engine.prompt_driver == prompt_driver - assert storage.csv_extraction_engine.prompt_driver == prompt_driver - assert storage.json_extraction_engine.prompt_driver == prompt_driver - - def test_deprecation(self): - with pytest.deprecated_call(): - Agent(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Agent(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Agent(stream=True) - def finished_tasks(self): task = PromptTask("test prompt") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) @@ -270,4 +224,4 @@ def finished_tasks(self): def test_fail_fast(self): with pytest.raises(ValueError): - Agent(prompt_driver=MockPromptDriver(), fail_fast=True) + Agent(fail_fast=True) diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 306fd7bd2..a7f7f40c1 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -4,14 +4,11 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Pipeline from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from griptape.tokenizers import OpenAiTokenizer -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.unit.structures.test_agent import MockEmbeddingDriver class TestPipeline: @@ -31,10 +28,8 @@ def fn(task): return CodeExecutionTask(run_fn=fn) def test_init(self): - driver = MockPromptDriver() - pipeline = Pipeline(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + pipeline = Pipeline(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert pipeline.prompt_driver is driver assert pipeline.input_task is None assert pipeline.output_task is None assert pipeline.rulesets[0].name == "TestRuleset" @@ -103,20 +98,6 @@ def test_with_task_memory(self): assert pipeline.tasks[0].tools[0].output_memory is not None assert pipeline.tasks[0].tools[0].output_memory["test"][0] == pipeline.task_memory - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - pipeline = Pipeline(embedding_driver=embedding_driver) - - pipeline.add_task(ToolkitTask(tools=[MockTool()])) - - storage = list(pipeline.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_with_task_memory_and_empty_tool_output_memory(self): pipeline = Pipeline() @@ -139,7 +120,7 @@ def test_with_memory(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline + [first_task, second_task, third_task] @@ -174,7 +155,7 @@ def test_tasks_order(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + first_task pipeline + second_task @@ -189,7 +170,7 @@ def test_add_task(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + first_task pipeline + second_task @@ -208,7 +189,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] @@ -227,7 +208,7 @@ def test_insert_task_in_middle(self): second_task = PromptTask("test2", id="test2") third_task = PromptTask("test3", id="test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] pipeline.insert_task(first_task, third_task) @@ -251,7 +232,7 @@ def test_insert_task_at_end(self): second_task = PromptTask("test2", id="test2") third_task = PromptTask("test3", id="test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] pipeline.insert_task(second_task, third_task) @@ -271,7 +252,7 @@ def test_insert_task_at_end(self): assert [child.id for child in third_task.children] == [] def test_prompt_stack_without_memory(self): - pipeline = Pipeline(conversation_memory=None, prompt_driver=MockPromptDriver(), rules=[Rule("test")]) + pipeline = Pipeline(conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") task2 = PromptTask("test") @@ -292,7 +273,7 @@ def test_prompt_stack_without_memory(self): assert len(task2.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), rules=[Rule("test")]) + pipeline = Pipeline(rules=[Rule("test")]) task1 = PromptTask("test") task2 = PromptTask("test") @@ -321,7 +302,7 @@ def test_text_artifact_token_count(self): def test_run(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task assert task.state == BaseTask.State.PENDING @@ -333,7 +314,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [task] pipeline._execution_args = ("test1", "test2") @@ -348,7 +329,7 @@ def test_context(self): parent = PromptTask("parent") task = PromptTask("test") child = PromptTask("child") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [parent, task, child] @@ -365,35 +346,23 @@ def test_context(self): assert context["parent"] == parent assert context["child"] == child - def test_deprecation(self): - with pytest.deprecated_call(): - Pipeline(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Pipeline(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Pipeline(stream=True) - def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") - pipeline = Pipeline(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + pipeline = Pipeline(tasks=[waiting_task, error_artifact_task, end_task]) pipeline.run() assert pipeline.output is None def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting_task): end_task = PromptTask("end") - pipeline = Pipeline( - prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False - ) + pipeline = Pipeline(tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False) pipeline.run() assert pipeline.output is not None def test_add_duplicate_task(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task pipeline + task @@ -402,7 +371,7 @@ def test_add_duplicate_task(self): def test_add_duplicate_task_directly(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task pipeline.tasks.append(task) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 242de29c5..79c9868e1 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -4,12 +4,9 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Workflow from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -30,10 +27,8 @@ def fn(task): return CodeExecutionTask(run_fn=fn) def test_init(self): - driver = MockPromptDriver() - workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + workflow = Workflow(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert workflow.prompt_driver is driver assert len(workflow.tasks) == 0 assert workflow.rulesets[0].name == "TestRuleset" assert workflow.rulesets[0].rules[0].value == "test" @@ -100,20 +95,6 @@ def test_with_task_memory(self): assert workflow.tasks[0].tools[0].output_memory is not None assert workflow.tasks[0].tools[0].output_memory["test"][0] == workflow.task_memory - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - workflow = Workflow(embedding_driver=embedding_driver) - - workflow.add_task(ToolkitTask(tools=[MockTool()])) - - storage = list(workflow.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_with_task_memory_and_empty_tool_output_memory(self): workflow = Workflow() @@ -136,7 +117,7 @@ def test_with_memory(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - workflow = Workflow(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + workflow = Workflow(conversation_memory=ConversationMemory()) workflow + [first_task, second_task, third_task] @@ -170,7 +151,7 @@ def test_add_task(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + first_task workflow.add_task(second_task) @@ -189,7 +170,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + [first_task, second_task] @@ -206,7 +187,7 @@ def test_add_tasks(self): def test_run(self): task1 = PromptTask("test") task2 = PromptTask("test") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + [task1, task2] assert task1.state == BaseTask.State.PENDING @@ -219,7 +200,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task workflow._execution_args = ("test1", "test2") @@ -241,7 +222,7 @@ def test_run_with_args(self): ], ) def test_run_raises_on_missing_parent_or_child_id(self, tasks): - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + workflow = Workflow(tasks=tasks) with pytest.raises(ValueError) as e: workflow.run() @@ -250,7 +231,6 @@ def test_run_raises_on_missing_parent_or_child_id(self, tasks): def test_run_topology_1_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task1"]), @@ -265,7 +245,6 @@ def test_run_topology_1_declarative_parents(self): def test_run_topology_1_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task2", "task3"]), PromptTask("test2", id="task2", child_ids=["task4"]), @@ -280,7 +259,6 @@ def test_run_topology_1_declarative_children(self): def test_run_topology_1_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task3"]), PromptTask("test2", id="task2", parent_ids=["task1"], child_ids=["task4"]), @@ -301,7 +279,7 @@ def test_run_topology_1_imperative_parents(self): task2.add_parent(task1) task3.add_parent("task1") task4.add_parents([task2, "task3"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -315,14 +293,14 @@ def test_run_topology_1_imperative_children(self): task1.add_children([task2, task3]) task2.add_child(task4) task3.add_child(task4) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() self._validate_topology_1(workflow) def test_run_topology_1_imperative_parents_structure_init(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2", structure=workflow) task3 = PromptTask("test3", id="task3", structure=workflow) @@ -336,7 +314,7 @@ def test_run_topology_1_imperative_parents_structure_init(self): self._validate_topology_1(workflow) def test_run_topology_1_imperative_children_structure_init(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() task1 = PromptTask("test1", id="task1", structure=workflow) task2 = PromptTask("test2", id="task2", structure=workflow) task3 = PromptTask("test3", id="task3", structure=workflow) @@ -356,7 +334,7 @@ def test_run_topology_1_imperative_mixed(self): task4 = PromptTask("test4", id="task4") task1.add_children([task2, task3]) task4.add_parents([task2, task3]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -367,7 +345,7 @@ def test_run_topology_1_imperative_insert(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task1 splits into task2 and task3 # task2 and task3 converge into task4 @@ -384,7 +362,7 @@ def test_run_topology_1_missing_parent(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task1 never added to workflow workflow + task4 @@ -396,7 +374,7 @@ def test_run_topology_1_id_equality(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task4 never added to workflow workflow + task1 @@ -410,7 +388,7 @@ def test_run_topology_1_object_equality(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -419,7 +397,6 @@ def test_run_topology_1_object_equality(self): def test_run_topology_2_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("testa", id="taska"), PromptTask("testb", id="taskb", parent_ids=["taska"]), @@ -435,7 +412,6 @@ def test_run_topology_2_declarative_parents(self): def test_run_topology_2_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("testa", id="taska", child_ids=["taskb", "taskc", "taskd", "taske"]), PromptTask("testb", id="taskb", child_ids=["taskd"]), @@ -459,7 +435,7 @@ def test_run_topology_2_imperative_parents(self): taskc.add_parent("taska") taskd.add_parents([taska, taskb, taskc]) taske.add_parents(["taska", taskd, "taskc"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -475,7 +451,7 @@ def test_run_topology_2_imperative_children(self): taskb.add_child(taskd) taskc.add_children([taskd, taske]) taskd.add_child(taske) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -491,7 +467,7 @@ def test_run_topology_2_imperative_mixed(self): taskb.add_child(taskd) taskd.add_parent(taskc) taske.add_parents(["taska", taskd, "taskc"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -503,7 +479,7 @@ def test_run_topology_2_imperative_insert(self): taskc = PromptTask("testc", id="taskc") taskd = PromptTask("testd", id="taskd") taske = PromptTask("teste", id="taske") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow.add_task(taska) workflow.add_task(taske) taske.add_parent(taska) @@ -517,7 +493,6 @@ def test_run_topology_2_imperative_insert(self): def test_run_topology_3_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task4"]), @@ -532,7 +507,6 @@ def test_run_topology_3_declarative_parents(self): def test_run_topology_3_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task4"]), PromptTask("test2", id="task2", child_ids=["task3"]), @@ -547,7 +521,6 @@ def test_run_topology_3_declarative_children(self): def test_run_topology_3_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task4"], child_ids=["task3"]), @@ -565,7 +538,7 @@ def test_run_topology_3_imperative_insert(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task2 @@ -580,7 +553,6 @@ def test_run_topology_3_imperative_insert(self): def test_run_topology_4_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info"), PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"]), @@ -600,7 +572,6 @@ def test_run_topology_4_declarative_parents(self): def test_run_topology_4_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info", child_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), PromptTask(id="movie_info_1", child_ids=["compare_movies"]), @@ -620,7 +591,6 @@ def test_run_topology_4_declarative_children(self): def test_run_topology_4_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info"), PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), @@ -650,7 +620,7 @@ def test_run_topology_4_imperative_insert(self): publish_website = PromptTask(id="publish_website") movie_info_3 = PromptTask(id="movie_info_3") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow.add_tasks(collect_movie_info, summarize_to_slack) workflow.insert_tasks(collect_movie_info, [movie_info_1, movie_info_2, movie_info_3], summarize_to_slack) workflow.insert_tasks([movie_info_1, movie_info_2, movie_info_3], compare_movies, summarize_to_slack) @@ -672,7 +642,7 @@ def test_run_topology_4_imperative_insert(self): ], ) def test_run_raises_on_cycle(self, tasks): - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + workflow = Workflow(tasks=tasks) with pytest.raises(ValueError) as e: workflow.run() @@ -684,7 +654,7 @@ def test_input_task(self): task2 = PromptTask("prompt2") task3 = PromptTask("prompt3") task4 = PromptTask("prompt4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -697,7 +667,7 @@ def test_output_task(self): task2 = PromptTask("prompt2") task3 = PromptTask("prompt3") task4 = PromptTask("prompt4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -709,7 +679,7 @@ def test_output_task(self): task1.add_children([task2, task3]) # task4 is the final task, but its defined at index 0 - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task4, task1, task2, task3]) + workflow = Workflow(tasks=[task4, task1, task2, task3]) # output_task topologically should be task4 assert task4 == workflow.output_task @@ -719,7 +689,7 @@ def test_to_graph(self): task2 = PromptTask("prompt2", id="task2") task3 = PromptTask("prompt3", id="task3") task4 = PromptTask("prompt4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -736,7 +706,7 @@ def test_order_tasks(self): task2 = PromptTask("prompt2", id="task2") task3 = PromptTask("prompt3", id="task3") task4 = PromptTask("prompt4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -753,7 +723,7 @@ def test_context(self): parent = PromptTask("parent") task = PromptTask("test") child = PromptTask("child") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + parent workflow + task @@ -776,20 +746,10 @@ def test_context(self): assert context["parents"] == {parent.id: parent} assert context["children"] == {child.id: child} - def test_deprecation(self): - with pytest.deprecated_call(): - Workflow(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Workflow(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Workflow(stream=True) - def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") end_task.add_parents([error_artifact_task, waiting_task]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + workflow = Workflow(tasks=[waiting_task, error_artifact_task, end_task]) workflow.run() assert workflow.output is None @@ -797,9 +757,7 @@ def test_run_with_error_artifact(self, error_artifact_task, waiting_task): def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting_task): end_task = PromptTask("end") end_task.add_parents([error_artifact_task, waiting_task]) - workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False - ) + workflow = Workflow(tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False) workflow.run() assert workflow.output is not None diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 4cc860c32..33405ad10 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -6,7 +6,6 @@ from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent, Pipeline from griptape.tasks import AudioTranscriptionTask, BaseTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestAudioTranscriptionTask: @@ -41,7 +40,7 @@ def test_run(self, audio_artifact, audio_transcription_engine): audio_transcription_engine.run.return_value = TextArtifact("mock transcription") task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) assert pipeline.run().output.to_text() == "mock transcription" diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py index 3d8d67a55..8eaa832ae 100644 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import TextArtifact from griptape.structures import Pipeline from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBaseMultiTextInputTask: @@ -42,7 +41,7 @@ def test_full_context(self): parent = MockMultiTextInputTask(("parent1", "parent2")) subtask = MockMultiTextInputTask(("test1", "test2"), context={"foo": "bar"}) child = MockMultiTextInputTask(("child2", "child2")) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_tasks(parent, subtask, child) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d22ef35f7 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -7,8 +7,6 @@ from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_task import MockTask from tests.mocks.mock_tool.tool import MockTool @@ -18,10 +16,9 @@ class TestBaseTask: def task(self): EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( - prompt_driver=MockPromptDriver(), - embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], ) + Config.event_listeners = [EventListener(handler=Mock())] agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 86dc98805..ff6afe42b 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import TextArtifact from griptape.rules import Rule, Ruleset from griptape.structures import Pipeline -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_text_input_task import MockTextInputTask @@ -31,7 +30,7 @@ def test_full_context(self): parent = MockTextInputTask("parent") subtask = MockTextInputTask("test", context={"foo": "bar"}) child = MockTextInputTask("child") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_tasks(parent, subtask, child) diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index 3178e29db..e2c492fad 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.structures import Pipeline from griptape.tasks import CodeExecutionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver def hello_world(task: CodeExecutionTask) -> BaseArtifact: @@ -27,7 +26,7 @@ def test_hello_world_fn(self): # Using a Pipeline # Overriding the input because we are implementing the task not the Pipeline def test_noop_fn(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() task = CodeExecutionTask("No Op", run_fn=non_outputting) pipeline.add_task(task) temp = task.run() diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index afa73a506..76a4c3bd2 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -3,15 +3,12 @@ from griptape.engines import CsvExtractionEngine from griptape.structures import Agent from griptape.tasks import ExtractionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestExtractionTask: @pytest.fixture() def task(self): - return ExtractionTask( - extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), args={"column_names": ["test1"]} - ) + return ExtractionTask(extraction_engine=CsvExtractionEngine(), args={"column_names": ["test1"]}) def test_run(self, task): agent = Agent() diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 4a618e0d1..cfe853226 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -12,7 +12,7 @@ class TestPromptTask: def test_run(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index b205d385a..f70a61bdd 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -5,7 +5,6 @@ from griptape.engines.rag.stages import ResponseRagStage from griptape.structures import Agent from griptape.tasks import RagTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestRagTask: @@ -13,11 +12,7 @@ class TestRagTask: def task(self): return RagTask( input="test", - rag_engine=RagEngine( - response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver()) - ) - ), + rag_engine=RagEngine(response_stage=ResponseRagStage(response_module=PromptResponseRagModule())), ) def test_run(self, task): diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 1053ade9e..8df0e6598 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -5,9 +5,11 @@ class TestStructureRunTask: - def test_run(self): - agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) - pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + def test_run(self, mock_config): + mock_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") + agent = Agent() + mock_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) task = StructureRunTask(driver=driver) diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index 438d2bae4..f83075f2a 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -6,7 +6,7 @@ class TestTextSummaryTask: def test_run(self): - task = TextSummaryTask("test", summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver())) + task = TextSummaryTask("test", summary_engine=PromptSummaryEngine()) agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index 3c629c69d..44348fef0 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -4,7 +4,6 @@ from griptape.engines import TextToSpeechEngine from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestTextToSpeechTask: @@ -40,7 +39,7 @@ def test_run(self): text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) assert isinstance(pipeline.run().output, AudioArtifact) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index dfc679919..90a7075fa 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -5,7 +5,6 @@ from griptape.artifacts import TextArtifact from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -166,13 +165,12 @@ class TestToolTask: } @pytest.fixture() - def agent(self): + def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - return Agent( - prompt_driver=MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}"), - embedding_driver=MockEmbeddingDriver(), - ) + mock_config.prompt_driver = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") + + return Agent() def test_run_without_memory(self, agent): task = ToolTask(tool=MockTool()) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index a47e4687b..1b89ddf70 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -2,7 +2,6 @@ from griptape.common import ToolAction from griptape.structures import Agent from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -175,7 +174,7 @@ def test_run(self, mock_config): mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) @@ -190,7 +189,7 @@ def test_run_max_subtasks(self, mock_config): mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) @@ -204,7 +203,7 @@ def test_run_invalid_react_prompt(self, mock_config): mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_client.py index d498b7c56..ee76d4da1 100644 --- a/tests/unit/tools/test_structure_run_client.py +++ b/tests/unit/tools/test_structure_run_client.py @@ -3,14 +3,12 @@ from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver from griptape.structures import Agent from griptape.tools import StructureRunClient -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureRunClient: @pytest.fixture() def client(self): - driver = MockPromptDriver() - agent = Agent(prompt_driver=driver) + agent = Agent() return StructureRunClient( description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index 42ecc59c3..5f97d1baf 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -1,14 +1,13 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Agent from griptape.utils import Chat -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: def test_init(self): import logging - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + agent = Agent(conversation_memory=ConversationMemory()) chat = Chat( agent, diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index 28ee72409..a07d15cdb 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -2,12 +2,11 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from griptape.utils import Conversation -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: def test_lines(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -22,7 +21,7 @@ def test_lines(self): assert lines[3] == "A: mock output" def test_prompt_stack_conversation_memory(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -36,8 +35,7 @@ def test_prompt_stack_conversation_memory(self): def test_prompt_stack_summary_conversation_memory(self): pipeline = Pipeline( - prompt_driver=MockPromptDriver(), - conversation_memory=SummaryConversationMemory(summary="foobar", prompt_driver=MockPromptDriver()), + conversation_memory=SummaryConversationMemory(summary="foobar"), ) pipeline.add_tasks(PromptTask("question 1")) @@ -52,7 +50,7 @@ def test_prompt_stack_summary_conversation_memory(self): assert lines[2] == "assistant: mock output" def test___str__(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index a9c122126..00df6958d 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -3,7 +3,6 @@ from griptape import utils from griptape.loaders import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -32,7 +31,7 @@ def test_load_files(self): def test_load_file_with_loader(self): dirname = os.path.dirname(__file__) file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()).load(file) + artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) assert len(artifacts) == 39 assert isinstance(artifacts, list) @@ -43,7 +42,7 @@ def test_load_files_with_loader(self): sources = ["resources/foobar-many.txt"] sources = [os.path.join(dirname, "../../", source) for source in sources] files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + loader = TextLoader(max_tokens=MAX_TOKENS) collection = loader.load_collection(list(files.values())) test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py index f6e621b91..8a055cb21 100644 --- a/tests/unit/utils/test_structure_visualizer.py +++ b/tests/unit/utils/test_structure_visualizer.py @@ -1,12 +1,11 @@ from griptape.structures import Agent, Pipeline, Workflow from griptape.tasks import PromptTask from griptape.utils import StructureVisualizer -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureVisualizer: def test_agent(self): - agent = Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask("test1", id="task1")]) + agent = Agent(tasks=[PromptTask("test1", id="task1")]) visualizer = StructureVisualizer(agent) result = visualizer.to_url() @@ -15,7 +14,6 @@ def test_agent(self): def test_pipeline(self): pipeline = Pipeline( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2"), @@ -34,7 +32,6 @@ def test_pipeline(self): def test_workflow(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task1"]), diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index bad7f0d79..e3bcde29b 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -17,9 +17,9 @@ def text_tool_artifact_storage(): rag_engine=rag_engine(MockPromptDriver(), vector_store_driver), vector_store_driver=vector_store_driver, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver()), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), - json_extraction_engine=JsonExtractionEngine(prompt_driver=MockPromptDriver()), + summary_engine=PromptSummaryEngine(), + csv_extraction_engine=CsvExtractionEngine(), + json_extraction_engine=JsonExtractionEngine(), ) diff --git a/tests/utils/test_reference_utils.py b/tests/utils/test_reference_utils.py index c3491f5d0..47da18713 100644 --- a/tests/utils/test_reference_utils.py +++ b/tests/utils/test_reference_utils.py @@ -1,12 +1,11 @@ from griptape.artifacts import TextArtifact from griptape.common import Reference from griptape.engines.rag.modules import PromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestReferenceUtils: def test_references_from_artifacts(self): - module = PromptResponseRagModule(prompt_driver=MockPromptDriver()) + module = PromptResponseRagModule() reference1 = Reference(title="foo") reference2 = Reference(title="bar") artifacts = [ From e8c1fff7155dcdfd6ecf3b8dcdf5322ff3f48a49 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:21:06 -0700 Subject: [PATCH 05/63] Namespace config --- CHANGELOG.md | 4 ++-- griptape/config/base_structure_config.py | 9 +-------- griptape/config/config.py | 16 +++++++++++++++- .../engines/audio/audio_transcription_engine.py | 2 +- griptape/engines/audio/text_to_speech_engine.py | 2 +- .../engines/extraction/base_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 4 +++- .../response/prompt_response_rag_module.py | 2 +- .../vector_store_retrieval_rag_module.py | 2 +- .../engines/summary/prompt_summary_engine.py | 2 +- .../memory/structure/base_conversation_memory.py | 4 ++-- .../structure/summary_conversation_memory.py | 2 +- .../memory/task/storage/text_artifact_storage.py | 2 +- griptape/structures/structure.py | 8 ++++---- griptape/tasks/prompt_task.py | 2 +- griptape/utils/chat.py | 4 ++-- griptape/utils/stream.py | 2 +- tests/unit/conftest.py | 9 +-------- .../drivers/prompt/test_base_prompt_driver.py | 6 +++--- .../test_local_structure_run_driver.py | 2 +- tests/unit/events/test_event_listener.py | 2 +- .../memory/structure/test_conversation_memory.py | 8 +++++--- tests/unit/tasks/test_json_extraction_task.py | 4 +++- tests/unit/tasks/test_structure_run_task.py | 4 ++-- tests/unit/tasks/test_tool_task.py | 4 +++- tests/unit/tasks/test_toolkit_task.py | 6 +++--- tests/unit/utils/test_stream.py | 4 ++-- 28 files changed, 64 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea88983f3..c0720ca47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -263,7 +263,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. +- **BREAKING**: Removed `StructureConfig.drivers.global_drivers`. Pass Drivers directly to the Structure Config instead. - **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. +- `StructureConfig.task_memory` not defaulting to using `StructureConfig.drivers.global_drivers` by default. ## [0.23.1] - 2024-03-07 diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index 84743c4da..bc9238df2 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -6,7 +6,6 @@ from attrs import define, field from griptape.config import BaseConfig -from griptape.utils import dict_merge if TYPE_CHECKING: from griptape.drivers import ( @@ -22,7 +21,7 @@ @define -class BaseStructureConfig(BaseConfig, ABC, EventPublisherMixin): +class BaseStructureConfig(BaseConfig, ABC): prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) @@ -35,9 +34,3 @@ class BaseStructureConfig(BaseConfig, ABC, EventPublisherMixin): ) text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) - - def merge_config(self, config: dict) -> BaseStructureConfig: - base_config = self.to_dict() - merged_config = dict_merge(base_config, config) - - return BaseStructureConfig.from_dict(merged_config) diff --git a/griptape/config/config.py b/griptape/config/config.py index e3017f8b6..3985abca2 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,3 +1,17 @@ +from attrs import define + +from griptape.config.base_config import BaseConfig +from griptape.config.base_structure_config import BaseStructureConfig +from griptape.mixins.event_publisher_mixin import EventPublisherMixin + from .openai_structure_config import OpenAiStructureConfig -Config = OpenAiStructureConfig() + +@define +class _Config(BaseConfig, EventPublisherMixin): + drivers: BaseStructureConfig + + +Config = _Config( + drivers=OpenAiStructureConfig(), +) diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index a3769842d..aad669d70 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -8,7 +8,7 @@ @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: Config.audio_transcription_driver), kw_only=True + default=Factory(lambda: Config.drivers.audio_transcription_driver), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 361ecc127..16634ce45 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Config.text_to_speech_driver), kw_only=True + default=Factory(lambda: Config.drivers.text_to_speech_driver), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 3ff6a96e3..03826ab43 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index eabf38be3..4187dde79 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: Config.image_generation_driver) + kw_only=True, default=Factory(lambda: Config.drivers.image_generation_driver) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index ed6a64ee3..5090e2f27 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -13,7 +13,9 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.image_query_driver), kw_only=True) + image_query_driver: BaseImageQueryDriver = field( + default=Factory(lambda: Config.drivers.image_query_driver), kw_only=True + ) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 2e7b486b6..979723beb 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -17,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index b0deca67d..392a6836d 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index d06ebaa2f..2586a8e0c 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index fb1cfdd8b..3c3a0aaca 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: Config.conversation_memory_driver), kw_only=True + default=Factory(lambda: Config.drivers.conversation_memory_driver), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = Config.prompt_driver + prompt_driver = Config.drivers.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 807775d63..161a68eb3 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.prompt_driver)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt_driver)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8a918c5f2..134274648 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -16,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 73b5e617a..010e8ef1f 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -140,10 +140,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=Config.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=Config.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.prompt_driver), + vector_store_driver=Config.drivers.vector_store_driver, + summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt_driver), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt_driver), + json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt_driver), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 9f698787f..6997c9558 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -17,7 +17,7 @@ @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index a8bdc9b13..6455efd14 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -27,7 +27,7 @@ class Chat: def default_output_fn(self, text: str) -> None: from griptape.config import Config - if Config.prompt_driver.stream: + if Config.drivers.prompt_driver.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 @@ -44,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if Config.prompt_driver.stream: + if Config.drivers.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 7b5381202..7c716787b 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -36,7 +36,7 @@ class Stream: def validate_structure(self, _: Attribute, structure: Structure) -> None: from griptape.config import Config - if not Config.prompt_driver.stream: + if not Config.drivers.prompt_driver.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7d2f8203d..e49de0021 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,13 +16,6 @@ def event_bus(): @pytest.fixture(autouse=True) def mock_config(): - mock_structure_config = MockStructureConfig() - Config.prompt_driver = mock_structure_config.prompt_driver - Config.image_generation_driver = mock_structure_config.image_generation_driver - Config.image_query_driver = mock_structure_config.image_query_driver - Config.embedding_driver = mock_structure_config.embedding_driver - Config.vector_store_driver = mock_structure_config.vector_store_driver - Config.text_to_speech_driver = mock_structure_config.text_to_speech_driver - Config.audio_transcription_driver = mock_structure_config.audio_transcription_driver + Config.drivers = MockStructureConfig() return Config diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index d95e7a5a7..84fd0bed1 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(max_attempts=2) + mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.drivers.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -47,7 +47,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 318a41aa2..b2e9c069b 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index ed978db78..d2681877f 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(stream=True) + mock_config.drivers.prompt_driver = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 77cebf193..06e54e6c4 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -119,7 +119,9 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) + mock_config.drivers.prompt_driver = MockPromptDriver( + tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) + ) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -143,7 +145,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) agent = Agent() memory = ConversationMemory( autoprune=True, diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 0189e6679..3eef4eec3 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -13,7 +13,9 @@ def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) def test_run(self, task, mock_config): - mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' + mock_config.drivers.prompt_driver.mock_output = ( + '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' + ) agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 8df0e6598..2c0dc1b28 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 90a7075fa..70ab05e12 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,7 +168,9 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.prompt_driver = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") + mock_config.drivers.prompt_driver = MockPromptDriver( + mock_output=f"```python foo bar\n{json.dumps(output_dict)}" + ) return Agent() diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 1b89ddf70..15f5a59b1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.prompt_driver.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.prompt_driver.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.prompt_driver.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 555daa4fd..48dbaae29 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -10,11 +10,11 @@ class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - Config.prompt_driver.stream = request.param + Config.drivers.prompt_driver.stream = request.param return Agent() def test_init(self, agent): - if Config.prompt_driver.stream: + if Config.drivers.prompt_driver.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent From 514665f8e7bdd9582ea1332045971a5fa563b21e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:35:52 -0700 Subject: [PATCH 06/63] Rename driver config fields --- docs/examples/multiple-agent-shared-memory.md | 9 ++- docs/examples/talk-to-a-video.md | 4 +- .../drivers/embedding-drivers.md | 6 +- .../drivers/event-listener-drivers.md | 4 +- .../drivers/prompt-drivers.md | 61 ++++++++--------- docs/griptape-framework/misc/events.md | 10 +-- docs/griptape-framework/structures/config.md | 66 +++++++++---------- .../structures/task-memory.md | 4 +- .../official-tools/rest-api-client.md | 6 +- griptape/config/__init__.py | 32 ++++----- ...fig.py => amazon_bedrock_driver_config.py} | 16 ++--- ...e_config.py => anthropic_driver_config.py} | 12 ++-- ...onfig.py => azure_openai_driver_config.py} | 16 ++--- griptape/config/base_driver_config.py | 34 ++++++++++ griptape/config/base_structure_config.py | 36 ---------- ...ture_config.py => cohere_driver_config.py} | 12 ++-- griptape/config/config.py | 8 +-- .../{structure_config.py => driver_config.py} | 20 +++--- ...ture_config.py => google_driver_config.py} | 10 +-- ...ture_config.py => openai_driver_config.py} | 18 ++--- .../audio/audio_transcription_engine.py | 2 +- .../engines/audio/text_to_speech_engine.py | 2 +- .../extraction/base_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 4 +- .../response/prompt_response_rag_module.py | 2 +- .../vector_store_retrieval_rag_module.py | 2 +- .../engines/summary/prompt_summary_engine.py | 2 +- griptape/exceptions/dummy_exception.py | 2 +- .../structure/base_conversation_memory.py | 4 +- .../structure/summary_conversation_memory.py | 2 +- .../task/storage/text_artifact_storage.py | 2 +- griptape/structures/structure.py | 8 +-- griptape/tasks/prompt_task.py | 2 +- griptape/utils/chat.py | 4 +- griptape/utils/stream.py | 2 +- ...ucture_config.py => mock_driver_config.py} | 18 +++-- ...y => test_amazon_bedrock_driver_config.py} | 47 +++++++------ ...fig.py => test_anthropic_driver_config.py} | 24 +++---- ....py => test_azure_openai_driver_config.py} | 22 +++---- ...config.py => test_cohere_driver_config.py} | 22 +++---- tests/unit/config/test_driver_config.py | 39 +++++++++++ ...config.py => test_google_driver_config.py} | 24 +++---- ...config.py => test_openai_driver_config.py} | 24 +++---- tests/unit/config/test_structure_config.py | 39 ----------- tests/unit/conftest.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 6 +- .../test_local_structure_run_driver.py | 2 +- tests/unit/events/test_event_listener.py | 2 +- .../structure/test_conversation_memory.py | 8 +-- tests/unit/tasks/test_json_extraction_task.py | 4 +- tests/unit/tasks/test_structure_run_task.py | 4 +- tests/unit/tasks/test_tool_task.py | 4 +- tests/unit/tasks/test_toolkit_task.py | 6 +- tests/unit/utils/test_stream.py | 4 +- tests/utils/structure_tester.py | 4 +- 56 files changed, 355 insertions(+), 380 deletions(-) rename griptape/config/{amazon_bedrock_structure_config.py => amazon_bedrock_driver_config.py} (84%) rename griptape/config/{anthropic_structure_config.py => anthropic_driver_config.py} (76%) rename griptape/config/{azure_openai_structure_config.py => azure_openai_driver_config.py} (89%) create mode 100644 griptape/config/base_driver_config.py delete mode 100644 griptape/config/base_structure_config.py rename griptape/config/{cohere_structure_config.py => cohere_driver_config.py} (76%) rename griptape/config/{structure_config.py => driver_config.py} (74%) rename griptape/config/{google_structure_config.py => google_driver_config.py} (75%) rename griptape/config/{openai_structure_config.py => openai_driver_config.py} (76%) rename tests/mocks/{mock_structure_config.py => mock_driver_config.py} (63%) rename tests/unit/config/{test_amazon_bedrock_structure_config.py => test_amazon_bedrock_driver_config.py} (71%) rename tests/unit/config/{test_anthropic_structure_config.py => test_anthropic_driver_config.py} (65%) rename tests/unit/config/{test_azure_openai_structure_config.py => test_azure_openai_driver_config.py} (84%) rename tests/unit/config/{test_cohere_structure_config.py => test_cohere_driver_config.py} (59%) create mode 100644 tests/unit/config/test_driver_config.py rename tests/unit/config/{test_google_structure_config.py => test_google_driver_config.py} (63%) rename tests/unit/config/{test_openai_structure_config.py => test_openai_driver_config.py} (81%) delete mode 100644 tests/unit/config/test_structure_config.py diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index 109394d49..e6b092965 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -11,8 +11,7 @@ import os from griptape.tools import WebScraper, TaskMemoryClient from griptape.structures import Agent from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver -from griptape.config import AzureOpenAiStructureConfig - +from griptape.config import AzureOpenAiDriverConfig AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -26,7 +25,6 @@ MONGODB_INDEX_NAME = os.environ["MONGODB_INDEX_NAME"] MONGODB_VECTOR_PATH = os.environ["MONGODB_VECTOR_PATH"] MONGODB_CONNECTION_STRING = f"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{MONGODB_HOST}/{MONGODB_DATABASE_NAME}?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000" - embedding_driver = AzureOpenAiEmbeddingDriver( model='text-embedding-ada-002', azure_endpoint=AZURE_OPENAI_ENDPOINT_1, @@ -42,7 +40,7 @@ mongo_driver = AzureMongoDbVectorStoreDriver( vector_path=MONGODB_VECTOR_PATH, ) -config = AzureOpenAiStructureConfig( +config = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, vector_store_driver=mongo_driver, embedding_driver=embedding_driver, @@ -64,6 +62,7 @@ asker = Agent( ) if __name__ == "__main__": - loader.run("Load https://medium.com/enterprise-rag/a-first-intro-to-complex-rag-retrieval-augmented-generation-a8624d70090f") + loader.run( + "Load https://medium.com/enterprise-rag/a-first-intro-to-complex-rag-retrieval-augmented-generation-a8624d70090f") asker.run("why is retrieval augmented generation useful?") ``` diff --git a/docs/examples/talk-to-a-video.md b/docs/examples/talk-to-a-video.md index 9673bd1c3..310b6d407 100644 --- a/docs/examples/talk-to-a-video.md +++ b/docs/examples/talk-to-a-video.md @@ -7,7 +7,7 @@ import time from griptape.structures import Agent from griptape.tasks import PromptTask from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig import google.generativeai as genai video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") @@ -19,7 +19,7 @@ if video_file.state.name == "FAILED": raise ValueError(video_file.state.name) agent = Agent( - config=GoogleStructureConfig(), + config=GoogleDriverConfig(), input=[ GenericArtifact(video_file), TextArtifact("Answer this question regarding the video: {{ args[0] }}"), diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 567aa13e4..de2f2d379 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -211,7 +211,7 @@ print(embeddings[:3]) ``` ### Override Default Structure Embedding Driver -Here is how you can override the Embedding Driver that is used by default in Structures. +Here is how you can override the Embedding Driver that is used by default in Structures. ```python from griptape.structures import Agent @@ -220,11 +220,11 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ), diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 73453afb6..0adb0b10f 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -123,7 +123,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, @@ -138,7 +138,7 @@ agent = Agent( value="You will be provided with a text, and your task is to extract the airport codes from it." ) ], - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( model="gpt-3.5-turbo", temperature=0.7 ) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index ab749bf7c..8693cc6ff 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -13,10 +13,10 @@ You can instantiate drivers and pass them to structures: from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), ), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", @@ -71,10 +71,10 @@ import os from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( api_key=os.environ["OPENAI_API_KEY"], temperature=0.1, @@ -106,10 +106,10 @@ Simply set the `base_url` to the service's API endpoint and the `model` to the m from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True @@ -134,10 +134,10 @@ import os from griptape.structures import Agent from griptape.rules import Rule from griptape.drivers import AzureOpenAiChatPromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-3.5-turbo", @@ -168,10 +168,10 @@ This driver uses [Cohere tool use](https://docs.cohere.com/docs/tools) when usin import os from griptape.structures import Agent from griptape.drivers import CoherePromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=CoherePromptDriver( model="command-r", api_key=os.environ['COHERE_API_KEY'], @@ -194,10 +194,10 @@ This driver uses [Anthropic tool use](https://docs.anthropic.com/en/docs/build-w import os from griptape.structures import Agent from griptape.drivers import AnthropicPromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-opus-20240229", api_key=os.environ['ANTHROPIC_API_KEY'], @@ -220,10 +220,10 @@ This driver uses [Gemini function calling](https://ai.google.dev/gemini-api/docs import os from griptape.structures import Agent from griptape.drivers import GooglePromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=GooglePromptDriver( model="gemini-pro", api_key=os.environ['GOOGLE_API_KEY'], @@ -248,10 +248,10 @@ All models supported by the Converse API are available for use with this driver. from griptape.structures import Agent from griptape.drivers import AmazonBedrockPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AmazonBedrockPromptDriver( model="anthropic.claude-3-sonnet-20240229-v1:0", ) @@ -285,14 +285,13 @@ The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_d This driver uses [Ollama tool calling](https://ollama.com/blog/tool-support) when using [Tools](../tools/index.md). ```python -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import OllamaPromptDriver from griptape.tools import Calculator from griptape.structures import Agent - agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OllamaPromptDriver( model="llama3.1", ), @@ -319,11 +318,10 @@ import os from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset -from griptape.config import StructureConfig - +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=HuggingFaceHubPromptDriver( model="HuggingFaceH4/zephyr-7b-beta", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], @@ -335,8 +333,8 @@ agent = Agent( rules=[ Rule( value="You are Girafatron, a giraffe-obsessed robot. You are talking to a human. " - "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " - "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " + "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." ) ], ) @@ -354,11 +352,10 @@ The [HuggingFaceHubPromptDriver](#hugging-face-hub) also supports [Text Generati import os from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.config import StructureConfig - +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=HuggingFaceHubPromptDriver( model="http://127.0.0.1:8080", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], @@ -383,11 +380,10 @@ The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/hu from griptape.structures import Agent from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset -from griptape.config import StructureConfig - +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=HuggingFacePipelinePromptDriver( model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) @@ -417,7 +413,6 @@ The [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prom Amazon Sagemaker Jumpstart provides a wide range of models with varying capabilities. This Driver has been primarily _chat-optimized_ models that have a [Huggingface Chat Template](https://huggingface.co/docs/transformers/en/chat_templating) available. If your model does not fit this use-case, we suggest sub-classing [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) and overriding the `_to_model_input` and `_to_model_params` methods. - ```python title="PYTEST_IGNORE" import os @@ -426,10 +421,10 @@ from griptape.drivers import ( AmazonSageMakerJumpstartPromptDriver, SageMakerFalconPromptModelDriver, ) -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AmazonSageMakerJumpstartPromptDriver( endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], model="meta-llama/Meta-Llama-3-8B-Instruct", diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 187321dc6..dfb6e2db3 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -135,7 +135,7 @@ from griptape.events import CompletionChunkEvent, EventListener, EventBus from griptape.tasks import ToolkitTask from griptape.structures import Pipeline from griptape.tools import WebScraper, TaskMemoryClient -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig from griptape.drivers import OpenAiChatPromptDriver @@ -148,7 +148,7 @@ EventBus.event_listeners = [ ] pipeline = Pipeline( - config=OpenAiStructureConfig( + config=OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True) ), ) @@ -172,10 +172,10 @@ from griptape.tasks import ToolkitTask from griptape.structures import Pipeline from griptape.tools import WebScraper, TaskMemoryClient - pipeline = Pipeline() -pipeline.config.prompt_driver.stream = True -pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)])) +pipeline.config.prompt.stream = True +pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", + tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)])) for artifact in Stream(pipeline).run(): print(artifact.value, end="", flush=True) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 3f510eb86..17fb9e5da 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -5,44 +5,42 @@ search: ## Overview -The [StructureConfig](../../reference/griptape/config/structure_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. +The [StructureConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. ### Premade Configs -Griptape provides predefined [StructureConfig](../../reference/griptape/config/structure_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. +Griptape provides predefined [StructureConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. #### OpenAI -The [OpenAI Structure Config](../../reference/griptape/config/openai_structure_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. - +The [OpenAI Structure Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python from griptape.structures import Agent -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig agent = Agent( - config=OpenAiStructureConfig() + config=OpenAiDriverConfig() ) -agent = Agent() # This is equivalent to the above +agent = Agent() # This is equivalent to the above ``` #### Azure OpenAI -The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_structure_config.md) provides default Drivers for Azure's OpenAI APIs. - +The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python import os from griptape.structures import Agent -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig agent = Agent( - config=AzureOpenAiStructureConfig( + config=AzureOpenAiDriverConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ).merge_config({ - "image_query_driver": { + "image_query": { "azure_deployment": "gpt-4o", }, }), @@ -50,16 +48,16 @@ agent = Agent( ``` #### Amazon Bedrock -The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_structure_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python import os import boto3 from griptape.structures import Agent -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig agent = Agent( - config=AmazonBedrockStructureConfig( + config=AmazonBedrockDriverConfig( session=boto3.Session( region_name=os.environ["AWS_DEFAULT_REGION"], aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], @@ -70,63 +68,61 @@ agent = Agent( ``` #### Google -The [Google Structure Config](../../reference/griptape/config/google_structure_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Structure Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python from griptape.structures import Agent -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig agent = Agent( - config=GoogleStructureConfig() + config=GoogleDriverConfig() ) ``` #### Anthropic -The [Anthropic Structure Config](../../reference/griptape/config/anthropic_structure_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Structure Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. The `AnthropicStructureConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). - ```python from griptape.structures import Agent -from griptape.config import AnthropicStructureConfig +from griptape.config import AnthropicDriverConfig agent = Agent( - config=AnthropicStructureConfig() + config=AnthropicDriverConfig() ) ``` #### Cohere -The [Cohere Structure Config](../../reference/griptape/config/cohere_structure_config.md) provides default Drivers for Cohere's APIs. - +The [Cohere Structure Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python import os -from griptape.config import CohereStructureConfig +from griptape.config import CohereDriverConfig from griptape.structures import Agent -agent = Agent(config=CohereStructureConfig(api_key=os.environ["COHERE_API_KEY"])) +agent = Agent(config=CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])) ``` ### Custom Configs -You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers. -The [StructureConfig](../../reference/griptape/config/structure_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. +You can create your own [StructureConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. +The [StructureConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. ```python import os from griptape.structures import Agent -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import AnthropicPromptDriver agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], @@ -141,14 +137,14 @@ Configuration classes in Griptape offer utility methods for loading, saving, and ```python from griptape.structures import Agent -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig from griptape.drivers import AmazonBedrockCohereEmbeddingDriver -custom_config = AmazonBedrockStructureConfig() +custom_config = AmazonBedrockDriverConfig() custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() custom_config.merge_config( { - "embedding_driver": { + "embedding": { "base_url": None, "model": "text-embedding-3-small", "organization": None, @@ -157,11 +153,11 @@ custom_config.merge_config( } ) serialized_config = custom_config.to_json() -deserialized_config = AmazonBedrockStructureConfig.from_json(serialized_config) +deserialized_config = AmazonBedrockDriverConfig.from_json(serialized_config) agent = Agent( config=deserialized_config.merge_config({ - "prompt_driver" : { + "prompt": { "model": "anthropic.claude-3-sonnet-20240229-v1:0", }, }), diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index ea4a787f6..3184c4096 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ```python from griptape.artifacts import TextArtifact from griptape.config import ( - OpenAiStructureConfig, + OpenAiDriverConfig, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -223,7 +223,7 @@ from griptape.tools import FileManager, TaskMemoryClient, WebScraper vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) agent = Agent( - config=OpenAiStructureConfig( + config=OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ), task_memory=TaskMemory( diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 07ddccf86..304ec00ec 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -6,7 +6,7 @@ The [RestApiClient](../../reference/griptape/tools/rest_api_client/tool.md) tool ### Example The following example is built using [https://jsonplaceholder.typicode.com/guide/](https://jsonplaceholder.typicode.com/guide/). - + ```python from json import dumps from griptape.drivers import OpenAiChatPromptDriver @@ -14,7 +14,7 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import StructureConfig +from griptape.config import DriverConfig posts_client = RestApiClient( base_url="https://jsonplaceholder.typicode.com", @@ -117,7 +117,7 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), - config = StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", temperature=0.1 diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 4b0f8eb28..7450d7738 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -1,26 +1,26 @@ from .base_config import BaseConfig -from .base_structure_config import BaseStructureConfig +from .base_driver_config import BaseDriverConfig -from .structure_config import StructureConfig -from .openai_structure_config import OpenAiStructureConfig -from .azure_openai_structure_config import AzureOpenAiStructureConfig -from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig -from .anthropic_structure_config import AnthropicStructureConfig -from .google_structure_config import GoogleStructureConfig -from .cohere_structure_config import CohereStructureConfig +from .driver_config import DriverConfig +from .openai_driver_config import OpenAiDriverConfig +from .azure_openai_driver_config import AzureOpenAiDriverConfig +from .amazon_bedrock_driver_config import AmazonBedrockDriverConfig +from .anthropic_driver_config import AnthropicDriverConfig +from .google_driver_config import GoogleDriverConfig +from .cohere_driver_config import CohereDriverConfig from .config import Config __all__ = [ "BaseConfig", - "BaseStructureConfig", - "StructureConfig", - "OpenAiStructureConfig", - "AzureOpenAiStructureConfig", - "AmazonBedrockStructureConfig", - "AnthropicStructureConfig", - "GoogleStructureConfig", - "CohereStructureConfig", + "BaseDriverConfig", + "DriverConfig", + "OpenAiDriverConfig", + "AzureOpenAiDriverConfig", + "AmazonBedrockDriverConfig", + "AnthropicDriverConfig", + "GoogleDriverConfig", + "CohereDriverConfig", "Config", ] diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_driver_config.py similarity index 84% rename from griptape/config/amazon_bedrock_structure_config.py rename to griptape/config/amazon_bedrock_driver_config.py index 3ad7f8f48..a07300638 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, AmazonBedrockImageQueryDriver, @@ -25,14 +25,14 @@ @define -class AmazonBedrockStructureConfig(StructureConfig): +class AmazonBedrockDriverConfig(DriverConfig): session: boto3.Session = field( default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True, metadata={"serializable": False}, ) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory( lambda self: AmazonBedrockPromptDriver( session=self.session, @@ -43,7 +43,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1"), takes_self=True, @@ -51,7 +51,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory( lambda self: AmazonBedrockImageGenerationDriver( session=self.session, @@ -63,7 +63,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageGenerationDriver = field( + image_query: BaseImageGenerationDriver = field( default=Factory( lambda self: AmazonBedrockImageQueryDriver( session=self.session, @@ -75,8 +75,8 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_driver_config.py similarity index 76% rename from griptape/config/anthropic_structure_config.py rename to griptape/config/anthropic_driver_config.py index 1bb5bf49b..642a3fced 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AnthropicImageQueryDriver, AnthropicPromptDriver, @@ -14,25 +14,25 @@ @define -class AnthropicStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class AnthropicDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")), ), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/azure_openai_structure_config.py b/griptape/config/azure_openai_driver_config.py similarity index 89% rename from griptape/config/azure_openai_structure_config.py rename to griptape/config/azure_openai_driver_config.py index ce0303e34..ef965fa28 100644 --- a/griptape/config/azure_openai_structure_config.py +++ b/griptape/config/azure_openai_driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, @@ -20,7 +20,7 @@ @define -class AzureOpenAiStructureConfig(StructureConfig): +class AzureOpenAiDriverConfig(DriverConfig): """Azure OpenAI Structure Configuration. Attributes: @@ -43,7 +43,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": False}, ) api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory( lambda self: AzureOpenAiChatPromptDriver( model="gpt-4o", @@ -57,7 +57,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory( lambda self: AzureOpenAiImageGenerationDriver( model="dall-e-2", @@ -72,7 +72,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory( lambda self: AzureOpenAiImageQueryDriver( model="gpt-4o", @@ -86,7 +86,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: AzureOpenAiEmbeddingDriver( model="text-embedding-3-small", @@ -100,8 +100,8 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), metadata={"serializable": True}, kw_only=True, ) diff --git a/griptape/config/base_driver_config.py b/griptape/config/base_driver_config.py new file mode 100644 index 000000000..46ff181d3 --- /dev/null +++ b/griptape/config/base_driver_config.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +if TYPE_CHECKING: + from griptape.drivers import ( + BaseAudioTranscriptionDriver, + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseTextToSpeechDriver, + BaseVectorStoreDriver, + ) + + +@define +class BaseDriverConfig(ABC): + prompt: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) + image_generation: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) + image_query: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) + embedding: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) + vector_store: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) + conversation_memory: Optional[BaseConversationMemoryDriver] = field( + default=None, + kw_only=True, + metadata={"serializable": True}, + ) + text_to_speech: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) + audio_transcription: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py deleted file mode 100644 index bc9238df2..000000000 --- a/griptape/config/base_structure_config.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.config import BaseConfig - -if TYPE_CHECKING: - from griptape.drivers import ( - BaseAudioTranscriptionDriver, - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BaseImageQueryDriver, - BasePromptDriver, - BaseTextToSpeechDriver, - BaseVectorStoreDriver, - ) - - -@define -class BaseStructureConfig(BaseConfig, ABC): - prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) - image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) - image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) - embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) - vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( - default=None, - kw_only=True, - metadata={"serializable": True}, - ) - text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) - audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/cohere_structure_config.py b/griptape/config/cohere_driver_config.py similarity index 76% rename from griptape/config/cohere_structure_config.py rename to griptape/config/cohere_driver_config.py index 2e896b9b0..7195f550f 100644 --- a/griptape/config/cohere_structure_config.py +++ b/griptape/config/cohere_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -12,15 +12,15 @@ @define -class CohereStructureConfig(StructureConfig): +class CohereDriverConfig(DriverConfig): api_key: str = field(metadata={"serializable": False}, kw_only=True) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: CohereEmbeddingDriver( model="embed-english-v3.0", @@ -32,8 +32,8 @@ class CohereStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/config.py b/griptape/config/config.py index 3985abca2..f325d1265 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,17 +1,17 @@ from attrs import define from griptape.config.base_config import BaseConfig -from griptape.config.base_structure_config import BaseStructureConfig +from griptape.config.base_driver_config import BaseDriverConfig from griptape.mixins.event_publisher_mixin import EventPublisherMixin -from .openai_structure_config import OpenAiStructureConfig +from .openai_driver_config import OpenAiDriverConfig @define class _Config(BaseConfig, EventPublisherMixin): - drivers: BaseStructureConfig + drivers: BaseDriverConfig Config = _Config( - drivers=OpenAiStructureConfig(), + drivers=OpenAiDriverConfig(), ) diff --git a/griptape/config/structure_config.py b/griptape/config/driver_config.py similarity index 74% rename from griptape/config/structure_config.py rename to griptape/config/driver_config.py index d68b6e2e2..325591258 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import BaseStructureConfig +from griptape.config import BaseDriverConfig from griptape.drivers import ( DummyAudioTranscriptionDriver, DummyEmbeddingDriver, @@ -29,43 +29,43 @@ @define -class StructureConfig(BaseStructureConfig): - prompt_driver: BasePromptDriver = field( +class DriverConfig(BaseDriverConfig): + prompt: BasePromptDriver = field( kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True}, ) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + conversation_memory: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True}, ) - text_to_speech_driver: BaseTextToSpeechDriver = field( + text_to_speech: BaseTextToSpeechDriver = field( default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True}, ) - audio_transcription_driver: BaseAudioTranscriptionDriver = field( + audio_transcription: BaseAudioTranscriptionDriver = field( default=Factory(lambda: DummyAudioTranscriptionDriver()), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_driver_config.py similarity index 75% rename from griptape/config/google_structure_config.py rename to griptape/config/google_driver_config.py index 66ed90b4b..a1089f0ee 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -12,18 +12,18 @@ @define -class GoogleStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class GoogleDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: GooglePromptDriver(model="gemini-1.5-pro")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")), ), diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_driver_config.py similarity index 76% rename from griptape/config/openai_structure_config.py rename to griptape/config/openai_driver_config.py index 63806dfc9..35ccde43d 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseAudioTranscriptionDriver, BaseEmbeddingDriver, @@ -20,40 +20,40 @@ @define -class OpenAiStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class OpenAiDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True}, kw_only=True, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory(lambda: OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory(lambda: OpenAiImageQueryDriver(model="gpt-4o")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")), ), kw_only=True, metadata={"serializable": True}, ) - text_to_speech_driver: BaseTextToSpeechDriver = field( + text_to_speech: BaseTextToSpeechDriver = field( default=Factory(lambda: OpenAiTextToSpeechDriver(model="tts")), kw_only=True, metadata={"serializable": True}, ) - audio_transcription_driver: BaseAudioTranscriptionDriver = field( + audio_transcription: BaseAudioTranscriptionDriver = field( default=Factory(lambda: OpenAiAudioTranscriptionDriver(model="whisper-1")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index aad669d70..51022e47c 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -8,7 +8,7 @@ @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: Config.drivers.audio_transcription_driver), kw_only=True + default=Factory(lambda: Config.drivers.audio_transcription), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 16634ce45..a163c36fd 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Config.drivers.text_to_speech_driver), kw_only=True + default=Factory(lambda: Config.drivers.text_to_speech), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 03826ab43..a1bcbdee2 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 4187dde79..921d600c7 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: Config.drivers.image_generation_driver) + kw_only=True, default=Factory(lambda: Config.drivers.image_generation) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index 5090e2f27..d85e6012d 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -13,9 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: Config.drivers.image_query_driver), kw_only=True - ) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.drivers.image_query), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 979723beb..8e421d792 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -17,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 392a6836d..4daa10e54 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 2586a8e0c..1c45fa5ea 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/exceptions/dummy_exception.py b/griptape/exceptions/dummy_exception.py index 815cb245f..172aeadc6 100644 --- a/griptape/exceptions/dummy_exception.py +++ b/griptape/exceptions/dummy_exception.py @@ -2,7 +2,7 @@ class DummyError(Exception): def __init__(self, dummy_class_name: str, dummy_method_name: str) -> None: message = ( f"You have attempted to use a {dummy_class_name}'s {dummy_method_name} method. " - "This likely originated from using a `StructureConfig` without providing a Driver required for this feature." + "This likely originated from using a `DriverConfig` without providing a Driver required for this feature." ) super().__init__(message) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 3c3a0aaca..e7c8ed488 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: Config.drivers.conversation_memory_driver), kw_only=True + default=Factory(lambda: Config.drivers.conversation_memory), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = Config.drivers.prompt_driver + prompt_driver = Config.drivers.prompt temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 161a68eb3..50be69a61 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt_driver)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 134274648..460581997 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -16,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 010e8ef1f..4db8228e3 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -140,10 +140,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=Config.drivers.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt_driver), + vector_store_driver=Config.drivers.vector_store, + summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt), + json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 6997c9558..19580b642 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -17,7 +17,7 @@ @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 6455efd14..99b5a7dc3 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -27,7 +27,7 @@ class Chat: def default_output_fn(self, text: str) -> None: from griptape.config import Config - if Config.drivers.prompt_driver.stream: + if Config.drivers.prompt.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 @@ -44,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if Config.drivers.prompt_driver.stream: + if Config.drivers.prompt.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 7c716787b..cb5266378 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -36,7 +36,7 @@ class Stream: def validate_structure(self, _: Attribute, structure: Structure) -> None: from griptape.config import Config - if not Config.drivers.prompt_driver.stream: + if not Config.drivers.prompt.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_driver_config.py similarity index 63% rename from tests/mocks/mock_structure_config.py rename to tests/mocks/mock_driver_config.py index 0b374449d..c7407b8bc 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -9,20 +9,18 @@ @define -class MockStructureConfig(StructureConfig): - prompt_driver: MockPromptDriver = field( - default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True} - ) - image_generation_driver: MockImageGenerationDriver = field( +class MockDriverConfig(DriverConfig): + prompt: MockPromptDriver = field(default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True}) + image_generation: MockImageGenerationDriver = field( default=Factory(lambda: MockImageGenerationDriver(model="dall-e-2")), metadata={"serializable": True} ) - image_query_driver: MockImageQueryDriver = field( + image_query: MockImageQueryDriver = field( default=Factory(lambda: MockImageQueryDriver(model="gpt-4-vision-preview")), metadata={"serializable": True} ) - embedding_driver: MockEmbeddingDriver = field( + embedding: MockEmbeddingDriver = field( default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} ) - vector_store_driver: LocalVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: LocalVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), metadata={"serializable": True}, ) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_driver_config.py similarity index 71% rename from tests/unit/config/test_amazon_bedrock_structure_config.py rename to tests/unit/config/test_amazon_bedrock_driver_config.py index afe9b3720..4fdbfedbc 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_driver_config.py @@ -1,7 +1,7 @@ import boto3 import pytest -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig from tests.utils.aws import mock_aws_credentials @@ -13,11 +13,11 @@ def _run_before_and_after_tests(self): @pytest.fixture() def config(self): mock_aws_credentials() - return AmazonBedrockStructureConfig() + return AmazonBedrockDriverConfig() @pytest.fixture() def config_with_values(self): - return AmazonBedrockStructureConfig( + return AmazonBedrockDriverConfig( session=boto3.Session( aws_access_key_id="testing", aws_secret_access_key="testing", region_name="region-value" ) @@ -25,9 +25,9 @@ def config_with_values(self): def test_to_dict(self, config): assert config.to_dict() == { - "conversation_memory_driver": None, - "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation_driver": { + "conversation_memory": None, + "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -40,13 +40,13 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt_driver": { + "prompt": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -55,32 +55,31 @@ def test_to_dict(self, config): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockStructureConfig", - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "type": "AmazonBedrockDriverConfig", + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AmazonBedrockStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AmazonBedrockDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() def test_from_dict_with_values(self, config_with_values): assert ( - AmazonBedrockStructureConfig.from_dict(config_with_values.to_dict()).to_dict() - == config_with_values.to_dict() + AmazonBedrockDriverConfig.from_dict(config_with_values.to_dict()).to_dict() == config_with_values.to_dict() ) def test_to_dict_with_values(self, config_with_values): assert config_with_values.to_dict() == { - "conversation_memory_driver": None, - "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation_driver": { + "conversation_memory": None, + "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -93,13 +92,13 @@ def test_to_dict_with_values(self, config_with_values): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt_driver": { + "prompt": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -108,15 +107,15 @@ def test_to_dict_with_values(self, config_with_values): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockStructureConfig", - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "type": "AmazonBedrockDriverConfig", + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } assert config_with_values.session.region_name == "region-value" diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_driver_config.py similarity index 65% rename from tests/unit/config/test_anthropic_structure_config.py rename to tests/unit/config/test_anthropic_driver_config.py index 05519fa5e..654e7ddf3 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import AnthropicStructureConfig +from griptape.config import AnthropicDriverConfig class TestAnthropicStructureConfig: @@ -11,12 +11,12 @@ def _mock_anthropic(self, mocker): @pytest.fixture() def config(self): - return AnthropicStructureConfig() + return AnthropicDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "AnthropicStructureConfig", - "prompt_driver": { + "type": "AnthropicDriverConfig", + "prompt": { "type": "AnthropicPromptDriver", "temperature": 0.1, "max_tokens": 1000, @@ -26,18 +26,18 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": { + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": { "type": "AnthropicImageQueryDriver", "model": "claude-3-5-sonnet-20240620", "max_tokens": 256, }, - "embedding_driver": { + "embedding": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", "input_type": "document", }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "VoyageAiEmbeddingDriver", @@ -45,10 +45,10 @@ def test_to_dict(self, config): "input_type": "document", }, }, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AnthropicStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AnthropicDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_driver_config.py similarity index 84% rename from tests/unit/config/test_azure_openai_structure_config.py rename to tests/unit/config/test_azure_openai_driver_config.py index 810cb41a1..5c43f3522 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig class TestAzureOpenAiStructureConfig: @@ -10,7 +10,7 @@ def mock_openai(self, mocker): @pytest.fixture() def config(self): - return AzureOpenAiStructureConfig( + return AzureOpenAiDriverConfig( azure_endpoint="http://localhost:8080", azure_ad_token="test-token", azure_ad_token_provider=lambda: "test-provider", @@ -18,9 +18,9 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "type": "AzureOpenAiStructureConfig", + "type": "AzureOpenAiDriverConfig", "azure_endpoint": "http://localhost:8080", - "prompt_driver": { + "prompt": { "type": "AzureOpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -36,8 +36,8 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, - "embedding_driver": { + "conversation_memory": None, + "embedding": { "base_url": None, "model": "text-embedding-3-small", "api_version": "2023-05-15", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiEmbeddingDriver", }, - "image_generation_driver": { + "image_generation": { "api_version": "2024-02-01", "base_url": None, "image_size": "512x512", @@ -59,7 +59,7 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "base_url": None, "image_quality": "auto", "max_tokens": 256, @@ -70,7 +70,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiImageQueryDriver", }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -82,6 +82,6 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_driver_config.py similarity index 59% rename from tests/unit/config/test_cohere_structure_config.py rename to tests/unit/config/test_cohere_driver_config.py index 113a589ec..c056cabeb 100644 --- a/tests/unit/config/test_cohere_structure_config.py +++ b/tests/unit/config/test_cohere_driver_config.py @@ -1,22 +1,22 @@ import pytest -from griptape.config import CohereStructureConfig +from griptape.config import CohereDriverConfig class TestCohereStructureConfig: @pytest.fixture() def config(self): - return CohereStructureConfig(api_key="api_key") + return CohereDriverConfig(api_key="api_key") def test_to_dict(self, config): assert config.to_dict() == { - "type": "CohereStructureConfig", - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - "prompt_driver": { + "type": "CohereDriverConfig", + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "prompt": { "type": "CoherePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -25,12 +25,12 @@ def test_to_dict(self, config): "force_single_step": False, "use_native_tools": True, }, - "embedding_driver": { + "embedding": { "type": "CohereEmbeddingDriver", "model": "embed-english-v3.0", "input_type": "search_document", }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "CohereEmbeddingDriver", diff --git a/tests/unit/config/test_driver_config.py b/tests/unit/config/test_driver_config.py new file mode 100644 index 000000000..e5585de24 --- /dev/null +++ b/tests/unit/config/test_driver_config.py @@ -0,0 +1,39 @@ +import pytest + +from griptape.config import DriverConfig + + +class TestStructureConfig: + @pytest.fixture() + def config(self): + return DriverConfig() + + def test_to_dict(self, config): + assert config.to_dict() == { + "type": "DriverConfig", + "prompt": { + "type": "DummyPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "use_native_tools": False, + }, + "conversation_memory": None, + "embedding": {"type": "DummyEmbeddingDriver"}, + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "vector_store": { + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "type": "DummyVectorStoreDriver", + }, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + } + + def test_from_dict(self, config): + assert DriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + + def test_dot_update(self, config): + config.prompt.max_tokens = 10 + + assert config.prompt.max_tokens == 10 diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_driver_config.py similarity index 63% rename from tests/unit/config/test_google_structure_config.py rename to tests/unit/config/test_google_driver_config.py index e193cc983..53663caf0 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig class TestGoogleStructureConfig: @@ -10,12 +10,12 @@ def mock_openai(self, mocker): @pytest.fixture() def config(self): - return GoogleStructureConfig() + return GoogleDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "GoogleStructureConfig", - "prompt_driver": { + "type": "GoogleDriverConfig", + "prompt": { "type": "GooglePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -26,15 +26,15 @@ def test_to_dict(self, config): "tool_choice": "auto", "use_native_tools": True, }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "embedding_driver": { + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "embedding": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", "task_type": "retrieval_document", "title": None, }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "GoogleEmbeddingDriver", @@ -43,10 +43,10 @@ def test_to_dict(self, config): "title": None, }, }, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert GoogleStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert GoogleDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_driver_config.py similarity index 81% rename from tests/unit/config/test_openai_structure_config.py rename to tests/unit/config/test_openai_driver_config.py index 8969e0ad0..7af0b755a 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig class TestOpenAiStructureConfig: @@ -10,12 +10,12 @@ def mock_openai(self, mocker): @pytest.fixture() def config(self): - return OpenAiStructureConfig() + return OpenAiDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "OpenAiStructureConfig", - "prompt_driver": { + "type": "OpenAiDriverConfig", + "prompt": { "type": "OpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -28,14 +28,14 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, - "embedding_driver": { + "conversation_memory": None, + "embedding": { "base_url": None, "model": "text-embedding-3-small", "organization": None, "type": "OpenAiEmbeddingDriver", }, - "image_generation_driver": { + "image_generation": { "api_version": None, "base_url": None, "image_size": "512x512", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "style": None, "type": "OpenAiImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "api_version": None, "base_url": None, "image_quality": "auto", @@ -55,7 +55,7 @@ def test_to_dict(self, config): "organization": None, "type": "OpenAiImageQueryDriver", }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -64,7 +64,7 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech_driver": { + "text_to_speech": { "type": "OpenAiTextToSpeechDriver", "api_version": None, "base_url": None, @@ -73,7 +73,7 @@ def test_to_dict(self, config): "organization": None, "voice": "alloy", }, - "audio_transcription_driver": { + "audio_transcription": { "type": "OpenAiAudioTranscriptionDriver", "api_version": None, "base_url": None, @@ -83,4 +83,4 @@ def test_to_dict(self, config): } def test_from_dict(self, config): - assert OpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert OpenAiDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py deleted file mode 100644 index cce97647e..000000000 --- a/tests/unit/config/test_structure_config.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest - -from griptape.config import StructureConfig - - -class TestStructureConfig: - @pytest.fixture() - def config(self): - return StructureConfig() - - def test_to_dict(self, config): - assert config.to_dict() == { - "type": "StructureConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "use_native_tools": False, - }, - "conversation_memory_driver": None, - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "vector_store_driver": { - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "type": "DummyVectorStoreDriver", - }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - } - - def test_from_dict(self, config): - assert StructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - def test_dot_update(self, config): - config.prompt_driver.max_tokens = 10 - - assert config.prompt_driver.max_tokens == 10 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e49de0021..9207bbc1c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -2,7 +2,7 @@ from griptape.config import Config from griptape.events import EventBus -from tests.mocks.mock_structure_config import MockStructureConfig +from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) @@ -16,6 +16,6 @@ def event_bus(): @pytest.fixture(autouse=True) def mock_config(): - Config.drivers = MockStructureConfig() + Config.drivers = MockDriverConfig() return Config diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 84fd0bed1..248c259e5 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=2) + mock_config.drivers.prompt = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.drivers.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.drivers.prompt = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -47,7 +47,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.drivers.prompt = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index b2e9c069b..c2bb45208 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.drivers.prompt = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index d2681877f..038cb4508 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(stream=True) + mock_config.drivers.prompt = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 06e54e6c4..f0e4b0af3 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -119,9 +119,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver( - tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) - ) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -145,7 +143,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) agent = Agent() memory = ConversationMemory( autoprune=True, diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 3eef4eec3..8f9278c3c 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -13,9 +13,7 @@ def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) def test_run(self, task, mock_config): - mock_config.drivers.prompt_driver.mock_output = ( - '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - ) + mock_config.drivers.prompt.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 2c0dc1b28..d18d75d75 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="agent mock output") + mock_config.drivers.prompt = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + mock_config.drivers.prompt = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 70ab05e12..18521632e 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,9 +168,7 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.drivers.prompt_driver = MockPromptDriver( - mock_output=f"```python foo bar\n{json.dumps(output_dict)}" - ) + mock_config.drivers.prompt = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") return Agent() diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 15f5a59b1..c1b91b1ed 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.drivers.prompt_driver.mock_output = output + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.drivers.prompt_driver.mock_output = output + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.drivers.prompt_driver.mock_output = output + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 48dbaae29..318f434c3 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -10,11 +10,11 @@ class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - Config.drivers.prompt_driver.stream = request.param + Config.drivers.prompt.stream = request.param return Agent() def test_init(self, agent): - if Config.drivers.prompt_driver.stream: + if Config.drivers.prompt.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 5b908065b..9fa5e559a 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -25,9 +25,7 @@ def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]: return [ - prompt_driver_option.prompt_driver - for prompt_driver_option in prompt_drivers_options - if prompt_driver_option.enabled + prompt_driver_option.prompt for prompt_driver_option in prompt_drivers_options if prompt_driver_option.enabled ] From ef24d49211d9afd73da7111c3745dd3e551913fb Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:37:43 -0700 Subject: [PATCH 07/63] Revert changelog update --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0720ca47..ea88983f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -263,7 +263,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `StructureConfig.drivers.global_drivers`. Pass Drivers directly to the Structure Config instead. +- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. - **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `StructureConfig.task_memory` not defaulting to using `StructureConfig.drivers.global_drivers` by default. +- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. ## [0.23.1] - 2024-03-07 From ab6578315abc8cf7607f770b83e9ed93b344e076 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:39:51 -0700 Subject: [PATCH 08/63] Rename Structure Config to Driver Config --- CHANGELOG.md | 32 +++++++++---------- docs/griptape-framework/structures/config.md | 22 ++++++------- griptape/config/azure_openai_driver_config.py | 2 +- .../test_amazon_bedrock_driver_config.py | 2 +- .../config/test_anthropic_driver_config.py | 2 +- .../config/test_azure_openai_driver_config.py | 2 +- .../unit/config/test_cohere_driver_config.py | 2 +- tests/unit/config/test_driver_config.py | 2 +- .../unit/config/test_google_driver_config.py | 2 +- .../unit/config/test_openai_driver_config.py | 2 +- 10 files changed, 35 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea88983f3..785a5d083 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,7 +112,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `GoogleWebSearchDriver` to web search with the Google Customsearch API. - `DuckDuckGoWebSearchDriver` to web search with the DuckDuckGo search SDK. - `ProxyWebScraperDriver` to web scrape using proxies. -- Parameter `session` on `AmazonBedrockStructureConfig`. +- Parameter `session` on `AmazonBedrockDriverConfig`. - Parameter `meta` on `TextArtifact`. - `VectorStoreClient` improvements: - `VectorStoreClient.query_params` dict for custom query params. @@ -155,7 +155,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - **BREAKING**: All `futures_executor` fields renamed to `futures_executor_fn` and now accept callables instead of futures; wrapped all future `submit` calls with the `with` block to address future executor shutdown issues. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. -- Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. +- Default Prompt Driver model in `GoogleDriverConfig` to `gemini-1.5-pro`. ### Fixed - `CoherePromptDriver` to properly handle empty history. @@ -175,7 +175,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Base Tool schema so that `input` is optional when no Tool Activity schema is set. - Tool Task system prompt for better results with lower-end models. -- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicStructureConfig` and `AmazonBedrockStructureConfig.` +- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicDriverConfig` and `AmazonBedrockDriverConfig.` ## [0.27.0] - 2024-06-19 @@ -186,7 +186,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseTask.add_parents()` to add multiple parent tasks to a child task. - `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. - `CohereEmbeddingDriver` for using Cohere's embeddings API. -- `CohereStructureConfig` for providing Structures with quick Cohere configuration. +- `CohereDriverConfig` for providing Structures with quick Cohere configuration. - `AmazonSageMakerJumpstartPromptDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.custom_attributes` for setting custom attributes when invoking an endpoint. @@ -252,7 +252,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.26.0] - 2024-06-04 ### Added -- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. +- `AzureOpenAiDriverConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. - `AudioLoader` for loading audio content into an `AudioArtifact`. - `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures. @@ -263,8 +263,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. -- **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. +- **BREAKING**: Removed `DriverConfig.global_drivers`. Pass Drivers directly to the Driver Config instead. +- **BREAKING**: Removed `DriverConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. - **BREAKING**: `AmazonSageMakerPromptDriver.model` parameter, which gets passed to `SageMakerRuntime.Client.invoke_endpoint` as `EndpointName`, is now renamed to `AmazonSageMakerPromptDriver.endpoint`. @@ -293,7 +293,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Default behavior of Event Listener Drivers to batch events. -- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. +- Default behavior of OpenAiDriverConfig to utilize `gpt-4o` for prompt_driver. ## [0.25.0] - 2024-05-06 @@ -359,7 +359,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `text-embedding-3-small` and `text-embedding-3-large` models. - `GooglePromptDriver` and `GoogleTokenizer` for use with `gemini-pro`. - `GoogleEmbeddingDriver` for use with `embedding-001`. -- `GoogleStructureConfig` for providing Structures with Google Prompt and Embedding Driver configuration. +- `GoogleDriverConfig` for providing Structures with Google Prompt and Embedding Driver configuration. - Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`. - Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`. - `top_k` and `top_p` parameters in `AnthropicPromptDriver`. @@ -369,7 +369,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura. - `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify. - `VoyageAiEmbeddingDriver` for use with VoyageAi's embedding models. -- `AnthropicStructureConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. +- `AnthropicDriverConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. - `QdrantVectorStoreDriver` to integrate with Qdrant vector databases. ### Fixed @@ -380,9 +380,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`. - **BREAKING**: `OpenAiVisionImageQueryDriver` field `model` no longer defaults to `gpt-4-vision-preview` and must be specified - Default model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`. -- Default model of `OpenAiStructureConfig` to `text-embedding-3-small`. +- Default model of `OpenAiDriverConfig` to `text-embedding-3-small`. - `BaseTextLoader` to accept a `BaseChunker`. -- Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. +- Default model of `AmazonBedrockDriverConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. - `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API. - `OpenAiVisionImageQueryDriver` now has a required field `max_tokens` that defaults to 256 @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. +- `DriverConfig.task_memory` not defaulting to using `DriverConfig.global_drivers` by default. ## [0.23.1] - 2024-03-07 @@ -408,9 +408,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AzureMongoDbVectorStoreDriver` for using CosmosDB with MongoDB vCore API. - `vector_path` field on `MongoDbAtlasVectorStoreDriver`. - `LeonardoImageGenerationDriver` supports image to image generation. -- `OpenAiStructureConfig` for providing Structures with all OpenAi Driver configuration. -- `AmazonBedrockStructureConfig` for providing Structures with all Amazon Bedrock Driver configuration. -- `StructureConfig` for building your own Structure configuration. +- `OpenAiDriverConfig` for providing Structures with all OpenAi Driver configuration. +- `AmazonBedrockDriverConfig` for providing Structures with all Amazon Bedrock Driver configuration. +- `DriverConfig` for building your own Structure configuration. - `JsonExtractionTask` for convenience over using `ExtractionTask` with a `JsonExtractionEngine`. - `CsvExtractionTask` for convenience over using `ExtractionTask` with a `CsvExtractionEngine`. - `OpenAiVisionImageQueryDriver` to support queries on images using OpenAI's Vision model. diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 17fb9e5da..917485837 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -5,15 +5,15 @@ search: ## Overview -The [StructureConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. +The [DriverConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. ### Premade Configs -Griptape provides predefined [StructureConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. +Griptape provides predefined [DriverConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. #### OpenAI -The [OpenAI Structure Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. +The [OpenAI Driver Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python from griptape.structures import Agent @@ -28,7 +28,7 @@ agent = Agent() # This is equivalent to the above #### Azure OpenAI -The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. +The [Azure OpenAI Driver Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python import os @@ -48,7 +48,7 @@ agent = Agent( ``` #### Amazon Bedrock -The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Driver Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python import os @@ -68,7 +68,7 @@ agent = Agent( ``` #### Google -The [Google Structure Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Driver Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python from griptape.structures import Agent @@ -81,11 +81,11 @@ agent = Agent( #### Anthropic -The [Anthropic Structure Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. - The `AnthropicStructureConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). + The `AnthropicDriverConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). ```python @@ -99,7 +99,7 @@ agent = Agent( #### Cohere -The [Cohere Structure Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. +The [Cohere Driver Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python import os @@ -111,8 +111,8 @@ agent = Agent(config=CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])) ### Custom Configs -You can create your own [StructureConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. -The [StructureConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. +You can create your own [DriverConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. +The [DriverConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. ```python diff --git a/griptape/config/azure_openai_driver_config.py b/griptape/config/azure_openai_driver_config.py index ef965fa28..c987a31b5 100644 --- a/griptape/config/azure_openai_driver_config.py +++ b/griptape/config/azure_openai_driver_config.py @@ -21,7 +21,7 @@ @define class AzureOpenAiDriverConfig(DriverConfig): - """Azure OpenAI Structure Configuration. + """Azure OpenAI Driver Configuration. Attributes: azure_endpoint: The endpoint for the Azure OpenAI instance. diff --git a/tests/unit/config/test_amazon_bedrock_driver_config.py b/tests/unit/config/test_amazon_bedrock_driver_config.py index 4fdbfedbc..57a80809e 100644 --- a/tests/unit/config/test_amazon_bedrock_driver_config.py +++ b/tests/unit/config/test_amazon_bedrock_driver_config.py @@ -5,7 +5,7 @@ from tests.utils.aws import mock_aws_credentials -class TestAmazonBedrockStructureConfig: +class TestAmazonBedrockDriverConfig: @pytest.fixture(autouse=True) def _run_before_and_after_tests(self): mock_aws_credentials() diff --git a/tests/unit/config/test_anthropic_driver_config.py b/tests/unit/config/test_anthropic_driver_config.py index 654e7ddf3..a2ccbd25b 100644 --- a/tests/unit/config/test_anthropic_driver_config.py +++ b/tests/unit/config/test_anthropic_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import AnthropicDriverConfig -class TestAnthropicStructureConfig: +class TestAnthropicDriverConfig: @pytest.fixture(autouse=True) def _mock_anthropic(self, mocker): mocker.patch("anthropic.Anthropic") diff --git a/tests/unit/config/test_azure_openai_driver_config.py b/tests/unit/config/test_azure_openai_driver_config.py index 5c43f3522..3c88b859d 100644 --- a/tests/unit/config/test_azure_openai_driver_config.py +++ b/tests/unit/config/test_azure_openai_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import AzureOpenAiDriverConfig -class TestAzureOpenAiStructureConfig: +class TestAzureOpenAiDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") diff --git a/tests/unit/config/test_cohere_driver_config.py b/tests/unit/config/test_cohere_driver_config.py index c056cabeb..9e8407d84 100644 --- a/tests/unit/config/test_cohere_driver_config.py +++ b/tests/unit/config/test_cohere_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import CohereDriverConfig -class TestCohereStructureConfig: +class TestCohereDriverConfig: @pytest.fixture() def config(self): return CohereDriverConfig(api_key="api_key") diff --git a/tests/unit/config/test_driver_config.py b/tests/unit/config/test_driver_config.py index e5585de24..dd3fd1a47 100644 --- a/tests/unit/config/test_driver_config.py +++ b/tests/unit/config/test_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import DriverConfig -class TestStructureConfig: +class TestDriverConfig: @pytest.fixture() def config(self): return DriverConfig() diff --git a/tests/unit/config/test_google_driver_config.py b/tests/unit/config/test_google_driver_config.py index 53663caf0..fb6cd23b5 100644 --- a/tests/unit/config/test_google_driver_config.py +++ b/tests/unit/config/test_google_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import GoogleDriverConfig -class TestGoogleStructureConfig: +class TestGoogleDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("google.generativeai.GenerativeModel") diff --git a/tests/unit/config/test_openai_driver_config.py b/tests/unit/config/test_openai_driver_config.py index 7af0b755a..55156730c 100644 --- a/tests/unit/config/test_openai_driver_config.py +++ b/tests/unit/config/test_openai_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import OpenAiDriverConfig -class TestOpenAiStructureConfig: +class TestOpenAiDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.OpenAI") From ecfa3583c0728c6f8822eb1269c37b7bc6596072 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:49:45 -0700 Subject: [PATCH 09/63] Rename doc fields --- docs/examples/multiple-agent-shared-memory.md | 4 +-- .../drivers/embedding-drivers.md | 4 +-- .../drivers/event-listener-drivers.md | 2 +- .../drivers/prompt-drivers.md | 26 +++++++++---------- docs/griptape-framework/misc/events.md | 2 +- docs/griptape-framework/structures/config.md | 2 +- .../official-tools/rest-api-client.md | 2 +- tests/utils/structure_tester.py | 16 +++++++----- 8 files changed, 30 insertions(+), 28 deletions(-) diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index e6b092965..0fe589d7b 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -42,8 +42,8 @@ mongo_driver = AzureMongoDbVectorStoreDriver( config = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, - vector_store_driver=mongo_driver, - embedding_driver=embedding_driver, + vector_store=mongo_driver, + embedding=embedding_driver, ) loader = Agent( diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index de2f2d379..3c81cf8a9 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -225,8 +225,8 @@ from griptape.config import DriverConfig agent = Agent( tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), - embedding_driver=VoyageAiEmbeddingDriver(), + prompt=OpenAiChatPromptDriver(model="gpt-4o"), + embedding=VoyageAiEmbeddingDriver(), ), ) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 0adb0b10f..8d4f521aa 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -139,7 +139,7 @@ agent = Agent( ) ], config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( model="gpt-3.5-turbo", temperature=0.7 ) ), diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 8693cc6ff..367bff4ba 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -17,7 +17,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), + prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), ), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[ @@ -75,7 +75,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( api_key=os.environ["OPENAI_API_KEY"], temperature=0.1, model="gpt-4o", @@ -110,7 +110,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True ) @@ -138,7 +138,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AzureOpenAiChatPromptDriver( + prompt=AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-3.5-turbo", azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], @@ -172,7 +172,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=CoherePromptDriver( + prompt=CoherePromptDriver( model="command-r", api_key=os.environ['COHERE_API_KEY'], ) @@ -198,7 +198,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AnthropicPromptDriver( + prompt=AnthropicPromptDriver( model="claude-3-opus-20240229", api_key=os.environ['ANTHROPIC_API_KEY'], ) @@ -224,7 +224,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=GooglePromptDriver( + prompt=GooglePromptDriver( model="gemini-pro", api_key=os.environ['GOOGLE_API_KEY'], ) @@ -252,7 +252,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AmazonBedrockPromptDriver( + prompt=AmazonBedrockPromptDriver( model="anthropic.claude-3-sonnet-20240229-v1:0", ) ), @@ -292,7 +292,7 @@ from griptape.structures import Agent agent = Agent( config=DriverConfig( - prompt_driver=OllamaPromptDriver( + prompt=OllamaPromptDriver( model="llama3.1", ), ), @@ -322,7 +322,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=HuggingFaceHubPromptDriver( + prompt=HuggingFaceHubPromptDriver( model="HuggingFaceH4/zephyr-7b-beta", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ) @@ -356,7 +356,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=HuggingFaceHubPromptDriver( + prompt=HuggingFaceHubPromptDriver( model="http://127.0.0.1:8080", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), @@ -384,7 +384,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=HuggingFacePipelinePromptDriver( + prompt=HuggingFacePipelinePromptDriver( model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) ), @@ -425,7 +425,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AmazonSageMakerJumpstartPromptDriver( + prompt=AmazonSageMakerJumpstartPromptDriver( endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], model="meta-llama/Meta-Llama-3-8B-Instruct", ) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index dfb6e2db3..b7f118d98 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -149,7 +149,7 @@ EventBus.event_listeners = [ pipeline = Pipeline( config=OpenAiDriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True) + prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True) ), ) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 917485837..13b6c001a 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -123,7 +123,7 @@ from griptape.drivers import AnthropicPromptDriver agent = Agent( config=DriverConfig( - prompt_driver=AnthropicPromptDriver( + prompt=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], ) diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 304ec00ec..675f77b6e 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -118,7 +118,7 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( model="gpt-4o", temperature=0.1 ), diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 9fa5e559a..2b9f83b81 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -226,6 +226,15 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: + from griptape.config import Config + + Config.drivers.prompt = AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-4o", + azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], + response_format="json_object", + ) output_schema = Schema( { Literal("correct", description="Whether the output was correct or not."): bool, @@ -263,13 +272,6 @@ def verify_structure_output(self, structure) -> dict: ], ), ], - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-4o", - azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - response_format="json_object", - ), tasks=[ PromptTask( "\nTasks: {{ task_names }}" From 582f55c372c1a08611c8dc73f8f1072637940050 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:18:34 -0700 Subject: [PATCH 10/63] Move events into config --- CHANGELOG.md | 4 ++-- griptape/config/base_config.py | 6 +++++- griptape/config/base_driver_config.py | 4 +++- griptape/config/config.py | 17 +++++++---------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 785a5d083..1b224bdd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseVectorStoreDriver.load_artifacts` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.upsert_vector` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.query` optional arguments are now keyword-only arguments. -- **BREAKING**: `EventListener.publish_event`'s `flush` argument is now a keyword-only argument. -- **BREAKING**: `BaseEventListenerDriver.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `EventListener.events.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `BaseEventListenerDriver.events.publish_event`'s `flush` argument is now a keyword-only argument. - **BREAKING**: Renamed `DummyException` to `DummyError` for pep8 naming compliance. - **BREAKING**: Migrate to `sqlalchemy` 2.0. - **BREAKING**: Make `sqlalchemy` an optional dependency. diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index 241efadcd..a3f132ea9 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -2,8 +2,12 @@ from attrs import define +from griptape.config.base_driver_config import BaseDriverConfig +from griptape.config.events_config import EventsConfig from griptape.mixins.serializable_mixin import SerializableMixin @define -class BaseConfig(SerializableMixin, ABC): ... +class BaseConfig(SerializableMixin, ABC): + drivers: BaseDriverConfig + events: EventsConfig diff --git a/griptape/config/base_driver_config.py b/griptape/config/base_driver_config.py index 46ff181d3..df32d382e 100644 --- a/griptape/config/base_driver_config.py +++ b/griptape/config/base_driver_config.py @@ -5,6 +5,8 @@ from attrs import define, field +from griptape.mixins import SerializableMixin + if TYPE_CHECKING: from griptape.drivers import ( BaseAudioTranscriptionDriver, @@ -19,7 +21,7 @@ @define -class BaseDriverConfig(ABC): +class BaseDriverConfig(ABC, SerializableMixin): prompt: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) image_query: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/config.py b/griptape/config/config.py index f325d1265..71edef8e6 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,17 +1,14 @@ -from attrs import define - -from griptape.config.base_config import BaseConfig -from griptape.config.base_driver_config import BaseDriverConfig -from griptape.mixins.event_publisher_mixin import EventPublisherMixin +from attrs import Factory, define, field +from .base_config import BaseConfig +from .events_config import EventsConfig from .openai_driver_config import OpenAiDriverConfig @define -class _Config(BaseConfig, EventPublisherMixin): - drivers: BaseDriverConfig +class _Config(BaseConfig): + drivers: OpenAiDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) + events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) -Config = _Config( - drivers=OpenAiDriverConfig(), -) +Config = _Config() From e80f78e3c0c95087e7a1dcec5917d387f19dcd7a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:22:14 -0700 Subject: [PATCH 11/63] Add back util fields for Agent --- griptape/structures/agent.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 31e0a424f..f31e9d2eb 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -2,16 +2,18 @@ from typing import TYPE_CHECKING, Callable, Optional -from attrs import Attribute, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable +from griptape.config import Config from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask if TYPE_CHECKING: from griptape.artifacts import BaseArtifact + from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask from griptape.tools import BaseTool @@ -21,6 +23,8 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) + stream: bool = field(default=False, kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -33,11 +37,19 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() + self.prompt_driver.stream = self.stream if len(self.tasks) == 0: if self.tools: - task = ToolkitTask(self.input, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries) + task = ToolkitTask( + self.input, + prompt_driver=self.prompt_driver, + tools=self.tools, + max_meta_memory_entries=self.max_meta_memory_entries, + ) else: - task = PromptTask(self.input, max_meta_memory_entries=self.max_meta_memory_entries) + task = PromptTask( + self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries + ) self.add_task(task) From 91619a5b1031654bd596f66ac21ba495a366c095 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:23:44 -0700 Subject: [PATCH 12/63] Revert changelog replaces --- CHANGELOG.md | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b224bdd4..ea88983f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseVectorStoreDriver.load_artifacts` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.upsert_vector` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.query` optional arguments are now keyword-only arguments. -- **BREAKING**: `EventListener.events.publish_event`'s `flush` argument is now a keyword-only argument. -- **BREAKING**: `BaseEventListenerDriver.events.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `EventListener.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `BaseEventListenerDriver.publish_event`'s `flush` argument is now a keyword-only argument. - **BREAKING**: Renamed `DummyException` to `DummyError` for pep8 naming compliance. - **BREAKING**: Migrate to `sqlalchemy` 2.0. - **BREAKING**: Make `sqlalchemy` an optional dependency. @@ -112,7 +112,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `GoogleWebSearchDriver` to web search with the Google Customsearch API. - `DuckDuckGoWebSearchDriver` to web search with the DuckDuckGo search SDK. - `ProxyWebScraperDriver` to web scrape using proxies. -- Parameter `session` on `AmazonBedrockDriverConfig`. +- Parameter `session` on `AmazonBedrockStructureConfig`. - Parameter `meta` on `TextArtifact`. - `VectorStoreClient` improvements: - `VectorStoreClient.query_params` dict for custom query params. @@ -155,7 +155,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - **BREAKING**: All `futures_executor` fields renamed to `futures_executor_fn` and now accept callables instead of futures; wrapped all future `submit` calls with the `with` block to address future executor shutdown issues. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. -- Default Prompt Driver model in `GoogleDriverConfig` to `gemini-1.5-pro`. +- Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. ### Fixed - `CoherePromptDriver` to properly handle empty history. @@ -175,7 +175,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Base Tool schema so that `input` is optional when no Tool Activity schema is set. - Tool Task system prompt for better results with lower-end models. -- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicDriverConfig` and `AmazonBedrockDriverConfig.` +- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicStructureConfig` and `AmazonBedrockStructureConfig.` ## [0.27.0] - 2024-06-19 @@ -186,7 +186,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseTask.add_parents()` to add multiple parent tasks to a child task. - `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. - `CohereEmbeddingDriver` for using Cohere's embeddings API. -- `CohereDriverConfig` for providing Structures with quick Cohere configuration. +- `CohereStructureConfig` for providing Structures with quick Cohere configuration. - `AmazonSageMakerJumpstartPromptDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.custom_attributes` for setting custom attributes when invoking an endpoint. @@ -252,7 +252,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.26.0] - 2024-06-04 ### Added -- `AzureOpenAiDriverConfig` for providing Structures with all Azure OpenAI Driver configuration. +- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. - `AudioLoader` for loading audio content into an `AudioArtifact`. - `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures. @@ -263,8 +263,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `DriverConfig.global_drivers`. Pass Drivers directly to the Driver Config instead. -- **BREAKING**: Removed `DriverConfig.task_memory` in favor of configuring directly on the Structure. +- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. +- **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. - **BREAKING**: `AmazonSageMakerPromptDriver.model` parameter, which gets passed to `SageMakerRuntime.Client.invoke_endpoint` as `EndpointName`, is now renamed to `AmazonSageMakerPromptDriver.endpoint`. @@ -293,7 +293,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Default behavior of Event Listener Drivers to batch events. -- Default behavior of OpenAiDriverConfig to utilize `gpt-4o` for prompt_driver. +- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. ## [0.25.0] - 2024-05-06 @@ -359,7 +359,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `text-embedding-3-small` and `text-embedding-3-large` models. - `GooglePromptDriver` and `GoogleTokenizer` for use with `gemini-pro`. - `GoogleEmbeddingDriver` for use with `embedding-001`. -- `GoogleDriverConfig` for providing Structures with Google Prompt and Embedding Driver configuration. +- `GoogleStructureConfig` for providing Structures with Google Prompt and Embedding Driver configuration. - Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`. - Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`. - `top_k` and `top_p` parameters in `AnthropicPromptDriver`. @@ -369,7 +369,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura. - `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify. - `VoyageAiEmbeddingDriver` for use with VoyageAi's embedding models. -- `AnthropicDriverConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. +- `AnthropicStructureConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. - `QdrantVectorStoreDriver` to integrate with Qdrant vector databases. ### Fixed @@ -380,9 +380,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`. - **BREAKING**: `OpenAiVisionImageQueryDriver` field `model` no longer defaults to `gpt-4-vision-preview` and must be specified - Default model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`. -- Default model of `OpenAiDriverConfig` to `text-embedding-3-small`. +- Default model of `OpenAiStructureConfig` to `text-embedding-3-small`. - `BaseTextLoader` to accept a `BaseChunker`. -- Default model of `AmazonBedrockDriverConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. +- Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. - `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API. - `OpenAiVisionImageQueryDriver` now has a required field `max_tokens` that defaults to 256 @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `DriverConfig.task_memory` not defaulting to using `DriverConfig.global_drivers` by default. +- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. ## [0.23.1] - 2024-03-07 @@ -408,9 +408,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AzureMongoDbVectorStoreDriver` for using CosmosDB with MongoDB vCore API. - `vector_path` field on `MongoDbAtlasVectorStoreDriver`. - `LeonardoImageGenerationDriver` supports image to image generation. -- `OpenAiDriverConfig` for providing Structures with all OpenAi Driver configuration. -- `AmazonBedrockDriverConfig` for providing Structures with all Amazon Bedrock Driver configuration. -- `DriverConfig` for building your own Structure configuration. +- `OpenAiStructureConfig` for providing Structures with all OpenAi Driver configuration. +- `AmazonBedrockStructureConfig` for providing Structures with all Amazon Bedrock Driver configuration. +- `StructureConfig` for building your own Structure configuration. - `JsonExtractionTask` for convenience over using `ExtractionTask` with a `JsonExtractionEngine`. - `CsvExtractionTask` for convenience over using `ExtractionTask` with a `CsvExtractionEngine`. - `OpenAiVisionImageQueryDriver` to support queries on images using OpenAI's Vision model. From 9ebd44e51f321c69b1981b657f4bd433a038518c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:25:36 -0700 Subject: [PATCH 13/63] Fix type --- griptape/config/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/griptape/config/config.py b/griptape/config/config.py index 71edef8e6..3920bdbb0 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,13 +1,14 @@ from attrs import Factory, define, field from .base_config import BaseConfig +from .base_driver_config import BaseDriverConfig from .events_config import EventsConfig from .openai_driver_config import OpenAiDriverConfig @define class _Config(BaseConfig): - drivers: OpenAiDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) + drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) From b7e1359d886a84ab8d47d80cae7380558d54e345 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 14:58:53 -0700 Subject: [PATCH 14/63] Revert some of agent test --- tests/unit/structures/test_agent.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 15e1399b6..d82de015c 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -5,13 +5,16 @@ from griptape.rules import Rule, Ruleset from griptape.structures import Agent from griptape.tasks import BaseTask, PromptTask, ToolkitTask +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool class TestAgent: def test_init(self): - agent = Agent(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + driver = MockPromptDriver() + agent = Agent(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + assert agent.prompt_driver is driver assert isinstance(agent.task, PromptTask) assert isinstance(agent.task, PromptTask) assert agent.rulesets[0].name == "TestRuleset" @@ -77,7 +80,7 @@ def test_without_default_task_memory(self): assert agent.tools[0].output_memory is None def test_with_memory(self): - agent = Agent(conversation_memory=ConversationMemory()) + agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) assert agent.conversation_memory is not None assert len(agent.conversation_memory.runs) == 0 @@ -99,7 +102,7 @@ def test_tasks_initialization(self): assert agent.tasks[0] == task def test_add_task(self): - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) assert len(agent.tasks) == 1 @@ -127,7 +130,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) try: agent.add_tasks(first_task, second_task) @@ -142,7 +145,7 @@ def test_add_tasks(self): assert True def test_prompt_stack_without_memory(self): - agent = Agent(conversation_memory=None, rules=[Rule("test")]) + agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") @@ -159,7 +162,7 @@ def test_prompt_stack_without_memory(self): assert len(task1.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - agent = Agent(conversation_memory=ConversationMemory(), rules=[Rule("test")]) + agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory(), rules=[Rule("test")]) task1 = PromptTask("test") @@ -177,7 +180,7 @@ def test_prompt_stack_with_memory(self): def test_run(self): task = PromptTask("test") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) assert task.state == BaseTask.State.PENDING @@ -189,7 +192,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) agent._execution_args = ("test1", "test2") @@ -202,7 +205,7 @@ def test_run_with_args(self): def test_context(self): task = PromptTask("test prompt") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) @@ -214,7 +217,7 @@ def test_context(self): def finished_tasks(self): task = PromptTask("test prompt") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) @@ -224,4 +227,4 @@ def finished_tasks(self): def test_fail_fast(self): with pytest.raises(ValueError): - Agent(fail_fast=True) + Agent(prompt_driver=MockPromptDriver(), fail_fast=True) From 97564a22d03d550373d2b02802e346979eaf1f5e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 11:42:27 -0700 Subject: [PATCH 15/63] Add logging module to config --- griptape/config/config.py | 2 ++ griptape/config/logging_config.py | 24 ++++++++++++++++++++ griptape/structures/structure.py | 22 ------------------ griptape/tasks/actions_subtask.py | 17 ++++++++------ griptape/tasks/base_audio_generation_task.py | 8 +++++-- griptape/tasks/base_audio_input_task.py | 8 +++++-- griptape/tasks/base_image_generation_task.py | 7 +++++- griptape/tasks/base_multi_text_input_task.py | 8 +++++-- griptape/tasks/base_task.py | 5 +++- griptape/tasks/base_text_input_task.py | 8 +++++-- griptape/tasks/prompt_task.py | 7 ++++-- 11 files changed, 75 insertions(+), 41 deletions(-) create mode 100644 griptape/config/logging_config.py diff --git a/griptape/config/config.py b/griptape/config/config.py index 3920bdbb0..8d29b5a0f 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -3,6 +3,7 @@ from .base_config import BaseConfig from .base_driver_config import BaseDriverConfig from .events_config import EventsConfig +from .logging_config import LoggingConfig from .openai_driver_config import OpenAiDriverConfig @@ -10,6 +11,7 @@ class _Config(BaseConfig): drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) + logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) Config = _Config() diff --git a/griptape/config/logging_config.py b/griptape/config/logging_config.py new file mode 100644 index 000000000..0c0fcc020 --- /dev/null +++ b/griptape/config/logging_config.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import logging + +from attrs import define, field +from rich.logging import RichHandler + + +@define +class LoggingConfig: + logger_name: str = field(default="griptape", kw_only=True) + logger_level: int = field( + default=logging.INFO, + kw_only=True, + on_setattr=lambda self, _, value: logging.getLogger(self.logger_name).setLevel(value), + ) + + def __attrs_post_init__(self) -> None: + logger = logging.getLogger(self.logger_name) + + logger.propagate = False + logger.setLevel(self.logger_level) + + logger.handlers = [RichHandler(show_time=True, show_path=False)] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 4db8228e3..49197592f 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -1,13 +1,10 @@ from __future__ import annotations -import logging import uuid from abc import ABC, abstractmethod -from logging import Logger from typing import TYPE_CHECKING, Any, Optional from attrs import Attribute, Factory, define, field -from rich.logging import RichHandler from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable @@ -35,14 +32,10 @@ @define class Structure(ABC): - LOGGER_NAME = "griptape" - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) rulesets: list[Ruleset] = field(factory=list, kw_only=True) rules: list[Rule] = field(factory=list, kw_only=True) tasks: list[BaseTask] = field(factory=list, kw_only=True) - custom_logger: Optional[Logger] = field(default=None, kw_only=True) - logger_level: int = field(default=logging.INFO, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory(lambda: ConversationMemory()), kw_only=True, @@ -55,7 +48,6 @@ class Structure(ABC): meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) fail_fast: bool = field(default=True, kw_only=True) _execution_args: tuple = () - _logger: Optional[Logger] = None @rulesets.validator # pyright: ignore[reportAttributeAccessIssue] def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None: @@ -88,20 +80,6 @@ def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: def execution_args(self) -> tuple: return self._execution_args - @property - def logger(self) -> Logger: - if self.custom_logger: - return self.custom_logger - else: - if self._logger is None: - self._logger = logging.getLogger(self.LOGGER_NAME) - - self._logger.propagate = False - self._logger.level = self.logger_level - - self._logger.handlers = [RichHandler(show_time=True, show_path=False)] - return self._logger - @property def input_task(self) -> Optional[BaseTask]: return self.tasks[0] if self.tasks else None diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..e3c2aeb12 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import logging import re from typing import TYPE_CHECKING, Callable, Optional @@ -18,6 +19,8 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory +logger = logging.getLogger(Config.logging.logger_name) + @define class ActionsSubtask(BaseTask): @@ -86,7 +89,7 @@ def attach_to(self, parent_task: BaseTask) -> None: else: self.__init_from_artifacts(self.input) except Exception as e: - self.structure.logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) + logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) @@ -109,7 +112,7 @@ def before_run(self) -> None: *([f"\nThought: {self.thought}"] if self.thought is not None else []), f"\nActions: {self.actions_to_json()}", ] - self.structure.logger.info("".join(parts)) + logger.info("".join(parts)) def run(self) -> BaseArtifact: try: @@ -128,7 +131,7 @@ def run(self) -> BaseArtifact: actions_output.append(output) self.output = ListArtifact(actions_output) except Exception as e: - self.structure.logger.exception("Subtask %s\n%s", self.id, e) + logger.exception("Subtask %s\n%s", self.id, e) self.output = ErrorArtifact(str(e), exception=e) if self.output is not None: @@ -169,7 +172,7 @@ def after_run(self) -> None: subtask_actions=self.actions_to_dicts(), ), ) - self.structure.logger.info("Subtask %s\nResponse: %s", self.id, response) + logger.info("Subtask %s\nResponse: %s", self.id, response) def actions_to_dicts(self) -> list[dict]: json_list = [] @@ -257,7 +260,7 @@ def __parse_actions(self, actions_matches: list[str]) -> None: self.actions = [self.__process_action_object(action_object) for action_object in actions_list] except json.JSONDecodeError as e: - self.structure.logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e) self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e) @@ -314,10 +317,10 @@ def __validate_action(self, action: ToolAction) -> None: if activity_schema: activity_schema.validate(action.input) except schema.SchemaError as e: - self.structure.logger.exception("Subtask %s\nInvalid action JSON: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nInvalid action JSON: %s", self.origin_task.id, e) action.output = ErrorArtifact(f"Activity input JSON validation error: {e}", exception=e) except SyntaxError as e: - self.structure.logger.exception("Subtask %s\nSyntax error: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nSyntax error: %s", self.origin_task.id, e) action.output = ErrorArtifact(f"Syntax error: {e}", exception=e) diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index d2657561d..4d9d82362 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -1,21 +1,25 @@ from __future__ import annotations +import logging from abc import ABC from attrs import define +from griptape.config import Config from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index 517c03a15..febd3f508 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -1,14 +1,18 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact +from griptape.config.config import Config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseAudioInputTask(RuleMixin, BaseTask, ABC): @@ -30,9 +34,9 @@ def input(self, value: AudioArtifact | Callable[[BaseTask], AudioArtifact]) -> N def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index d32e8f142..afbc2c05e 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os from abc import ABC from pathlib import Path @@ -7,6 +8,7 @@ from attrs import Attribute, define, field +from griptape.config import Config from griptape.loaders import ImageLoader from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.rules import Rule, Ruleset @@ -16,6 +18,9 @@ from griptape.artifacts import MediaArtifact +logger = logging.getLogger(Config.logging.logger_name) + + @define class BaseImageGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): """Provides a base class for image generation-related tasks. @@ -60,5 +65,5 @@ def all_negative_rulesets(self) -> list[Ruleset]: return task_rulesets def _read_from_file(self, path: str) -> MediaArtifact: - self.structure.logger.info("Reading image from %s", os.path.abspath(path)) + logger.info("Reading image from %s", os.path.abspath(path)) return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index a0d8cb9ac..6962098ca 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import Factory, define, field from griptape.artifacts import ListArtifact, TextArtifact +from griptape.config import Config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseMultiTextInputTask(RuleMixin, BaseTask, ABC): @@ -48,9 +52,9 @@ def before_run(self) -> None: super().before_run() joined_input = "\n".join([i.to_text() for i in self.input]) - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..cdaf1b032 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import uuid from abc import ABC, abstractmethod from concurrent import futures @@ -16,6 +17,8 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseTask(ABC): @@ -159,7 +162,7 @@ def execute(self) -> Optional[BaseArtifact]: self.after_run() except Exception as e: - self.structure.logger.exception("%s %s\n%s", self.__class__.__name__, self.id, e) + logger.exception("%s %s\n%s", self.__class__.__name__, self.id, e) self.output = ErrorArtifact(str(e), exception=e) finally: diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 90f60efcd..16f8c705c 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import define, field from griptape.artifacts import TextArtifact +from griptape.config import Config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseTextInputTask(RuleMixin, BaseTask, ABC): @@ -36,9 +40,9 @@ def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 19580b642..3769f26dc 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Callable, Optional from attrs import Factory, define, field @@ -14,6 +15,8 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver +logger = logging.getLogger(Config.logging.logger_name) + @define class PromptTask(RuleMixin, BaseTask): @@ -65,12 +68,12 @@ def default_system_template_generator(self, _: PromptTask) -> str: def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) def run(self) -> BaseArtifact: message = self.prompt_driver.run(self.prompt_stack) From 4220e0f1392f33463f1f43c82b2ea9f999f718e8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:56:48 -0700 Subject: [PATCH 16/63] Fix bad rebase --- griptape/config/base_config.py | 9 +++++---- griptape/config/config.py | 2 -- griptape/drivers/prompt/base_prompt_driver.py | 2 -- griptape/tasks/actions_subtask.py | 1 + griptape/tasks/base_task.py | 1 + griptape/utils/stream.py | 2 -- tests/unit/drivers/prompt/test_base_prompt_driver.py | 3 +-- tests/unit/tasks/test_base_task.py | 2 +- 8 files changed, 9 insertions(+), 13 deletions(-) diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index a3f132ea9..9209aa4a4 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -2,12 +2,13 @@ from attrs import define -from griptape.config.base_driver_config import BaseDriverConfig -from griptape.config.events_config import EventsConfig from griptape.mixins.serializable_mixin import SerializableMixin +from .base_driver_config import BaseDriverConfig +from .logging_config import LoggingConfig -@define + +@define(kw_only=True) class BaseConfig(SerializableMixin, ABC): drivers: BaseDriverConfig - events: EventsConfig + logging: LoggingConfig diff --git a/griptape/config/config.py b/griptape/config/config.py index 8d29b5a0f..d81a8974b 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -2,7 +2,6 @@ from .base_config import BaseConfig from .base_driver_config import BaseDriverConfig -from .events_config import EventsConfig from .logging_config import LoggingConfig from .openai_driver_config import OpenAiDriverConfig @@ -10,7 +9,6 @@ @define class _Config(BaseConfig): drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) - events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index b6c28560b..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -113,8 +113,6 @@ def __process_run(self, prompt_stack: PromptStack) -> Message: return result def __process_stream(self, prompt_stack: PromptStack) -> Message: - from griptape.config import Config - delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index e3c2aeb12..2f199e368 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -11,6 +11,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction +from griptape.config import Config from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index cdaf1b032..c42f73629 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -10,6 +10,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact +from griptape.config import Config from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index cb5266378..87cb9dec8 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -56,8 +56,6 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args) -> None: - from griptape.config import Config - def event_handler(event: BaseEvent) -> None: self._event_queue.put(event) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 248c259e5..c30acdec4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -28,8 +28,7 @@ def test_run_via_pipeline_retries_failure(self, mock_config): def test_run_via_pipeline_publishes_events(self, mocker): mock_publish_event = mocker.patch.object(_EventBus, "publish_event") - driver = MockPromptDriver() - pipeline = Pipeline(prompt_driver=driver) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) pipeline.run() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d22ef35f7..d4e0ce23d 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -18,7 +18,7 @@ def task(self): agent = Agent( tools=[MockTool()], ) - Config.event_listeners = [EventListener(handler=Mock())] + EventBus.event_listeners = [EventListener(handler=Mock())] agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) From f8a616fcc93d35c4e9c10d5e8131bd64d57fb68c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 09:47:58 -0700 Subject: [PATCH 17/63] Update docs --- docs/examples/talk-to-a-video.md | 5 +- .../drivers/embedding-drivers.md | 11 +- .../drivers/event-listener-drivers.md | 13 +- docs/griptape-framework/structures/config.md | 123 ++++++++---------- .../structures/task-memory.md | 10 +- .../official-tools/rest-api-client.md | 15 ++- 6 files changed, 83 insertions(+), 94 deletions(-) diff --git a/docs/examples/talk-to-a-video.md b/docs/examples/talk-to-a-video.md index 310b6d407..cf41dea0f 100644 --- a/docs/examples/talk-to-a-video.md +++ b/docs/examples/talk-to-a-video.md @@ -7,9 +7,11 @@ import time from griptape.structures import Agent from griptape.tasks import PromptTask from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import GoogleDriverConfig +from griptape.config import Config import google.generativeai as genai +Config.drivers = GoogleDriverConfig() + video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": time.sleep(2) @@ -19,7 +21,6 @@ if video_file.state.name == "FAILED": raise ValueError(video_file.state.name) agent = Agent( - config=GoogleDriverConfig(), input=[ GenericArtifact(video_file), TextArtifact("Answer this question regarding the video: {{ args[0] }}"), diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 3c81cf8a9..7a8fd96a1 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -220,14 +220,15 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import DriverConfig +from griptape.config import DriverConfig, Config -agent = Agent( - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], - config=DriverConfig( +Config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), embedding=VoyageAiEmbeddingDriver(), - ), +) + +agent = Agent( + tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], ) agent.run("based on https://www.griptape.ai/, tell me what Griptape is") diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 8d4f521aa..20ae045f4 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -123,7 +123,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import DriverConfig +from griptape.config import DriverConfig, Config from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, @@ -132,17 +132,18 @@ from griptape.events import ( from griptape.rules import Rule from griptape.structures import Agent +Config.drivers = DriverConfig( + prompt=OpenAiChatPromptDriver( + model="gpt-3.5-turbo", temperature=0.7 + ) +) + agent = Agent( rules=[ Rule( value="You will be provided with a text, and your task is to extract the airport codes from it." ) ], - config=DriverConfig( - prompt=OpenAiChatPromptDriver( - model="gpt-3.5-turbo", temperature=0.7 - ) - ), event_listeners=[ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 13b6c001a..b4c928ff7 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -17,13 +17,11 @@ The [OpenAI Driver Config](../../reference/griptape/config/openai_driver_config. ```python from griptape.structures import Agent -from griptape.config import OpenAiDriverConfig +from griptape.config import OpenAiDriverConfig, Config -agent = Agent( - config=OpenAiDriverConfig() -) +Config.drivers = OpenAiDriverConfig() -agent = Agent() # This is equivalent to the above +agent = Agent() ``` #### Azure OpenAI @@ -33,18 +31,14 @@ The [Azure OpenAI Driver Config](../../reference/griptape/config/azure_openai_dr ```python import os from griptape.structures import Agent -from griptape.config import AzureOpenAiDriverConfig - -agent = Agent( - config=AzureOpenAiDriverConfig( - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], - api_key=os.environ["AZURE_OPENAI_API_KEY_3"] - ).merge_config({ - "image_query": { - "azure_deployment": "gpt-4o", - }, - }), +from griptape.config import AzureOpenAiDriverConfig, Config + +Config.drivers = AzureOpenAiDriverConfig( + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], + api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) + +agent = Agent() ``` #### Amazon Bedrock @@ -54,17 +48,17 @@ The [Amazon Bedrock Driver Config](../../reference/griptape/config/amazon_bedroc import os import boto3 from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig - -agent = Agent( - config=AmazonBedrockDriverConfig( - session=boto3.Session( - region_name=os.environ["AWS_DEFAULT_REGION"], - aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - ) +from griptape.config import AmazonBedrockDriverConfig, Config + +Config.drivers = AmazonBedrockDriverConfig( + session=boto3.Session( + region_name=os.environ["AWS_DEFAULT_REGION"], + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], ) ) + +agent = Agent() ``` #### Google @@ -72,11 +66,11 @@ The [Google Driver Config](../../reference/griptape/config/google_driver_config. ```python from griptape.structures import Agent -from griptape.config import GoogleDriverConfig +from griptape.config import GoogleDriverConfig, Config -agent = Agent( - config=GoogleDriverConfig() -) +Config.drivers = GoogleDriverConfig() + +agent = Agent() ``` #### Anthropic @@ -90,11 +84,11 @@ The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_c ```python from griptape.structures import Agent -from griptape.config import AnthropicDriverConfig +from griptape.config import AnthropicDriverConfig, Config -agent = Agent( - config=AnthropicDriverConfig() -) +Config.drivers = AnthropicDriverConfig() + +agent = Agent() ``` #### Cohere @@ -103,10 +97,12 @@ The [Cohere Driver Config](../../reference/griptape/config/cohere_driver_config. ```python import os -from griptape.config import CohereDriverConfig +from griptape.config import CohereDriverConfig, Config from griptape.structures import Agent -agent = Agent(config=CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])) +Config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) + +agent = Agent() ``` ### Custom Configs @@ -118,49 +114,38 @@ This approach ensures that you are informed through clear error messages if you ```python import os from griptape.structures import Agent -from griptape.config import DriverConfig +from griptape.config import DriverConfig, Config from griptape.drivers import AnthropicPromptDriver -agent = Agent( - config=DriverConfig( - prompt=AnthropicPromptDriver( - model="claude-3-sonnet-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) - ), +Config.drivers = DriverConfig( + prompt=AnthropicPromptDriver( + model="claude-3-sonnet-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], + ) ) + + +agent = Agent() ``` ### Loading/Saving Configs -Configuration classes in Griptape offer utility methods for loading, saving, and merging configurations, streamlining the management of complex setups. - ```python from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig -from griptape.drivers import AmazonBedrockCohereEmbeddingDriver +from griptape.config import AmazonBedrockDriverConfig, Config custom_config = AmazonBedrockDriverConfig() -custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() -custom_config.merge_config( - { - "embedding": { - "base_url": None, - "model": "text-embedding-3-small", - "organization": None, - "type": "OpenAiEmbeddingDriver", - }, - } -) -serialized_config = custom_config.to_json() -deserialized_config = AmazonBedrockDriverConfig.from_json(serialized_config) - -agent = Agent( - config=deserialized_config.merge_config({ - "prompt": { - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - }, - }), -) +dict_config = custom_config.to_dict() +# Use OpenAi for embeddings +dict_config["embedding"] = { + "base_url": None, + "model": "text-embedding-3-small", + "organization": None, + "type": "OpenAiEmbeddingDriver", +} +custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) + +Config.drivers = custom_config + +agent = Agent() ``` - diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 3184c4096..49d6b28cf 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ```python from griptape.artifacts import TextArtifact from griptape.config import ( - OpenAiDriverConfig, + Config, OpenAiDriverConfig, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -220,12 +220,13 @@ from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent from griptape.tools import FileManager, TaskMemoryClient, WebScraper +Config.drivers = OpenAiDriverConfig( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), +) + vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) agent = Agent( - config=OpenAiDriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), - ), task_memory=TaskMemory( artifact_storages={ TextArtifact: TextArtifactStorage( @@ -233,7 +234,6 @@ agent = Agent( retrieval_stage=RetrievalRagStage( retrieval_modules=[ VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver, query_params={ "namespace": "griptape", diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 675f77b6e..a73f6fa57 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -14,7 +14,14 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import DriverConfig +from griptape.config import Config + +Config.drivers = DriverConfig( + prompt=OpenAiChatPromptDriver( + model="gpt-4o", + temperature=0.1 + ), +) posts_client = RestApiClient( base_url="https://jsonplaceholder.typicode.com", @@ -117,12 +124,6 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), - config=DriverConfig( - prompt=OpenAiChatPromptDriver( - model="gpt-4o", - temperature=0.1 - ), - ), ) pipeline.add_tasks( From 95d83b5634409b95a76480480213aa23b6a5451f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:23:25 -0700 Subject: [PATCH 18/63] Add global event bus --- CHANGELOG.md | 3 + docs/griptape-framework/misc/events.md | 61 ++++++++++--------- griptape/config/base_structure_config.py | 40 ------------ .../base_audio_transcription_driver.py | 10 +-- .../embedding/base_embedding_driver.py | 4 +- .../base_image_generation_driver.py | 10 +-- .../image_query/base_image_query_driver.py | 10 +-- .../base_conversation_memory_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 16 ++--- .../base_text_to_speech_driver.py | 9 +-- .../vector/base_vector_store_driver.py | 4 +- griptape/events/__init__.py | 2 + .../event_bus.py} | 5 +- griptape/mixins/__init__.py | 2 - griptape/structures/structure.py | 12 ++-- griptape/tasks/actions_subtask.py | 6 +- griptape/tasks/base_task.py | 6 +- griptape/utils/stream.py | 9 +-- tests/unit/config/test_structure_config.py | 35 ----------- tests/unit/conftest.py | 12 ++++ .../test_base_audio_transcription_driver.py | 4 +- .../test_base_image_generation_driver.py | 9 +-- .../test_base_image_query_driver.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 7 +-- .../test_base_audio_transcription_driver.py | 4 +- tests/unit/events/test_event_bus.py | 45 ++++++++++++++ tests/unit/events/test_event_listener.py | 29 ++++----- tests/unit/mixins/test_events_mixin.py | 59 ------------------ tests/unit/tasks/test_base_task.py | 5 +- 29 files changed, 176 insertions(+), 250 deletions(-) rename griptape/{mixins/event_publisher_mixin.py => events/event_bus.py} (96%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/events/test_event_bus.py delete mode 100644 tests/unit/mixins/test_events_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d8bf2e72..6748299d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. ### Changed +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 1f50fd6d0..187321dc6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,15 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, + EventBus ) def handler(event: BaseEvent): print(event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener( handler, event_types=[ @@ -44,7 +43,8 @@ agent = Agent( ], ) ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -69,7 +69,8 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener +from griptape.events import BaseEvent, EventListener, EventBus + def handler1(event: BaseEvent): @@ -79,13 +80,12 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -131,7 +131,7 @@ Handler 2 list: - return [ - self.prompt_driver, - self.image_generation_driver, - self.image_query_driver, - self.embedding_driver, - self.vector_store_driver, - self.conversation_memory_driver, - self.text_to_speech_driver, - self.audio_transcription_driver, - ] - - @property - def structure(self) -> Optional[Structure]: - return self._structure - - @structure.setter - def structure(self, structure: Structure) -> None: - if structure != self.structure: - event_publisher_drivers = [ - driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) - ] - - for driver in event_publisher_drivers: - if self._event_listener is not None: - driver.remove_event_listener(self._event_listener) - - self._event_listener = EventListener(structure.publish_event) - for driver in event_publisher_drivers: - driver.add_event_listener(self._event_listener) - - self._structure = structure - def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() merged_config = dict_merge(base_config, config) diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index c81ea1d5b..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import AudioArtifact, TextArtifact @define -class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - self.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - self.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 690726060..8998f00e5 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -7,7 +7,7 @@ from attrs import define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import TextArtifact @@ -15,7 +15,7 @@ @define -class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index f500d6d09..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @define -class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - self.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b39f198d4..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,24 +5,24 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index f13b82c29..1caeb902f 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory -class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, memory: BaseConversationMemory) -> None: ... diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e5fd0408d..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,8 +16,8 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from collections.abc import Iterator @@ -26,7 +26,7 @@ @define(kw_only=True) -class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublishe use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - self.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - self.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - self.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - self.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index 788d92974..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,23 +5,24 @@ from attrs import define, field +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @define -class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - self.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - self.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index d1da78188..ed1f2d589 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -10,14 +10,14 @@ from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,6 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -48,4 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", + "EventBus", ] diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/events/event_bus.py similarity index 96% rename from griptape/mixins/event_publisher_mixin.py rename to griptape/events/event_bus.py index 67a302ed6..9239e66bd 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/events/event_bus.py @@ -9,7 +9,7 @@ @define -class EventPublisherMixin: +class _EventBus: event_listeners: list[EventListener] = field(factory=list, kw_only=True) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: @@ -32,3 +32,6 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) + + +EventBus = _EventBus() diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 944027c59..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,7 +4,6 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin -from .event_publisher_mixin import EventPublisherMixin __all__ = [ "ActivityMixin", @@ -13,5 +12,4 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", - "EventPublisherMixin", ] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 079e0b741..df7113c23 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,13 +28,11 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events.finish_structure_run_event import FinishStructureRunEvent -from griptape.events.start_structure_run_event import StartStructureRunEvent +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.mixins import EventPublisherMixin from griptape.utils import deprecation_warn if TYPE_CHECKING: @@ -44,7 +42,7 @@ @define -class Structure(ABC, EventPublisherMixin): +class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @@ -97,8 +95,6 @@ def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self - self.config.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) @@ -261,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - self.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -273,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - self.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index cde59d0ef..07f49f52a 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - self.structure.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - self.structure.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 8c50e4df9..9a8361e6c 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import FinishTaskEvent, StartTaskEvent +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index bf33e5df8..4a7899b2a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,10 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events.completion_chunk_event import CompletionChunkEvent -from griptape.events.event_listener import EventListener -from griptape.events.finish_prompt_event import FinishPromptEvent -from griptape.events.finish_structure_run_event import FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -64,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - self.structure.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - self.structure.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,7 +1,6 @@ import pytest from griptape.config import StructureConfig -from griptape.structures import Agent class TestStructureConfig: @@ -61,37 +60,3 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 - - def test_drivers(self, config): - assert config.drivers == [ - config.prompt_driver, - config.image_generation_driver, - config.image_query_driver, - config.embedding_driver, - config.vector_store_driver, - config.conversation_memory_driver, - config.text_to_speech_driver, - config.audio_transcription_driver, - ] - - def test_structure(self, config): - structure_1 = Agent( - config=config, - ) - - assert config.structure == structure_1 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 - - structure_2 = Agent( - config=config, - ) - assert config.structure == structure_2 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..0be2f9758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from griptape.events import EventBus + + +@pytest.fixture(autouse=True) +def event_bus(): + EventBus.event_listeners = [] + + yield EventBus + + EventBus.event_listeners = [] diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 519e40f57..fc41837fd 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 7447b2c08..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -14,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -30,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -52,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -80,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index 14de15f2d..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 2708b0a88..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(EventPublisherMixin, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) @@ -42,8 +42,7 @@ def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): - pipeline = Pipeline() - result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[])) + result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 8af5dc827..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py new file mode 100644 index 000000000..fd862913e --- /dev/null +++ b/tests/unit/events/test_event_bus.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from griptape.events import EventBus, EventListener +from tests.mocks.mock_event import MockEvent + + +class TestEventBus: + def test_add_event_listeners(self): + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 + + def test_add_event_listener(self): + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listener(self): + listener = EventListener() + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) + + assert len(EventBus.event_listeners) == 0 + + def test_remove_unknown_event_listener(self): + EventBus.remove_event_listener(EventListener()) + + def test_publish_event(self): + # Given + mock_handler = Mock() + mock_handler.return_value = None + EventBus.event_listeners = [EventListener(handler=mock_handler)] + mock_event = MockEvent() + + # When + EventBus.publish_event(mock_event) + + # Then + mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b245c2be9..5601aef34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -37,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -58,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - pipeline.event_listeners = [ + EventBus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -86,25 +87,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - pipeline.event_listeners = [] + EventBus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(pipeline.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_3) - pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) - assert len(pipeline.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py deleted file mode 100644 index 99f5541ba..000000000 --- a/tests/unit/mixins/test_events_mixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -from griptape.events import EventListener -from griptape.mixins import EventPublisherMixin -from tests.mocks.mock_event import MockEvent - - -class TestEventsMixin: - def test_init(self): - assert EventPublisherMixin() - - def test_add_event_listeners(self): - mixin = EventPublisherMixin() - - mixin.add_event_listeners([EventListener(), EventListener()]) - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listeners(self): - mixin = EventPublisherMixin() - - listeners = [EventListener(), EventListener()] - mixin.add_event_listeners(listeners) - mixin.remove_event_listeners(listeners) - assert len(mixin.event_listeners) == 0 - - def test_add_event_listener(self): - mixin = EventPublisherMixin() - - mixin.add_event_listener(EventListener()) - mixin.add_event_listener(EventListener()) - - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listener(self): - mixin = EventPublisherMixin() - - listener = EventListener() - mixin.add_event_listener(listener) - mixin.remove_event_listener(listener) - - assert len(mixin.event_listeners) == 0 - - def test_remove_unknown_event_listener(self): - mixin = EventPublisherMixin() - - mixin.remove_event_listener(EventListener()) - - def test_publish_event(self): - # Given - mock_handler = Mock() - mock_handler.return_value = None - mixin = EventPublisherMixin(event_listeners=[EventListener(handler=mock_handler)]) - mock_event = MockEvent() - - # When - mixin.publish_event(mock_event) - - # Then - mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 4f4b43d40..636515106 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -15,11 +16,11 @@ class TestBaseTask: @pytest.fixture() def task(self): + EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], - event_listeners=[EventListener(handler=Mock())], ) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -117,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert task.structure.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 From 951a4ed1fb163a47c5330d26de2dcc7c704b0e1b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 09:11:00 -0700 Subject: [PATCH 19/63] Update docs --- .../drivers/event-listener-drivers.md | 89 +++++++++++-------- docs/griptape-framework/misc/events.md | 14 +-- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 73453afb6..db02cd77a 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,26 +14,27 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( - handler=lambda event: { # You can optionally use the handler to transform the event payload before sending it to the Driver - "event": event.to_dict(), - }, driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -83,23 +84,26 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -128,10 +132,23 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.rules import Rule from griptape.structures import Agent +EventBus.add_event_listeners( + [ + EventListener( + event_types=[FinishStructureRunEvent], + driver=AwsIotCoreEventListenerDriver( + topic=os.environ["AWS_IOT_CORE_TOPIC"], + iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], + ), + ), + ] +) + agent = Agent( rules=[ Rule( @@ -143,15 +160,6 @@ agent = Agent( model="gpt-3.5-turbo", temperature=0.7 ) ), - event_listeners=[ - EventListener( - event_types=[FinishStructureRunEvent], - driver=AwsIotCoreEventListenerDriver( - topic=os.environ["AWS_IOT_CORE_TOPIC"], - iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], - ), - ), - ], ) agent.run("I want to fly from Orlando to Boston") @@ -171,18 +179,19 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], # By default, GriptapeCloudEventListenerDriver uses the api key provided # in the GT_CLOUD_API_KEY environment variable. driver=GriptapeCloudEventListenerDriver(), ), - ], + ] ) agent.run( @@ -201,20 +210,23 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=WebhookEventListenerDriver( webhook_url=os.environ["WEBHOOK_URL"], ), ), - ], + ] ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` ### Pusher @@ -229,12 +241,13 @@ import os from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, - FinishStructureRunEvent + FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=PusherEventListenerDriver( @@ -250,6 +263,8 @@ agent = Agent( ], ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 187321dc6..23ebcdc2a 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -30,7 +30,7 @@ from griptape.events import ( def handler(event: BaseEvent): print(event.__class__) -EventBus.event_listeners=[ +EventBus.add_event_listeners([ EventListener( handler, event_types=[ @@ -42,7 +42,7 @@ EventBus.event_listeners=[ FinishPromptEvent, ], ) - ] + ]) agent = Agent() @@ -140,12 +140,12 @@ from griptape.drivers import OpenAiChatPromptDriver -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: print(e.token, end="", flush=True), event_types=[CompletionChunkEvent], ) -] +]) pipeline = Pipeline( config=OpenAiStructureConfig( @@ -194,12 +194,12 @@ from griptape.structures import Agent token_counter = utils.TokenCounter() -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: token_counter.add_tokens(e.token_count), event_types=[StartPromptEvent, FinishPromptEvent], ) -] +]) def count_tokens(e: BaseEvent): if isinstance(e, StartPromptEvent) or isinstance(e, FinishPromptEvent): @@ -248,7 +248,7 @@ from griptape.structures import Agent from griptape.events import BaseEvent, StartPromptEvent, EventListener, EventBus -EventBus.event_listeners = [EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])] +EventBus.add_event_listeners([EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])]) def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): From 025437b5f0f376318ce15fa5b9111a89f4608484 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:03:17 -0700 Subject: [PATCH 20/63] Make event listeners private --- griptape/events/event_bus.py | 19 +++++++++++----- tests/unit/conftest.py | 4 ++-- tests/unit/events/test_event_bus.py | 2 +- tests/unit/events/test_event_listener.py | 28 +++++++++++++----------- tests/unit/tasks/test_base_task.py | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..6ffd65550 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -10,7 +10,11 @@ @define class _EventBus: - event_listeners: list[EventListener] = field(factory=list, kw_only=True) + _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") + + @property + def event_listeners(self) -> list[EventListener]: + return self._event_listeners def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: return [self.add_event_listener(event_listener) for event_listener in event_listeners] @@ -20,18 +24,21 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: - if event_listener not in self.event_listeners: - self.event_listeners.append(event_listener) + if event_listener not in self._event_listeners: + self._event_listeners.append(event_listener) return event_listener def remove_event_listener(self, event_listener: EventListener) -> None: - if event_listener in self.event_listeners: - self.event_listeners.remove(event_listener) + if event_listener in self._event_listeners: + self._event_listeners.remove(event_listener) def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: - for event_listener in self.event_listeners: + for event_listener in self._event_listeners: event_listener.publish_event(event, flush=flush) + def clear_event_listeners(self) -> None: + self._event_listeners.clear() + EventBus = _EventBus() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7a73b041f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,8 +5,8 @@ @pytest.fixture(autouse=True) def event_bus(): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() yield EventBus - EventBus.event_listeners = [] + EventBus.clear_event_listeners() diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..d237bb3b4 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -35,7 +35,7 @@ def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + EventBus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..f3d9823d3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,17 +59,19 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ - EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), - EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), - EventListener(start_task_event_handler, event_types=[StartTaskEvent]), - EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), - EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), - EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), - EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), - EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), - ] + EventBus.add_event_listeners( + [ + EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), + EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), + EventListener(start_task_event_handler, event_types=[StartTaskEvent]), + EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), + EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), + EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), + EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), + EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), + EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + ] + ) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -87,7 +89,7 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d6e4da8b6 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), From 0f193854cc4b4230301b22578d5e14b42c1e72f0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:17:34 -0700 Subject: [PATCH 21/63] Rename EventBus to event_bus --- CHANGELOG.md | 4 +-- .../drivers/event-listener-drivers.md | 24 +++++++-------- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../drivers/prompt/test_base_prompt_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 6 ++-- 23 files changed, 109 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6748299d0..7a95701c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index db02cd77a..c3c92cfe1 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,12 +14,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -84,12 +84,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -132,12 +132,12 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -179,11 +179,11 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -210,11 +210,11 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -242,11 +242,11 @@ from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 23ebcdc2a..b3f4a77fd 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.add_event_listeners([ +event_bus.add_event_listeners([ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 6ffd65550..a956f7deb 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -41,4 +41,4 @@ def clear_event_listeners(self) -> None: self._event_listeners.clear() -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..d68457ebc 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,7 +28,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -257,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -269,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..d600c80a5 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..ade656f87 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..fd64a0f52 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -61,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7a73b041f..e462ede90 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,12 @@ import pytest -from griptape.events import EventBus +from griptape.events import event_bus @pytest.fixture(autouse=True) -def event_bus(): - EventBus.clear_event_listeners() +def mock_event_bus(): + event_bus.clear_event_listeners() - yield EventBus + yield event_bus - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..6fcab26e5 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..52b7d5c0d 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _EventBus +from griptape.events.event_bus import _event_bus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_EventBus, "publish_event") + mock_publish_event = mocker.patch.object(_event_bus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index d237bb3b4..7eb87036a 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.add_event_listeners([EventListener(handler=mock_handler)]) + event_bus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index f3d9823d3..50763e0c3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + event_bus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.add_event_listeners( + event_bus.add_event_listeners( [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), @@ -89,25 +89,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d6e4da8b6..aa402bb48 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.add_event_listeners([EventListener(handler=Mock())]) + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), @@ -118,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2 From 729f3aad962e569321c0d263dc919b9972c6e2b2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:29:50 -0700 Subject: [PATCH 22/63] Fix doc --- docs/griptape-framework/misc/events.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b3f4a77fd..ebab3c460 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -80,10 +80,11 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -event_bus.event_listeners=[ +event_bus.add_event_listeners([ EventListener(handler1), EventListener(handler2), ] +) agent = Agent() From c391b1459042f4ad9f621ced9df052a3e0d41f83 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:05:52 -0700 Subject: [PATCH 23/63] Fix test --- tests/unit/drivers/prompt/test_base_prompt_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 52b7d5c0d..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _event_bus +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_event_bus, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) From 7baefacc6853fc6bc7acb405f615f57b74f4f4eb Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:12:00 -0700 Subject: [PATCH 24/63] Rename event bus --- CHANGELOG.md | 4 +-- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 8 ++--- 21 files changed, 96 insertions(+), 96 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea88983f3..f338ff961 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b7f118d98..d37a73663 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..c0881503d 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -34,4 +34,4 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: event_listener.publish_event(event, flush=flush) -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 49197592f..6fea4d2e6 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -18,7 +18,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -180,7 +180,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -192,7 +192,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 2f199e368..ccbf5dbb1 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -12,7 +12,7 @@ from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.config import Config -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -95,7 +95,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -161,7 +161,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index c42f73629..2397fbfd0 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -11,7 +11,7 @@ from griptape.artifacts import ErrorArtifact from griptape.config import Config -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -131,7 +131,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -143,7 +143,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 87cb9dec8..efca5c5b8 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -63,8 +63,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9207bbc1c..01af02573 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,17 +1,17 @@ import pytest from griptape.config import Config -from griptape.events import EventBus +from griptape.events import event_bus from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) -def event_bus(): - EventBus.event_listeners = [] +def mock_event_bus(): + event_bus.event_listeners = [] - yield EventBus + yield event_bus - EventBus.event_listeners = [] + event_bus.event_listeners = [] @pytest.fixture(autouse=True) diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 29aecfdf9..61ef3aa53 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..97aaa239b 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + event_bus.event_listeners = [EventListener(handler=mock_handler)] mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 038cb4508..713e5ce42 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -39,7 +39,7 @@ def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + event_bus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -60,7 +60,7 @@ def test_typed_listeners(self, pipeline, mock_config): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ + event_bus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -88,25 +88,25 @@ def test_typed_listeners(self, pipeline, mock_config): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + event_bus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d4e0ce23d..4dfc890c9 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -14,11 +14,11 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + event_bus.event_listeners = [EventListener(handler=Mock())] agent = Agent( tools=[MockTool()], ) - EventBus.event_listeners = [EventListener(handler=Mock())] + event_bus.event_listeners = [EventListener(handler=Mock())] agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -115,4 +115,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2 From feec94bc7bde4053db08bc19a0632cc89afc3393 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:27:23 -0700 Subject: [PATCH 25/63] Rename Config to config, fix tests --- docs/examples/talk-to-a-video.md | 4 +- .../drivers/embedding-drivers.md | 4 +- .../drivers/event-listener-drivers.md | 4 +- docs/griptape-framework/structures/config.md | 44 +++++++++---------- .../structures/task-memory.md | 4 +- .../official-tools/rest-api-client.md | 4 +- griptape/config/__init__.py | 4 +- griptape/config/config.py | 2 +- .../audio/audio_transcription_engine.py | 4 +- .../engines/audio/text_to_speech_engine.py | 4 +- .../extraction/base_extraction_engine.py | 4 +- .../image/base_image_generation_engine.py | 4 +- .../engines/image_query/image_query_engine.py | 4 +- .../response/prompt_response_rag_module.py | 4 +- .../vector_store_retrieval_rag_module.py | 4 +- .../engines/summary/prompt_summary_engine.py | 4 +- .../structure/base_conversation_memory.py | 6 +-- .../structure/summary_conversation_memory.py | 4 +- .../task/storage/text_artifact_storage.py | 4 +- griptape/structures/agent.py | 4 +- griptape/structures/structure.py | 10 ++--- griptape/tasks/actions_subtask.py | 4 +- griptape/tasks/base_audio_generation_task.py | 4 +- griptape/tasks/base_audio_input_task.py | 4 +- griptape/tasks/base_image_generation_task.py | 4 +- griptape/tasks/base_multi_text_input_task.py | 4 +- griptape/tasks/base_task.py | 4 +- griptape/tasks/base_text_input_task.py | 4 +- griptape/tasks/prompt_task.py | 6 +-- griptape/utils/chat.py | 8 ++-- griptape/utils/stream.py | 4 +- tests/mocks/docker/fake_api.py | 8 ++-- tests/unit/conftest.py | 10 ++--- tests/unit/tasks/test_base_task.py | 2 +- tests/unit/utils/test_stream.py | 6 +-- tests/utils/structure_tester.py | 4 +- 36 files changed, 103 insertions(+), 103 deletions(-) diff --git a/docs/examples/talk-to-a-video.md b/docs/examples/talk-to-a-video.md index cf41dea0f..20ad9952e 100644 --- a/docs/examples/talk-to-a-video.md +++ b/docs/examples/talk-to-a-video.md @@ -7,10 +7,10 @@ import time from griptape.structures import Agent from griptape.tasks import PromptTask from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config import google.generativeai as genai -Config.drivers = GoogleDriverConfig() +config.drivers = GoogleDriverConfig() video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 7a8fd96a1..f210f50b7 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -220,9 +220,9 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import DriverConfig, Config +from griptape.config import DriverConfig, config -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), embedding=VoyageAiEmbeddingDriver(), ) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 4f1eeb391..e4e815709 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -127,7 +127,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import DriverConfig, Config +from griptape.config import DriverConfig, config from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, @@ -137,7 +137,7 @@ from griptape.events import ( from griptape.rules import Rule from griptape.structures import Agent -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver( model="gpt-3.5-turbo", temperature=0.7 ) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index b4c928ff7..33fa27798 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -13,27 +13,27 @@ Griptape provides predefined [DriverConfig](../../reference/griptape/config/driv #### OpenAI -The [OpenAI Driver Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. +The [OpenAI Driver config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python from griptape.structures import Agent -from griptape.config import OpenAiDriverConfig, Config +from griptape.config import OpenAiDriverConfig, config -Config.drivers = OpenAiDriverConfig() +config.drivers = OpenAiDriverConfig() agent = Agent() ``` #### Azure OpenAI -The [Azure OpenAI Driver Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. +The [Azure OpenAI Driver config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python import os from griptape.structures import Agent -from griptape.config import AzureOpenAiDriverConfig, Config +from griptape.config import AzureOpenAiDriverConfig, config -Config.drivers = AzureOpenAiDriverConfig( +config.drivers = AzureOpenAiDriverConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) @@ -42,15 +42,15 @@ agent = Agent() ``` #### Amazon Bedrock -The [Amazon Bedrock Driver Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Driver config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python import os import boto3 from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig, Config +from griptape.config import AmazonBedrockDriverConfig, config -Config.drivers = AmazonBedrockDriverConfig( +config.drivers = AmazonBedrockDriverConfig( session=boto3.Session( region_name=os.environ["AWS_DEFAULT_REGION"], aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], @@ -62,20 +62,20 @@ agent = Agent() ``` #### Google -The [Google Driver Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Driver config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python from griptape.structures import Agent -from griptape.config import GoogleDriverConfig, Config +from griptape.config import GoogleDriverConfig, config -Config.drivers = GoogleDriverConfig() +config.drivers = GoogleDriverConfig() agent = Agent() ``` #### Anthropic -The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Driver config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. @@ -84,23 +84,23 @@ The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_c ```python from griptape.structures import Agent -from griptape.config import AnthropicDriverConfig, Config +from griptape.config import AnthropicDriverConfig, config -Config.drivers = AnthropicDriverConfig() +config.drivers = AnthropicDriverConfig() agent = Agent() ``` #### Cohere -The [Cohere Driver Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. +The [Cohere Driver config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python import os -from griptape.config import CohereDriverConfig, Config +from griptape.config import CohereDriverConfig, config from griptape.structures import Agent -Config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) +config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) agent = Agent() ``` @@ -114,10 +114,10 @@ This approach ensures that you are informed through clear error messages if you ```python import os from griptape.structures import Agent -from griptape.config import DriverConfig, Config +from griptape.config import DriverConfig, config from griptape.drivers import AnthropicPromptDriver -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], @@ -132,7 +132,7 @@ agent = Agent() ```python from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig, Config +from griptape.config import AmazonBedrockDriverConfig, config custom_config = AmazonBedrockDriverConfig() dict_config = custom_config.to_dict() @@ -145,7 +145,7 @@ dict_config["embedding"] = { } custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) -Config.drivers = custom_config +config.drivers = custom_config agent = Agent() ``` diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 49d6b28cf..1fad33856 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ```python from griptape.artifacts import TextArtifact from griptape.config import ( - Config, OpenAiDriverConfig, + config, OpenAiDriverConfig, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -220,7 +220,7 @@ from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent from griptape.tools import FileManager, TaskMemoryClient, WebScraper -Config.drivers = OpenAiDriverConfig( +config.drivers = OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index a73f6fa57..0151c2efd 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -14,9 +14,9 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import Config +from griptape.config import config -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver( model="gpt-4o", temperature=0.1 diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 7450d7738..b242d80a7 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -9,7 +9,7 @@ from .anthropic_driver_config import AnthropicDriverConfig from .google_driver_config import GoogleDriverConfig from .cohere_driver_config import CohereDriverConfig -from .config import Config +from .config import config __all__ = [ @@ -22,5 +22,5 @@ "AnthropicDriverConfig", "GoogleDriverConfig", "CohereDriverConfig", - "Config", + "config", ] diff --git a/griptape/config/config.py b/griptape/config/config.py index d81a8974b..97d501abb 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -12,4 +12,4 @@ class _Config(BaseConfig): logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) -Config = _Config() +config = _Config() diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index 51022e47c..cad8287d5 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -1,14 +1,14 @@ from attrs import Factory, define, field from griptape.artifacts import AudioArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.drivers import BaseAudioTranscriptionDriver @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: Config.drivers.audio_transcription), kw_only=True + default=Factory(lambda: config.drivers.audio_transcription), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index a163c36fd..aad45a10a 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Config.drivers.text_to_speech), kw_only=True + default=Factory(lambda: config.drivers.text_to_speech), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index a1bcbdee2..4b1184e5e 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 921d600c7..9bec68b91 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: Config.drivers.image_generation) + kw_only=True, default=Factory(lambda: config.drivers.image_generation) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index d85e6012d..f2bd99544 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @@ -13,7 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.drivers.image_query), kw_only=True) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: config.drivers.image_query), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 8e421d792..9804404fc 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.engines.rag.modules import BaseResponseRagModule from griptape.utils import J2 @@ -17,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 4daa10e54..6ce235fa5 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape import utils -from griptape.config import Config +from griptape.config import config from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 1c45fa5ea..82c33a0ad 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -7,7 +7,7 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.common import Message, PromptStack -from griptape.config import Config +from griptape.config import config from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index e7c8ed488..d6e3549af 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.common import PromptStack -from griptape.config import Config +from griptape.config import config from griptape.mixins import SerializableMixin if TYPE_CHECKING: @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: Config.drivers.conversation_memory), kw_only=True + default=Factory(lambda: config.drivers.conversation_memory), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = Config.drivers.prompt + prompt_driver = config.drivers.prompt temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 50be69a61..4263e61c8 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.common import Message, PromptStack -from griptape.config import Config +from griptape.config import config from griptape.memory.structure import ConversationMemory from griptape.utils import J2 @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.drivers.prompt)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 460581997..ded114213 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -5,7 +5,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.engines.rag import RagContext, RagEngine from griptape.memory.task.storage import BaseArtifactStorage @@ -16,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index f31e9d2eb..59a865897 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -6,7 +6,7 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable -from griptape.config import Config +from griptape.config import config from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask @@ -24,7 +24,7 @@ class Agent(Structure): default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) stream: bool = field(default=False, kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 6fea4d2e6..b7ca84c4f 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -8,7 +8,7 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import Config +from griptape.config import config from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -118,10 +118,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=Config.drivers.vector_store, - summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt), - json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt), + vector_store_driver=config.drivers.vector_store, + summary_engine=PromptSummaryEngine(prompt_driver=config.drivers.prompt), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=config.drivers.prompt), + json_extraction_engine=JsonExtractionEngine(prompt_driver=config.drivers.prompt), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index ccbf5dbb1..0f885d260 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -11,7 +11,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.config import Config +from griptape.config import config from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask @@ -20,7 +20,7 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index 4d9d82362..00774e0a2 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -5,11 +5,11 @@ from attrs import define -from griptape.config import Config +from griptape.config import config from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index febd3f508..8a470bb85 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -7,11 +7,11 @@ from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.config.config import Config +from griptape.config.config import config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index afbc2c05e..f94ff8de2 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -8,7 +8,7 @@ from attrs import Attribute, define, field -from griptape.config import Config +from griptape.config import config from griptape.loaders import ImageLoader from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.rules import Rule, Ruleset @@ -18,7 +18,7 @@ from griptape.artifacts import MediaArtifact -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index 6962098ca..c688a1129 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -7,12 +7,12 @@ from attrs import Factory, define, field from griptape.artifacts import ListArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 2397fbfd0..b3086bebb 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -10,7 +10,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.config import Config +from griptape.config import config from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: @@ -18,7 +18,7 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 16f8c705c..1c9dfc023 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -7,12 +7,12 @@ from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 3769f26dc..a8038832d 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -7,7 +7,7 @@ from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack -from griptape.config import Config +from griptape.config import config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -15,12 +15,12 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 99b5a7dc3..56b53c0ce 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,15 +25,15 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - from griptape.config import Config + from griptape.config import config - if Config.drivers.prompt.stream: + if config.drivers.prompt.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 def start(self) -> None: - from griptape.config import Config + from griptape.config import config if self.intro_text: self.output_fn(self.intro_text) @@ -44,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if Config.drivers.prompt.stream: + if config.drivers.prompt.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index efca5c5b8..c5545bc44 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,9 +34,9 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - from griptape.config import Config + from griptape.config import config - if not Config.drivers.prompt.stream: + if not config.drivers.prompt.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/mocks/docker/fake_api.py b/tests/mocks/docker/fake_api.py index 881093057..00e750232 100644 --- a/tests/mocks/docker/fake_api.py +++ b/tests/mocks/docker/fake_api.py @@ -154,7 +154,7 @@ def get_fake_inspect_container(*, tty=False): status_code = 200 response = { "Id": FAKE_CONTAINER_ID, - "Config": {"Labels": {"foo": "bar"}, "Privileged": True, "Tty": tty}, + "config": {"Labels": {"foo": "bar"}, "Privileged": True, "Tty": tty}, "ID": FAKE_CONTAINER_ID, "Image": "busybox:latest", "Name": "foobar", @@ -166,7 +166,7 @@ def get_fake_inspect_container(*, tty=False): "StartedAt": "2013-09-25T14:01:18.869545111+02:00", "Ghost": False, }, - "HostConfig": {"LogConfig": {"Type": "json-file", "Config": {}}}, + "HostConfig": {"LogConfig": {"Type": "json-file", "config": {}}}, "MacAddress": "02:42:ac:11:00:0a", } return status_code, response @@ -179,7 +179,7 @@ def get_fake_inspect_image(): "Parent": "27cf784147099545", "Created": "2013-03-23T22:24:18.818426-07:00", "Container": FAKE_CONTAINER_ID, - "Config": {"Labels": {"bar": "foo"}}, + "config": {"Labels": {"bar": "foo"}}, "ContainerConfig": { "Hostname": "", "User": "", @@ -446,7 +446,7 @@ def get_fake_network_list(): "Driver": "bridge", "EnableIPv6": False, "Internal": False, - "IPAM": {"Driver": "default", "Config": [{"Subnet": "172.17.0.0/16"}]}, + "IPAM": {"Driver": "default", "config": [{"Subnet": "172.17.0.0/16"}]}, "Containers": { FAKE_CONTAINER_ID: { "EndpointID": "ed2419a97c1d99", diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 01af02573..8a37f6d28 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,21 +1,21 @@ import pytest -from griptape.config import Config +from griptape.config import config from griptape.events import event_bus from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) def mock_event_bus(): - event_bus.event_listeners = [] + event_bus.clear_event_listeners() yield event_bus - event_bus.event_listeners = [] + event_bus.clear_event_listeners() @pytest.fixture(autouse=True) def mock_config(): - Config.drivers = MockDriverConfig() + config.drivers = MockDriverConfig() - return Config + return config diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 90e826d19..1b45b4e98 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -18,7 +18,7 @@ def task(self): agent = Agent( tools=[MockTool()], ) - event_bus.event_listeners = [EventListener(handler=Mock())] + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 318f434c3..edd0258f2 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,7 +2,7 @@ import pytest -from griptape.config import Config +from griptape.config import config from griptape.structures import Agent from griptape.utils import Stream @@ -10,11 +10,11 @@ class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - Config.drivers.prompt.stream = request.param + config.drivers.prompt.stream = request.param return Agent() def test_init(self, agent): - if Config.drivers.prompt.stream: + if config.drivers.prompt.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 2b9f83b81..d87fc095e 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -226,9 +226,9 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: - from griptape.config import Config + from griptape.config import config - Config.drivers.prompt = AzureOpenAiChatPromptDriver( + config.drivers.prompt = AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], From 13969c3a71c23ee733f73c5bf115d8d65bdd57e3 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:30:26 -0700 Subject: [PATCH 26/63] Fix doc --- docs/examples/multiple-agent-shared-memory.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index 0fe589d7b..30ff03ecc 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -11,7 +11,7 @@ import os from griptape.tools import WebScraper, TaskMemoryClient from griptape.structures import Agent from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver -from griptape.config import AzureOpenAiDriverConfig +from griptape.config import AzureOpenAiDriverConfig, config AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -40,7 +40,7 @@ mongo_driver = AzureMongoDbVectorStoreDriver( vector_path=MONGODB_VECTOR_PATH, ) -config = AzureOpenAiDriverConfig( +config.drivers = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, vector_store=mongo_driver, embedding=embedding_driver, From f5c42e8944e65127243d94e1def66ceae43a3b3f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:05:50 -0700 Subject: [PATCH 27/63] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a95701c2..d4aaab1a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Global config, `griptape.config.config`, for setting global configuration defaults. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. +- **BREAKING**: Removed `Workflow.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 From dd23d895a665a28572b91168887bb600f85e7c4c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:23:25 -0700 Subject: [PATCH 28/63] Add global event bus --- CHANGELOG.md | 3 + docs/griptape-framework/misc/events.md | 61 ++++++++++--------- griptape/config/base_structure_config.py | 40 ------------ .../base_audio_transcription_driver.py | 10 +-- .../embedding/base_embedding_driver.py | 4 +- .../base_image_generation_driver.py | 10 +-- .../image_query/base_image_query_driver.py | 10 +-- .../base_conversation_memory_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 16 ++--- .../base_text_to_speech_driver.py | 9 +-- .../vector/base_vector_store_driver.py | 4 +- griptape/events/__init__.py | 2 + .../event_bus.py} | 5 +- griptape/mixins/__init__.py | 2 - griptape/structures/structure.py | 12 ++-- griptape/tasks/actions_subtask.py | 6 +- griptape/tasks/base_task.py | 6 +- griptape/utils/stream.py | 9 +-- tests/unit/config/test_structure_config.py | 35 ----------- tests/unit/conftest.py | 12 ++++ .../test_base_audio_transcription_driver.py | 4 +- .../test_base_image_generation_driver.py | 9 +-- .../test_base_image_query_driver.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 7 +-- .../test_base_audio_transcription_driver.py | 4 +- tests/unit/events/test_event_bus.py | 45 ++++++++++++++ tests/unit/events/test_event_listener.py | 29 ++++----- tests/unit/mixins/test_events_mixin.py | 59 ------------------ tests/unit/tasks/test_base_task.py | 5 +- 29 files changed, 176 insertions(+), 250 deletions(-) rename griptape/{mixins/event_publisher_mixin.py => events/event_bus.py} (96%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/events/test_event_bus.py delete mode 100644 tests/unit/mixins/test_events_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f9b2e72e8..76f705ddb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. ### Changed +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 1f50fd6d0..187321dc6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,15 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, + EventBus ) def handler(event: BaseEvent): print(event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener( handler, event_types=[ @@ -44,7 +43,8 @@ agent = Agent( ], ) ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -69,7 +69,8 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener +from griptape.events import BaseEvent, EventListener, EventBus + def handler1(event: BaseEvent): @@ -79,13 +80,12 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -131,7 +131,7 @@ Handler 2 list: - return [ - self.prompt_driver, - self.image_generation_driver, - self.image_query_driver, - self.embedding_driver, - self.vector_store_driver, - self.conversation_memory_driver, - self.text_to_speech_driver, - self.audio_transcription_driver, - ] - - @property - def structure(self) -> Optional[Structure]: - return self._structure - - @structure.setter - def structure(self, structure: Structure) -> None: - if structure != self.structure: - event_publisher_drivers = [ - driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) - ] - - for driver in event_publisher_drivers: - if self._event_listener is not None: - driver.remove_event_listener(self._event_listener) - - self._event_listener = EventListener(structure.publish_event) - for driver in event_publisher_drivers: - driver.add_event_listener(self._event_listener) - - self._structure = structure - def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() merged_config = dict_merge(base_config, config) diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index c81ea1d5b..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import AudioArtifact, TextArtifact @define -class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - self.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - self.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 690726060..8998f00e5 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -7,7 +7,7 @@ from attrs import define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import TextArtifact @@ -15,7 +15,7 @@ @define -class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index f500d6d09..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @define -class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - self.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b39f198d4..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,24 +5,24 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index f13b82c29..1caeb902f 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory -class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, memory: BaseConversationMemory) -> None: ... diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e5fd0408d..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,8 +16,8 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from collections.abc import Iterator @@ -26,7 +26,7 @@ @define(kw_only=True) -class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublishe use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - self.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - self.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - self.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - self.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index 788d92974..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,23 +5,24 @@ from attrs import define, field +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @define -class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - self.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - self.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index d1da78188..ed1f2d589 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -10,14 +10,14 @@ from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,6 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -48,4 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", + "EventBus", ] diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/events/event_bus.py similarity index 96% rename from griptape/mixins/event_publisher_mixin.py rename to griptape/events/event_bus.py index 67a302ed6..9239e66bd 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/events/event_bus.py @@ -9,7 +9,7 @@ @define -class EventPublisherMixin: +class _EventBus: event_listeners: list[EventListener] = field(factory=list, kw_only=True) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: @@ -32,3 +32,6 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) + + +EventBus = _EventBus() diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 944027c59..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,7 +4,6 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin -from .event_publisher_mixin import EventPublisherMixin __all__ = [ "ActivityMixin", @@ -13,5 +12,4 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", - "EventPublisherMixin", ] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 079e0b741..df7113c23 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,13 +28,11 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events.finish_structure_run_event import FinishStructureRunEvent -from griptape.events.start_structure_run_event import StartStructureRunEvent +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.mixins import EventPublisherMixin from griptape.utils import deprecation_warn if TYPE_CHECKING: @@ -44,7 +42,7 @@ @define -class Structure(ABC, EventPublisherMixin): +class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @@ -97,8 +95,6 @@ def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self - self.config.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) @@ -261,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - self.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -273,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - self.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index cde59d0ef..07f49f52a 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - self.structure.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - self.structure.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 8c50e4df9..9a8361e6c 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import FinishTaskEvent, StartTaskEvent +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index bf33e5df8..4a7899b2a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,10 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events.completion_chunk_event import CompletionChunkEvent -from griptape.events.event_listener import EventListener -from griptape.events.finish_prompt_event import FinishPromptEvent -from griptape.events.finish_structure_run_event import FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -64,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - self.structure.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - self.structure.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,7 +1,6 @@ import pytest from griptape.config import StructureConfig -from griptape.structures import Agent class TestStructureConfig: @@ -61,37 +60,3 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 - - def test_drivers(self, config): - assert config.drivers == [ - config.prompt_driver, - config.image_generation_driver, - config.image_query_driver, - config.embedding_driver, - config.vector_store_driver, - config.conversation_memory_driver, - config.text_to_speech_driver, - config.audio_transcription_driver, - ] - - def test_structure(self, config): - structure_1 = Agent( - config=config, - ) - - assert config.structure == structure_1 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 - - structure_2 = Agent( - config=config, - ) - assert config.structure == structure_2 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..0be2f9758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from griptape.events import EventBus + + +@pytest.fixture(autouse=True) +def event_bus(): + EventBus.event_listeners = [] + + yield EventBus + + EventBus.event_listeners = [] diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 519e40f57..fc41837fd 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 7447b2c08..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -14,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -30,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -52,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -80,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index 14de15f2d..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 2708b0a88..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(EventPublisherMixin, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) @@ -42,8 +42,7 @@ def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): - pipeline = Pipeline() - result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[])) + result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 8af5dc827..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py new file mode 100644 index 000000000..fd862913e --- /dev/null +++ b/tests/unit/events/test_event_bus.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from griptape.events import EventBus, EventListener +from tests.mocks.mock_event import MockEvent + + +class TestEventBus: + def test_add_event_listeners(self): + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 + + def test_add_event_listener(self): + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listener(self): + listener = EventListener() + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) + + assert len(EventBus.event_listeners) == 0 + + def test_remove_unknown_event_listener(self): + EventBus.remove_event_listener(EventListener()) + + def test_publish_event(self): + # Given + mock_handler = Mock() + mock_handler.return_value = None + EventBus.event_listeners = [EventListener(handler=mock_handler)] + mock_event = MockEvent() + + # When + EventBus.publish_event(mock_event) + + # Then + mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b245c2be9..5601aef34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -37,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -58,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - pipeline.event_listeners = [ + EventBus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -86,25 +87,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - pipeline.event_listeners = [] + EventBus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(pipeline.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_3) - pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) - assert len(pipeline.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py deleted file mode 100644 index 99f5541ba..000000000 --- a/tests/unit/mixins/test_events_mixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -from griptape.events import EventListener -from griptape.mixins import EventPublisherMixin -from tests.mocks.mock_event import MockEvent - - -class TestEventsMixin: - def test_init(self): - assert EventPublisherMixin() - - def test_add_event_listeners(self): - mixin = EventPublisherMixin() - - mixin.add_event_listeners([EventListener(), EventListener()]) - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listeners(self): - mixin = EventPublisherMixin() - - listeners = [EventListener(), EventListener()] - mixin.add_event_listeners(listeners) - mixin.remove_event_listeners(listeners) - assert len(mixin.event_listeners) == 0 - - def test_add_event_listener(self): - mixin = EventPublisherMixin() - - mixin.add_event_listener(EventListener()) - mixin.add_event_listener(EventListener()) - - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listener(self): - mixin = EventPublisherMixin() - - listener = EventListener() - mixin.add_event_listener(listener) - mixin.remove_event_listener(listener) - - assert len(mixin.event_listeners) == 0 - - def test_remove_unknown_event_listener(self): - mixin = EventPublisherMixin() - - mixin.remove_event_listener(EventListener()) - - def test_publish_event(self): - # Given - mock_handler = Mock() - mock_handler.return_value = None - mixin = EventPublisherMixin(event_listeners=[EventListener(handler=mock_handler)]) - mock_event = MockEvent() - - # When - mixin.publish_event(mock_event) - - # Then - mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 4f4b43d40..636515106 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -15,11 +16,11 @@ class TestBaseTask: @pytest.fixture() def task(self): + EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], - event_listeners=[EventListener(handler=Mock())], ) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -117,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert task.structure.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 From 39f75c4d897560b22a4c7a81a3192ee47d30ea35 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 09:11:00 -0700 Subject: [PATCH 29/63] Update docs --- .../drivers/event-listener-drivers.md | 89 +++++++++++-------- docs/griptape-framework/misc/events.md | 14 +-- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 73453afb6..db02cd77a 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,26 +14,27 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( - handler=lambda event: { # You can optionally use the handler to transform the event payload before sending it to the Driver - "event": event.to_dict(), - }, driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -83,23 +84,26 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -128,10 +132,23 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.rules import Rule from griptape.structures import Agent +EventBus.add_event_listeners( + [ + EventListener( + event_types=[FinishStructureRunEvent], + driver=AwsIotCoreEventListenerDriver( + topic=os.environ["AWS_IOT_CORE_TOPIC"], + iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], + ), + ), + ] +) + agent = Agent( rules=[ Rule( @@ -143,15 +160,6 @@ agent = Agent( model="gpt-3.5-turbo", temperature=0.7 ) ), - event_listeners=[ - EventListener( - event_types=[FinishStructureRunEvent], - driver=AwsIotCoreEventListenerDriver( - topic=os.environ["AWS_IOT_CORE_TOPIC"], - iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], - ), - ), - ], ) agent.run("I want to fly from Orlando to Boston") @@ -171,18 +179,19 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], # By default, GriptapeCloudEventListenerDriver uses the api key provided # in the GT_CLOUD_API_KEY environment variable. driver=GriptapeCloudEventListenerDriver(), ), - ], + ] ) agent.run( @@ -201,20 +210,23 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=WebhookEventListenerDriver( webhook_url=os.environ["WEBHOOK_URL"], ), ), - ], + ] ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` ### Pusher @@ -229,12 +241,13 @@ import os from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, - FinishStructureRunEvent + FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=PusherEventListenerDriver( @@ -250,6 +263,8 @@ agent = Agent( ], ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 187321dc6..23ebcdc2a 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -30,7 +30,7 @@ from griptape.events import ( def handler(event: BaseEvent): print(event.__class__) -EventBus.event_listeners=[ +EventBus.add_event_listeners([ EventListener( handler, event_types=[ @@ -42,7 +42,7 @@ EventBus.event_listeners=[ FinishPromptEvent, ], ) - ] + ]) agent = Agent() @@ -140,12 +140,12 @@ from griptape.drivers import OpenAiChatPromptDriver -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: print(e.token, end="", flush=True), event_types=[CompletionChunkEvent], ) -] +]) pipeline = Pipeline( config=OpenAiStructureConfig( @@ -194,12 +194,12 @@ from griptape.structures import Agent token_counter = utils.TokenCounter() -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: token_counter.add_tokens(e.token_count), event_types=[StartPromptEvent, FinishPromptEvent], ) -] +]) def count_tokens(e: BaseEvent): if isinstance(e, StartPromptEvent) or isinstance(e, FinishPromptEvent): @@ -248,7 +248,7 @@ from griptape.structures import Agent from griptape.events import BaseEvent, StartPromptEvent, EventListener, EventBus -EventBus.event_listeners = [EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])] +EventBus.add_event_listeners([EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])]) def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): From 8a97313f9a224eb143975b9c670f90a874506141 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:03:17 -0700 Subject: [PATCH 30/63] Make event listeners private --- griptape/events/event_bus.py | 19 +++++++++++----- tests/unit/conftest.py | 4 ++-- tests/unit/events/test_event_bus.py | 2 +- tests/unit/events/test_event_listener.py | 28 +++++++++++++----------- tests/unit/tasks/test_base_task.py | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..6ffd65550 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -10,7 +10,11 @@ @define class _EventBus: - event_listeners: list[EventListener] = field(factory=list, kw_only=True) + _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") + + @property + def event_listeners(self) -> list[EventListener]: + return self._event_listeners def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: return [self.add_event_listener(event_listener) for event_listener in event_listeners] @@ -20,18 +24,21 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: - if event_listener not in self.event_listeners: - self.event_listeners.append(event_listener) + if event_listener not in self._event_listeners: + self._event_listeners.append(event_listener) return event_listener def remove_event_listener(self, event_listener: EventListener) -> None: - if event_listener in self.event_listeners: - self.event_listeners.remove(event_listener) + if event_listener in self._event_listeners: + self._event_listeners.remove(event_listener) def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: - for event_listener in self.event_listeners: + for event_listener in self._event_listeners: event_listener.publish_event(event, flush=flush) + def clear_event_listeners(self) -> None: + self._event_listeners.clear() + EventBus = _EventBus() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7a73b041f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,8 +5,8 @@ @pytest.fixture(autouse=True) def event_bus(): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() yield EventBus - EventBus.event_listeners = [] + EventBus.clear_event_listeners() diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..d237bb3b4 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -35,7 +35,7 @@ def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + EventBus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..f3d9823d3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,17 +59,19 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ - EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), - EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), - EventListener(start_task_event_handler, event_types=[StartTaskEvent]), - EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), - EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), - EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), - EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), - EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), - ] + EventBus.add_event_listeners( + [ + EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), + EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), + EventListener(start_task_event_handler, event_types=[StartTaskEvent]), + EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), + EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), + EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), + EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), + EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), + EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + ] + ) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -87,7 +89,7 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d6e4da8b6 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), From 0e1d019d94d56c4ab59ef56f53c1f6fe5dc18678 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:17:34 -0700 Subject: [PATCH 31/63] Rename EventBus to event_bus --- CHANGELOG.md | 4 +-- .../drivers/event-listener-drivers.md | 24 +++++++-------- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../drivers/prompt/test_base_prompt_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 6 ++-- 23 files changed, 109 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76f705ddb..9e016228c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index db02cd77a..c3c92cfe1 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,12 +14,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -84,12 +84,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -132,12 +132,12 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -179,11 +179,11 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -210,11 +210,11 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -242,11 +242,11 @@ from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 23ebcdc2a..b3f4a77fd 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.add_event_listeners([ +event_bus.add_event_listeners([ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 6ffd65550..a956f7deb 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -41,4 +41,4 @@ def clear_event_listeners(self) -> None: self._event_listeners.clear() -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..d68457ebc 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,7 +28,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -257,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -269,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..d600c80a5 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..ade656f87 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..fd64a0f52 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -61,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7a73b041f..e462ede90 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,12 @@ import pytest -from griptape.events import EventBus +from griptape.events import event_bus @pytest.fixture(autouse=True) -def event_bus(): - EventBus.clear_event_listeners() +def mock_event_bus(): + event_bus.clear_event_listeners() - yield EventBus + yield event_bus - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..6fcab26e5 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..52b7d5c0d 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _EventBus +from griptape.events.event_bus import _event_bus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_EventBus, "publish_event") + mock_publish_event = mocker.patch.object(_event_bus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index d237bb3b4..7eb87036a 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.add_event_listeners([EventListener(handler=mock_handler)]) + event_bus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index f3d9823d3..50763e0c3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + event_bus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.add_event_listeners( + event_bus.add_event_listeners( [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), @@ -89,25 +89,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d6e4da8b6..aa402bb48 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.add_event_listeners([EventListener(handler=Mock())]) + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), @@ -118,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2 From 4d491c234f9fed6abbd7fe00fcb8858b2ab49c26 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:29:50 -0700 Subject: [PATCH 32/63] Fix doc --- docs/griptape-framework/misc/events.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b3f4a77fd..ebab3c460 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -80,10 +80,11 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -event_bus.event_listeners=[ +event_bus.add_event_listeners([ EventListener(handler1), EventListener(handler2), ] +) agent = Agent() From 863bcde112e3bb360950448cef4a115de406fbb9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:05:52 -0700 Subject: [PATCH 33/63] Fix test --- tests/unit/drivers/prompt/test_base_prompt_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 52b7d5c0d..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _event_bus +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_event_bus, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) From 12efa677c7ffa413877d56d18fca5bd637c3e485 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:11:31 -0700 Subject: [PATCH 34/63] Fix doc --- docs/griptape-framework/misc/events.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index bfc8ee8ec..e70909074 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -84,10 +84,7 @@ event_bus.add_event_listeners([ EventListener(handler1), EventListener(handler2), ] - -agent = Agent() - -agent = Agent() +) agent = Agent() From 0d5ce93d6ceb5c6b6ba1ace203932e9179f2ea7b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:18:40 -0700 Subject: [PATCH 35/63] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04b592c4e..375773c3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `EventPublisherMixin`. - **BREAKING**: Removed `Workflow.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. - **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. +- **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `griptape.config.config.logger` instead. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. +- All Task and Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. ## [0.29.0] - 2024-07-30 From daa171031416d36eb846a92b803e8f4b40d09939 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:47:46 -0700 Subject: [PATCH 36/63] Update changelog --- CHANGELOG.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 375773c3b..a528d9f81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,19 +11,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. -- Global config, `griptape.config.config`, for setting global configuration defaults. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Global config, `griptape.config.config`, for setting global configuration defaults. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. -- **BREAKING**: Removed `Workflow.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Pipeline.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Pipeline.stream` and `Workflow.stream`. `Agent.stream` has not been removed. - **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. - **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `griptape.config.config.logger` instead. +- **BREAKING**: Removed `BaseStructureConfig.merge_config`. +- **BREAKING**: Renamed `StructureConfig` to `DriverConfig`, and renamed fields accordingly. +- Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. -- All Task and Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. ## [0.29.0] - 2024-07-30 From 11f9ac8da0701ae43fde15dcdec1ef62fc75fe83 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 09:24:55 -0700 Subject: [PATCH 37/63] Fix type errors --- Makefile | 2 +- .../drivers/src/prompt_drivers_1.py | 5 +---- .../drivers/src/prompt_drivers_10.py | 7 ++----- .../drivers/src/prompt_drivers_11.py | 9 +++------ .../drivers/src/prompt_drivers_12.py | 9 +++------ .../drivers/src/prompt_drivers_13.py | 7 ++----- .../drivers/src/prompt_drivers_14.py | 9 +++------ .../drivers/src/prompt_drivers_3.py | 15 ++++++--------- .../drivers/src/prompt_drivers_4.py | 7 ++----- .../drivers/src/prompt_drivers_5.py | 13 +++++-------- .../drivers/src/prompt_drivers_6.py | 9 +++------ .../drivers/src/prompt_drivers_7.py | 9 +++------ .../drivers/src/prompt_drivers_8.py | 9 +++------ .../drivers/src/prompt_drivers_9.py | 7 ++----- docs/griptape-framework/misc/src/events_3.py | 12 +++++++----- 15 files changed, 46 insertions(+), 83 deletions(-) diff --git a/Makefile b/Makefile index f1db966f0..73175b7c5 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ check/lint: .PHONY: check/types check/types: - @poetry run pyright griptape/ docs/**/src/** + @poetry run pyright griptape $(shell find docs -type f -path "*/src/*") .PHONY: check/spell check/spell: diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_1.py b/docs/griptape-framework/drivers/src/prompt_drivers_1.py index 978435f2d..ab5273228 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_1.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_1.py @@ -1,12 +1,9 @@ -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), - ), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[Rule(value="Output only the sentiment.")], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_10.py b/docs/griptape-framework/drivers/src/prompt_drivers_10.py index 02f083570..04e2acb35 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_10.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_10.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import OllamaPromptDriver from griptape.structures import Agent from griptape.tools import Calculator agent = Agent( - config=StructureConfig( - prompt_driver=OllamaPromptDriver( - model="llama3.1", - ), + prompt_driver=OllamaPromptDriver( + model="llama3.1", ), tools=[Calculator()], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_11.py b/docs/griptape-framework/drivers/src/prompt_drivers_11.py index 1c81c4785..9e838473c 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_11.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_11.py @@ -1,16 +1,13 @@ import os -from griptape.config import StructureConfig from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="HuggingFaceH4/zephyr-7b-beta", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ) + prompt_driver=HuggingFaceHubPromptDriver( + model="HuggingFaceH4/zephyr-7b-beta", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), rulesets=[ Ruleset( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_12.py b/docs/griptape-framework/drivers/src/prompt_drivers_12.py index d6f59f96e..d555c32c9 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_12.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_12.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import HuggingFaceHubPromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="http://127.0.0.1:8080", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ), + prompt_driver=HuggingFaceHubPromptDriver( + model="http://127.0.0.1:8080", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_13.py b/docs/griptape-framework/drivers/src/prompt_drivers_13.py index e4fe5208c..d3ddd9093 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_13.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_13.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFacePipelinePromptDriver( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - ) + prompt_driver=HuggingFacePipelinePromptDriver( + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ), rulesets=[ Ruleset( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_14.py b/docs/griptape-framework/drivers/src/prompt_drivers_14.py index 85bd5216e..228a5f9b2 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_14.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_14.py @@ -1,17 +1,14 @@ import os -from griptape.config import StructureConfig from griptape.drivers import ( AmazonSageMakerJumpstartPromptDriver, ) from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AmazonSageMakerJumpstartPromptDriver( - endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], - model="meta-llama/Meta-Llama-3-8B-Instruct", - ) + prompt_driver=AmazonSageMakerJumpstartPromptDriver( + endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], + model="meta-llama/Meta-Llama-3-8B-Instruct", ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_3.py b/docs/griptape-framework/drivers/src/prompt_drivers_3.py index b92596aca..8e85ce887 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_3.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_3.py @@ -1,19 +1,16 @@ import os -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver( - api_key=os.environ["OPENAI_API_KEY"], - temperature=0.1, - model="gpt-4o", - response_format="json_object", - seed=42, - ) + prompt_driver=OpenAiChatPromptDriver( + api_key=os.environ["OPENAI_API_KEY"], + temperature=0.1, + model="gpt-4o", + response_format="json_object", + seed=42, ), input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", rules=[Rule(value='Write your output in json with a single key called "css_code".')], diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_4.py b/docs/griptape-framework/drivers/src/prompt_drivers_4.py index b024638b7..bcafb40de 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_4.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_4.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver( - base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True - ) + prompt_driver=OpenAiChatPromptDriver( + base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True ), rules=[Rule(value="You are a helpful coding assistant.")], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_5.py b/docs/griptape-framework/drivers/src/prompt_drivers_5.py index ffe9a4e0a..76301d8d9 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_5.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_5.py @@ -1,18 +1,15 @@ import os -from griptape.config import StructureConfig from griptape.drivers import AzureOpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-3.5-turbo", - azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - ) + prompt_driver=AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-3.5-turbo", + azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], ), rules=[ Rule( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_6.py b/docs/griptape-framework/drivers/src/prompt_drivers_6.py index 2bd1b00fb..5e4d226a6 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_6.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_6.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import CoherePromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=CoherePromptDriver( - model="command-r", - api_key=os.environ["COHERE_API_KEY"], - ) + prompt_driver=CoherePromptDriver( + model="command-r", + api_key=os.environ["COHERE_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_7.py b/docs/griptape-framework/drivers/src/prompt_drivers_7.py index dd1c15370..23f3d0c35 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_7.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_7.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import AnthropicPromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AnthropicPromptDriver( - model="claude-3-opus-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) + prompt_driver=AnthropicPromptDriver( + model="claude-3-opus-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_8.py b/docs/griptape-framework/drivers/src/prompt_drivers_8.py index 1bbf2848c..b6a1c109e 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_8.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_8.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import GooglePromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=GooglePromptDriver( - model="gemini-pro", - api_key=os.environ["GOOGLE_API_KEY"], - ) + prompt_driver=GooglePromptDriver( + model="gemini-pro", + api_key=os.environ["GOOGLE_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_9.py b/docs/griptape-framework/drivers/src/prompt_drivers_9.py index cdd0db82d..992dbecd2 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_9.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_9.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import AmazonBedrockPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - ) + prompt_driver=AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", ), rules=[ Rule( diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index a99a412eb..ab995e018 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,4 +1,6 @@ -from griptape.config import OpenAiDriverConfig +from typing import cast + +from griptape.config import OpenAiDriverConfig, config from griptape.drivers import OpenAiChatPromptDriver from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline @@ -8,15 +10,15 @@ event_bus.add_event_listeners( [ EventListener( - lambda e: print(e.token, end="", flush=True), + lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True), event_types=[CompletionChunkEvent], ) ] ) -pipeline = Pipeline( - config=OpenAiDriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True)), -) +config.drivers = OpenAiDriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True)) + +pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( From 6364cb776e873390b3f0a78e76ff1536f69c3722 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 10:37:14 -0700 Subject: [PATCH 38/63] Remove engine functionality from artifact storage, delete Task Memory --- .../structures/src/task_memory_6.py | 17 ------- .../task/storage/base_artifact_storage.py | 10 +--- .../task/storage/blob_artifact_storage.py | 10 +--- .../task/storage/text_artifact_storage.py | 50 ++----------------- griptape/memory/task/task_memory.py | 29 ++++------- griptape/structures/structure.py | 47 +---------------- griptape/tools/task_memory_client/tool.py | 25 +++------- griptape/utils/load_artifact_from_memory.py | 12 +++-- 8 files changed, 33 insertions(+), 167 deletions(-) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 88d10971d..5a81ee8cd 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -8,9 +8,6 @@ OpenAiChatPromptDriver, OpenAiEmbeddingDriver, ) -from griptape.engines.rag import RagEngine -from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule -from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.memory import TaskMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent @@ -26,20 +23,6 @@ task_memory=TaskMemory( artifact_storages={ TextArtifact: TextArtifactStorage( - rag_engine=RagEngine( - retrieval_stage=RetrievalRagStage( - retrieval_modules=[ - VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver, - query_params={"namespace": "griptape", "count": 20}, - ) - ] - ), - response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ), - ), - retrieval_rag_module_name="VectorStoreRetrievalRagModule", vector_store_driver=vector_store_driver, ) } diff --git a/griptape/memory/task/storage/base_artifact_storage.py b/griptape/memory/task/storage/base_artifact_storage.py index 866df19da..792f479bc 100644 --- a/griptape/memory/task/storage/base_artifact_storage.py +++ b/griptape/memory/task/storage/base_artifact_storage.py @@ -1,12 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import define if TYPE_CHECKING: - from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact + from griptape.artifacts import BaseArtifact, ListArtifact @define @@ -19,9 +19,3 @@ def load_artifacts(self, namespace: str) -> ListArtifact: ... @abstractmethod def can_store(self, artifact: BaseArtifact) -> bool: ... - - @abstractmethod - def summarize(self, namespace: str) -> TextArtifact | InfoArtifact: ... - - @abstractmethod - def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: ... diff --git a/griptape/memory/task/storage/blob_artifact_storage.py b/griptape/memory/task/storage/blob_artifact_storage.py index 6199dc3a3..3f4309481 100644 --- a/griptape/memory/task/storage/blob_artifact_storage.py +++ b/griptape/memory/task/storage/blob_artifact_storage.py @@ -1,10 +1,8 @@ from __future__ import annotations -from typing import Any - from attrs import define, field -from griptape.artifacts import BaseArtifact, BlobArtifact, InfoArtifact, ListArtifact +from griptape.artifacts import BaseArtifact, BlobArtifact, ListArtifact from griptape.memory.task.storage import BaseArtifactStorage @@ -26,9 +24,3 @@ def store_artifact(self, namespace: str, artifact: BaseArtifact) -> None: def load_artifacts(self, namespace: str) -> ListArtifact: return ListArtifact(next((blobs for key, blobs in self.blobs.items() if key == namespace), [])) - - def summarize(self, namespace: str) -> InfoArtifact: - return InfoArtifact("can't summarize artifacts") - - def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: - return InfoArtifact("can't query artifacts") diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index ded114213..623c176ea 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -1,32 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING -from attrs import Attribute, Factory, define, field +from attrs import Factory, define, field -from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.config import config -from griptape.engines.rag import RagContext, RagEngine from griptape.memory.task.storage import BaseArtifactStorage if TYPE_CHECKING: from griptape.drivers import BaseVectorStoreDriver - from griptape.engines import BaseSummaryEngine, CsvExtractionEngine, JsonExtractionEngine @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) - rag_engine: Optional[RagEngine] = field(default=None) - retrieval_rag_module_name: Optional[str] = field(default=None) - summary_engine: Optional[BaseSummaryEngine] = field(default=None) - csv_extraction_engine: Optional[CsvExtractionEngine] = field(default=None) - json_extraction_engine: Optional[JsonExtractionEngine] = field(default=None) - - @rag_engine.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_rag_engine(self, _: Attribute, rag_engine: str) -> None: - if rag_engine is not None and self.retrieval_rag_module_name is None: - raise ValueError("You have to set retrieval_rag_module_name if rag_engine is provided") def can_store(self, artifact: BaseArtifact) -> bool: return isinstance(artifact, TextArtifact) @@ -39,35 +27,3 @@ def store_artifact(self, namespace: str, artifact: BaseArtifact) -> None: def load_artifacts(self, namespace: str) -> ListArtifact: return self.vector_store_driver.load_artifacts(namespace=namespace) - - def summarize(self, namespace: str) -> TextArtifact: - if self.summary_engine is None: - raise ValueError("Summary engine is not set.") - - return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace)) - - def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: - if self.rag_engine is None: - raise ValueError("rag_engine is not set") - - if self.retrieval_rag_module_name is None: - raise ValueError("retrieval_rag_module_name is not set") - - result = self.rag_engine.process( - RagContext( - query=query, - module_configs={ - self.retrieval_rag_module_name: { - "query_params": { - "namespace": namespace, - "metadata": None if metadata is None else str(metadata), - }, - }, - }, - ), - ).output - - if result is None: - return InfoArtifact("Empty output") - else: - return result diff --git a/griptape/memory/task/task_memory.py b/griptape/memory/task/task_memory.py index e2131d1f0..c7f12b233 100644 --- a/griptape/memory/task/task_memory.py +++ b/griptape/memory/task/task_memory.py @@ -4,8 +4,9 @@ from attrs import Attribute, Factory, define, field -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BaseArtifact, BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.memory.meta import ActionSubtaskMetaEntry +from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.mixins import ActivityMixin if TYPE_CHECKING: @@ -16,7 +17,15 @@ @define class TaskMemory(ActivityMixin): name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) - artifact_storages: dict[type, BaseArtifactStorage] = field(factory=dict, kw_only=True) + artifact_storages: dict[type, BaseArtifactStorage] = field( + default=Factory( + lambda: { + TextArtifact: TextArtifactStorage(), + BlobArtifact: BlobArtifactStorage(), + } + ), + kw_only=True, + ) namespace_storage: dict[str, BaseArtifactStorage] = field(factory=dict, kw_only=True) namespace_metadata: dict[str, Any] = field(factory=dict, kw_only=True) @@ -123,19 +132,3 @@ def find_input_memory(self, memory_name: str) -> Optional[TaskMemory]: return self else: return None - - def summarize_namespace(self, namespace: str) -> TextArtifact | InfoArtifact: - storage = self.namespace_storage.get(namespace) - - if storage: - return storage.summarize(namespace) - else: - return InfoArtifact("Can't find memory content") - - def query_namespace(self, namespace: str, query: str) -> BaseArtifact: - storage = self.namespace_storage.get(namespace) - - if storage: - return storage.query(namespace=namespace, query=query, metadata=self.namespace_metadata.get(namespace)) - else: - return InfoArtifact("Can't find memory content") diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index b7ca84c4f..a18e9d578 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -6,25 +6,14 @@ from attrs import Attribute, Factory, define, field -from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import config -from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine -from griptape.engines.rag import RagEngine -from griptape.engines.rag.modules import ( - MetadataBeforeResponseRagModule, - PromptResponseRagModule, - RulesetsBeforeResponseRagModule, - VectorStoreRetrievalRagModule, -) -from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact from griptape.memory.structure import BaseConversationMemory from griptape.rules import Rule, Ruleset from griptape.tasks import BaseTask @@ -40,9 +29,8 @@ class Structure(ABC): default=Factory(lambda: ConversationMemory()), kw_only=True, ) - rag_engine: RagEngine = field(default=Factory(lambda self: self.default_rag_engine, takes_self=True), kw_only=True) task_memory: TaskMemory = field( - default=Factory(lambda self: self.default_task_memory, takes_self=True), + default=Factory(lambda self: TaskMemory(), takes_self=True), kw_only=True, ) meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) @@ -96,37 +84,6 @@ def output(self) -> Optional[BaseArtifact]: def finished_tasks(self) -> list[BaseTask]: return [s for s in self.tasks if s.is_finished()] - @property - def default_rag_engine(self) -> RagEngine: - return RagEngine( - retrieval_stage=RetrievalRagStage( - retrieval_modules=[VectorStoreRetrievalRagModule()], - ), - response_stage=ResponseRagStage( - before_response_modules=[ - RulesetsBeforeResponseRagModule(rulesets=self.rulesets), - MetadataBeforeResponseRagModule(), - ], - response_module=PromptResponseRagModule(), - ), - ) - - @property - def default_task_memory(self) -> TaskMemory: - return TaskMemory( - artifact_storages={ - TextArtifact: TextArtifactStorage( - rag_engine=self.rag_engine, - retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=config.drivers.vector_store, - summary_engine=PromptSummaryEngine(prompt_driver=config.drivers.prompt), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=config.drivers.prompt), - json_extraction_engine=JsonExtractionEngine(prompt_driver=config.drivers.prompt), - ), - BlobArtifact: BlobArtifactStorage(), - }, - ) - def is_finished(self) -> bool: return all(s.is_finished() for s in self.tasks) diff --git a/griptape/tools/task_memory_client/tool.py b/griptape/tools/task_memory_client/tool.py index 160a54d85..a20d63506 100644 --- a/griptape/tools/task_memory_client/tool.py +++ b/griptape/tools/task_memory_client/tool.py @@ -1,12 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from attrs import define from schema import Literal, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity +if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact + @define class TaskMemoryClient(BaseTool): @@ -16,14 +20,7 @@ class TaskMemoryClient(BaseTool): "schema": Schema({"memory_name": str, "artifact_namespace": str}), }, ) - def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact: - memory = self.find_input_memory(params["values"]["memory_name"]) - artifact_namespace = params["values"]["artifact_namespace"] - - if memory: - return memory.summarize_namespace(artifact_namespace) - else: - return ErrorArtifact("memory not found") + def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact: ... @activity( config={ @@ -41,12 +38,4 @@ def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact ), }, ) - def query(self, params: dict) -> BaseArtifact: - memory = self.find_input_memory(params["values"]["memory_name"]) - artifact_namespace = params["values"]["artifact_namespace"] - query = params["values"]["query"] - - if memory: - return memory.query_namespace(namespace=artifact_namespace, query=query) - else: - return ErrorArtifact("memory not found") + def query(self, params: dict) -> BaseArtifact: ... diff --git a/griptape/utils/load_artifact_from_memory.py b/griptape/utils/load_artifact_from_memory.py index a45a41dbd..ec260787a 100644 --- a/griptape/utils/load_artifact_from_memory.py +++ b/griptape/utils/load_artifact_from_memory.py @@ -1,5 +1,10 @@ -from griptape.artifacts import BaseArtifact -from griptape.memory import TaskMemory +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact + from griptape.memory import TaskMemory def load_artifact_from_memory( @@ -8,9 +13,6 @@ def load_artifact_from_memory( artifact_name: str, artifact_type: type, ) -> BaseArtifact: - if memory is None: - raise ValueError("memory not found") - artifacts = memory.load_artifacts(namespace=artifact_namespace) if len(artifacts) == 0: raise ValueError("no artifacts found in namespace") From b4f6d610bcae15866c9d2f1e5913c5b916841fa4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 10:46:28 -0700 Subject: [PATCH 39/63] Create Summary and Extraction Tools --- .../extraction/base_extraction_engine.py | 8 +- .../extraction/csv_extraction_engine.py | 31 ++----- .../extraction/json_extraction_engine.py | 27 ++---- griptape/tools/__init__.py | 4 + griptape/tools/extraction_client/__init__.py | 0 griptape/tools/extraction_client/manifest.yml | 5 ++ .../tools/extraction_client/requirements.txt | 0 griptape/tools/extraction_client/tool.py | 88 +++++++++++++++++++ .../tools/prompt_summary_client/__init__.py | 0 .../tools/prompt_summary_client/manifest.yml | 5 ++ .../prompt_summary_client/requirements.txt | 0 griptape/tools/prompt_summary_client/tool.py | 58 ++++++++++++ tests/unit/tools/test_extraction_client.py | 73 +++++++++++++++ .../unit/tools/test_prompt_summary_client.py | 29 ++++++ tests/utils/defaults.py | 7 -- 15 files changed, 281 insertions(+), 54 deletions(-) create mode 100644 griptape/tools/extraction_client/__init__.py create mode 100644 griptape/tools/extraction_client/manifest.yml create mode 100644 griptape/tools/extraction_client/requirements.txt create mode 100644 griptape/tools/extraction_client/tool.py create mode 100644 griptape/tools/prompt_summary_client/__init__.py create mode 100644 griptape/tools/prompt_summary_client/manifest.yml create mode 100644 griptape/tools/prompt_summary_client/requirements.txt create mode 100644 griptape/tools/prompt_summary_client/tool.py create mode 100644 tests/unit/tools/test_extraction_client.py create mode 100644 tests/unit/tools/test_prompt_summary_client.py diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 4b1184e5e..f4fcd5d3a 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -1,21 +1,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker from griptape.config import config +from griptape.mixins.rule_mixin import RuleMixin if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact from griptape.drivers import BasePromptDriver - from griptape.rules import Ruleset @define -class BaseExtractionEngine(ABC): +class BaseExtractionEngine(ABC, RuleMixin): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) @@ -49,7 +49,5 @@ def min_response_tokens(self) -> int: def extract( self, text: str | ListArtifact, - *, - rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: ... diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 3184654b1..9297d8377 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,41 +2,31 @@ import csv import io -from typing import TYPE_CHECKING, Optional, cast +from typing import cast from attrs import Factory, define, field from griptape.artifacts import CsvRowArtifact, ErrorArtifact, ListArtifact, TextArtifact -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack from griptape.engines import BaseExtractionEngine from griptape.utils import J2 -if TYPE_CHECKING: - from griptape.rules import Ruleset - @define class CsvExtractionEngine(BaseExtractionEngine): + column_names: list[str] = field(default=Factory(list), kw_only=True) template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv_extraction.j2")), kw_only=True) def extract( self, text: str | ListArtifact, - *, - rulesets: Optional[list[Ruleset]] = None, - column_names: Optional[list[str]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: - if column_names is None: - column_names = [] try: return ListArtifact( self._extract_rec( cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], - column_names, [], - rulesets=rulesets, ), item_separator="\n", ) @@ -55,22 +45,19 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArt def _extract_rec( self, artifacts: list[TextArtifact], - column_names: list[str], rows: list[CsvRowArtifact], - rulesets: Optional[list[Ruleset]] = None, ) -> list[CsvRowArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) full_text = self.template_generator.render( - column_names=column_names, text=artifacts_text, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), ) if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: rows.extend( self.text_to_csv_rows( self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, - column_names, + self.column_names, ), ) @@ -78,16 +65,16 @@ def _extract_rec( else: chunks = self.chunker.chunk(artifacts_text) partial_text = self.template_generator.render( - column_names=column_names, + column_names=self.column_names, text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), ) rows.extend( self.text_to_csv_rows( self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value, - column_names, + self.column_names, ), ) - return self._extract_rec(chunks[1:], column_names, rows, rulesets=rulesets) + return self._extract_rec(chunks[1:], rows) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index d9c7cd4aa..8bd46ed8b 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Optional, cast +from typing import cast from attrs import Factory, define, field @@ -11,33 +11,22 @@ from griptape.engines import BaseExtractionEngine from griptape.utils import J2 -if TYPE_CHECKING: - from griptape.rules import Ruleset - @define class JsonExtractionEngine(BaseExtractionEngine): + template_schema: dict = field(default=Factory(dict), kw_only=True) template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json_extraction.j2")), kw_only=True) def extract( self, text: str | ListArtifact, - *, - rulesets: Optional[list[Ruleset]] = None, - template_schema: Optional[dict | list[dict]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: - if template_schema is None: - template_schema = [] try: - json_schema = json.dumps(template_schema) - return ListArtifact( self._extract_rec( cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], - json_schema, [], - rulesets=rulesets, ), item_separator="\n", ) @@ -50,15 +39,13 @@ def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: def _extract_rec( self, artifacts: list[TextArtifact], - json_template_schema: str, extractions: list[TextArtifact], - rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) full_text = self.template_generator.render( - json_template_schema=json_template_schema, + json_template_schema=json.dumps(self.template_schema), text=artifacts_text, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), ) if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: @@ -72,9 +59,9 @@ def _extract_rec( else: chunks = self.chunker.chunk(artifacts_text) partial_text = self.template_generator.render( - template_schema=json_template_schema, + template_schema=self.template_schema, text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), ) extractions.extend( @@ -83,4 +70,4 @@ def _extract_rec( ), ) - return self._extract_rec(chunks[1:], json_template_schema, extractions, rulesets=rulesets) + return self._extract_rec(chunks[1:], extractions) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index d99b63b6c..6c566d1eb 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -30,6 +30,8 @@ from .rag_client.tool import RagClient from .text_to_speech_client.tool import TextToSpeechClient from .audio_transcription_client.tool import AudioTranscriptionClient +from .extraction_client.tool import ExtractionClient +from .prompt_summary_client.tool import PromptSummaryClient __all__ = [ "BaseTool", @@ -64,4 +66,6 @@ "RagClient", "TextToSpeechClient", "AudioTranscriptionClient", + "ExtractionClient", + "PromptSummaryClient", ] diff --git a/griptape/tools/extraction_client/__init__.py b/griptape/tools/extraction_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/extraction_client/manifest.yml b/griptape/tools/extraction_client/manifest.yml new file mode 100644 index 000000000..9c489d9f6 --- /dev/null +++ b/griptape/tools/extraction_client/manifest.yml @@ -0,0 +1,5 @@ +version: "v1" +name: Extraction Client +description: Tool for performing structured extractions on unstructured data. +contact_email: hello@griptape.ai +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/extraction_client/requirements.txt b/griptape/tools/extraction_client/requirements.txt new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/extraction_client/tool.py b/griptape/tools/extraction_client/tool.py new file mode 100644 index 000000000..911810a70 --- /dev/null +++ b/griptape/tools/extraction_client/tool.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from attrs import define, field +from schema import Literal, Or, Schema + +from griptape.artifacts import ErrorArtifact +from griptape.engines import CsvExtractionEngine, JsonExtractionEngine +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + +if TYPE_CHECKING: + from griptape.artifacts import InfoArtifact, ListArtifact + from griptape.engines import BaseExtractionEngine + + +@define(kw_only=True) +class ExtractionClient(BaseTool): + """Tool for using an Extraction Engine. + + Attributes: + extraction_engine: `ExtractionEngine`. + """ + + extraction_engine: BaseExtractionEngine = field() + + def __attrs_post_init__(self) -> None: + if isinstance(self.extraction_engine, CsvExtractionEngine): + self.allowlist = ["extract_csv"] + elif isinstance(self.extraction_engine, JsonExtractionEngine): + self.allowlist = ["extract_json"] + + @activity( + config={ + "description": "Can be used extract data in JSON format", + "schema": Schema( + { + Literal("data"): Or( + str, + Schema( + { + "memory_name": str, + "artifact_namespace": str, + } + ), + ), + } + ), + }, + ) + def extract_json(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + return self._extract(params) + + @activity( + config={ + "description": "Can be used extract data in CSV format", + "schema": Schema( + { + Literal("data"): Or( + str, + Schema( + { + "memory_name": str, + "artifact_namespace": str, + } + ), + ), + } + ), + }, + ) + def extract_csv(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + return self._extract(params) + + def _extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + data = params["values"]["data"] + + if isinstance(data, str): + return self.extraction_engine.extract(data) + else: + memory = self.find_input_memory(data["memory_name"]) + artifact_namespace = data["artifact_namespace"] + + if memory is not None: + return self.extraction_engine.extract(memory.load_artifacts(artifact_namespace)) + else: + return ErrorArtifact("memory not found") diff --git a/griptape/tools/prompt_summary_client/__init__.py b/griptape/tools/prompt_summary_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/prompt_summary_client/manifest.yml b/griptape/tools/prompt_summary_client/manifest.yml new file mode 100644 index 000000000..a83ea4021 --- /dev/null +++ b/griptape/tools/prompt_summary_client/manifest.yml @@ -0,0 +1,5 @@ +version: "v1" +name: Prompt Summary Client +description: Tool for using a Prompt Summary Engine +contact_email: hello@griptape.ai +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/prompt_summary_client/requirements.txt b/griptape/tools/prompt_summary_client/requirements.txt new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/prompt_summary_client/tool.py b/griptape/tools/prompt_summary_client/tool.py new file mode 100644 index 000000000..37b42a3ac --- /dev/null +++ b/griptape/tools/prompt_summary_client/tool.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from attrs import define, field +from schema import Literal, Or, Schema + +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + +if TYPE_CHECKING: + from griptape.engines import PromptSummaryEngine + + +@define(kw_only=True) +class PromptSummaryClient(BaseTool): + """Tool for using a Prompt Summary Engine. + + Attributes: + prompt_summary_engine: `PromptSummaryEngine`. + """ + + prompt_summary_engine: PromptSummaryEngine = field(kw_only=True) + + @activity( + config={ + "description": "Can be used to summarize text content.", + "schema": Schema( + { + Literal("summary"): Or( + str, + Schema( + { + "memory_name": str, + "artifact_namespace": str, + } + ), + ), + } + ), + }, + ) + def summarize(self, params: dict) -> BaseArtifact: + summary = params["values"]["summary"] + + if isinstance(summary, str): + artifacts = ListArtifact([TextArtifact(summary)]) + else: + memory = self.find_input_memory(summary["memory_name"]) + artifact_namespace = summary["artifact_namespace"] + + if memory is not None: + artifacts = memory.load_artifacts(artifact_namespace) + else: + return ErrorArtifact("memory not found") + + return self.prompt_summary_engine.summarize_artifacts(artifacts) diff --git a/tests/unit/tools/test_extraction_client.py b/tests/unit/tools/test_extraction_client.py new file mode 100644 index 000000000..6285d5b09 --- /dev/null +++ b/tests/unit/tools/test_extraction_client.py @@ -0,0 +1,73 @@ +import json + +import pytest + +from griptape.artifacts import TextArtifact +from griptape.engines import CsvExtractionEngine, JsonExtractionEngine +from griptape.tools import ExtractionClient +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.utils import defaults + + +class TestExtractionClient: + @pytest.fixture() + def json_tool(self): + return ExtractionClient( + input_memory=[defaults.text_task_memory("TestMemory")], + extraction_engine=JsonExtractionEngine( + prompt_driver=MockPromptDriver( + mock_output='[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' + ), + template_schema={}, + ), + ) + + @pytest.fixture() + def csv_tool(self): + return ExtractionClient( + input_memory=[defaults.text_task_memory("TestMemory")], + extraction_engine=CsvExtractionEngine( + prompt_driver=MockPromptDriver(), + column_names=["test1"], + ), + ) + + def test_json_extract_artifacts(self, json_tool): + json_tool.input_memory[0].store_artifact("foo", TextArtifact(json.dumps({}))) + + result = json_tool.extract_json( + {"values": {"data": {"memory_name": json_tool.input_memory[0].name, "artifact_namespace": "foo"}}} + ) + + assert len(result.value) == 2 + assert result.value[0].value == '{"test_key_1": "test_value_1"}' + assert result.value[1].value == '{"test_key_2": "test_value_2"}' + + def test_json_extract_content(self, json_tool): + result = json_tool.extract_json({"values": {"data": "foo"}}) + + assert len(result.value) == 2 + assert result.value[0].value == '{"test_key_1": "test_value_1"}' + assert result.value[1].value == '{"test_key_2": "test_value_2"}' + + def test_csv_extract_artifacts(self, csv_tool): + csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz")) + + result = csv_tool.extract_csv( + {"values": {"data": {"memory_name": csv_tool.input_memory[0].name, "artifact_namespace": "foo"}}} + ) + + assert len(result.value) == 1 + assert result.value[0].value == {"test1": "mock output"} + + def test_csv_extract_content(self, csv_tool): + result = csv_tool.extract_csv({"values": {"data": "foo"}}) + + assert len(result.value) == 1 + assert result.value[0].value == {"test1": "mock output"} + + def test_json_allowlist(self, json_tool): + assert json_tool.allowlist == ["extract_json"] + + def test_csv_allowlist(self, csv_tool): + assert csv_tool.allowlist == ["extract_csv"] diff --git a/tests/unit/tools/test_prompt_summary_client.py b/tests/unit/tools/test_prompt_summary_client.py new file mode 100644 index 000000000..a31f217dd --- /dev/null +++ b/tests/unit/tools/test_prompt_summary_client.py @@ -0,0 +1,29 @@ +import pytest + +from griptape.artifacts import TextArtifact +from griptape.engines import PromptSummaryEngine +from griptape.tools import PromptSummaryClient +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.utils import defaults + + +class TestPromptSummaryClient: + @pytest.fixture() + def tool(self): + return PromptSummaryClient( + input_memory=[defaults.text_task_memory("TestMemory")], + prompt_summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver()), + ) + + def test_summarize_artifacts(self, tool): + tool.input_memory[0].store_artifact("foo", TextArtifact("test")) + + assert ( + tool.summarize( + {"values": {"summary": {"memory_name": tool.input_memory[0].name, "artifact_namespace": "foo"}}} + ).value + == "mock output" + ) + + def test_summarize_content(self, tool): + assert tool.summarize({"values": {"summary": "test"}}).value == "mock output" diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index e3bcde29b..22ce869db 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -1,25 +1,18 @@ from griptape.artifacts import BlobArtifact, TextArtifact from griptape.drivers import LocalVectorStoreDriver -from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.memory import TaskMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver def text_tool_artifact_storage(): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) return TextArtifactStorage( - rag_engine=rag_engine(MockPromptDriver(), vector_store_driver), vector_store_driver=vector_store_driver, - retrieval_rag_module_name="VectorStoreRetrievalRagModule", - summary_engine=PromptSummaryEngine(), - csv_extraction_engine=CsvExtractionEngine(), - json_extraction_engine=JsonExtractionEngine(), ) From 9c5e5c8e2d418007d576919bf7cd6374ececcebf Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 11:22:22 -0700 Subject: [PATCH 40/63] Clean up engines/tests --- .../drivers/prompt-drivers.md | 2 +- .../extraction/base_extraction_engine.py | 8 +++-- .../extraction/csv_extraction_engine.py | 15 ++++++--- .../extraction/json_extraction_engine.py | 15 ++++++--- griptape/tasks/__init__.py | 4 --- griptape/tasks/csv_extraction_task.py | 11 ------- griptape/tasks/extraction_task.py | 2 +- griptape/tasks/json_extraction_task.py | 11 ------- griptape/utils/load_artifact_from_memory.py | 7 +++-- ...st_griptape_cloud_event_listener_driver.py | 4 --- .../extraction/test_csv_extraction_engine.py | 4 +-- .../extraction/test_json_extraction_engine.py | 9 +++--- .../storage/test_blob_artifact_storage.py | 10 ------ .../storage/test_text_artifact_storage.py | 10 ------ tests/unit/memory/tool/test_task_memory.py | 10 ------ tests/unit/tasks/test_csv_extraction_task.py | 28 ----------------- tests/unit/tasks/test_extraction_task.py | 2 +- tests/unit/tasks/test_json_extraction_task.py | 31 ------------------- tests/unit/tools/test_task_memory_client.py | 29 ----------------- 19 files changed, 42 insertions(+), 170 deletions(-) delete mode 100644 griptape/tasks/csv_extraction_task.py delete mode 100644 griptape/tasks/json_extraction_task.py delete mode 100644 tests/unit/tasks/test_csv_extraction_task.py delete mode 100644 tests/unit/tasks/test_json_extraction_task.py delete mode 100644 tests/unit/tools/test_task_memory_client.py diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 54230b999..0be304211 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -26,7 +26,7 @@ Griptape offers the following Prompt Drivers for interacting with LLMs. ### OpenAI Chat The [OpenAiChatPromptDriver](../../reference/griptape/drivers/prompt/openai_chat_prompt_driver.md) connects to the [OpenAI Chat](https://platform.openai.com/docs/guides/chat) API. -This driver uses [OpenAi function calling](https://platform.openai.com/docs/guides/function-calling) when using [Tools](../tools/index.md). +This driver uses [OpenAI function calling](https://platform.openai.com/docs/guides/function-calling) when using [Tools](../tools/index.md). ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_3.py" diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index f4fcd5d3a..4b1184e5e 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -1,21 +1,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker from griptape.config import config -from griptape.mixins.rule_mixin import RuleMixin if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact from griptape.drivers import BasePromptDriver + from griptape.rules import Ruleset @define -class BaseExtractionEngine(ABC, RuleMixin): +class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) @@ -49,5 +49,7 @@ def min_response_tokens(self) -> int: def extract( self, text: str | ListArtifact, + *, + rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: ... diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 9297d8377..1ddef52f1 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,7 +2,7 @@ import csv import io -from typing import cast +from typing import TYPE_CHECKING, Optional, cast from attrs import Factory, define, field @@ -11,6 +11,9 @@ from griptape.engines import BaseExtractionEngine from griptape.utils import J2 +if TYPE_CHECKING: + from griptape.rules import Ruleset + @define class CsvExtractionEngine(BaseExtractionEngine): @@ -20,6 +23,8 @@ class CsvExtractionEngine(BaseExtractionEngine): def extract( self, text: str | ListArtifact, + *, + rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: try: @@ -46,11 +51,13 @@ def _extract_rec( self, artifacts: list[TextArtifact], rows: list[CsvRowArtifact], + *, + rulesets: Optional[list[Ruleset]] = None, ) -> list[CsvRowArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) full_text = self.template_generator.render( text=artifacts_text, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: @@ -67,7 +74,7 @@ def _extract_rec( partial_text = self.template_generator.render( column_names=self.column_names, text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) rows.extend( @@ -77,4 +84,4 @@ def _extract_rec( ), ) - return self._extract_rec(chunks[1:], rows) + return self._extract_rec(chunks[1:], rows, rulesets=rulesets) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 8bd46ed8b..7e1ac23ca 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import cast +from typing import TYPE_CHECKING, Optional, cast from attrs import Factory, define, field @@ -11,6 +11,9 @@ from griptape.engines import BaseExtractionEngine from griptape.utils import J2 +if TYPE_CHECKING: + from griptape.rules import Ruleset + @define class JsonExtractionEngine(BaseExtractionEngine): @@ -20,6 +23,8 @@ class JsonExtractionEngine(BaseExtractionEngine): def extract( self, text: str | ListArtifact, + *, + rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: try: @@ -40,12 +45,14 @@ def _extract_rec( self, artifacts: list[TextArtifact], extractions: list[TextArtifact], + *, + rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) full_text = self.template_generator.render( json_template_schema=json.dumps(self.template_schema), text=artifacts_text, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: @@ -61,7 +68,7 @@ def _extract_rec( partial_text = self.template_generator.render( template_schema=self.template_schema, text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), + rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) extractions.extend( @@ -70,4 +77,4 @@ def _extract_rec( ), ) - return self._extract_rec(chunks[1:], extractions) + return self._extract_rec(chunks[1:], extractions, rulesets=rulesets) diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 764d1669a..7d08cf858 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -8,8 +8,6 @@ from .tool_task import ToolTask from .rag_task import RagTask from .extraction_task import ExtractionTask -from .csv_extraction_task import CsvExtractionTask -from .json_extraction_task import JsonExtractionTask from .base_image_generation_task import BaseImageGenerationTask from .code_execution_task import CodeExecutionTask from .prompt_image_generation_task import PromptImageGenerationTask @@ -33,8 +31,6 @@ "ToolTask", "RagTask", "ExtractionTask", - "CsvExtractionTask", - "JsonExtractionTask", "BaseImageGenerationTask", "CodeExecutionTask", "PromptImageGenerationTask", diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py deleted file mode 100644 index c252893de..000000000 --- a/griptape/tasks/csv_extraction_task.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from attrs import Factory, define, field - -from griptape.engines import CsvExtractionEngine -from griptape.tasks import ExtractionTask - - -@define -class CsvExtractionTask(ExtractionTask): - extraction_engine: CsvExtractionEngine = field(default=Factory(lambda: CsvExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index a1c18eff0..c74c3ac49 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -14,7 +14,7 @@ @define class ExtractionTask(BaseTextInputTask): extraction_engine: BaseExtractionEngine = field(kw_only=True) - args: dict = field(kw_only=True) + args: dict = field(kw_only=True, factory=dict) def run(self) -> ListArtifact | ErrorArtifact: return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args) diff --git a/griptape/tasks/json_extraction_task.py b/griptape/tasks/json_extraction_task.py deleted file mode 100644 index 94db187da..000000000 --- a/griptape/tasks/json_extraction_task.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from attrs import Factory, define, field - -from griptape.engines import JsonExtractionEngine -from griptape.tasks import ExtractionTask - - -@define -class JsonExtractionTask(ExtractionTask): - extraction_engine: JsonExtractionEngine = field(default=Factory(lambda: JsonExtractionEngine()), kw_only=True) diff --git a/griptape/utils/load_artifact_from_memory.py b/griptape/utils/load_artifact_from_memory.py index ec260787a..2d3f8bc86 100644 --- a/griptape/utils/load_artifact_from_memory.py +++ b/griptape/utils/load_artifact_from_memory.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -8,11 +8,14 @@ def load_artifact_from_memory( - memory: TaskMemory, + memory: Optional[TaskMemory], artifact_namespace: str, artifact_name: str, artifact_type: type, ) -> BaseArtifact: + if memory is None: + raise ValueError("memory not found") + artifacts = memory.load_artifacts(namespace=artifact_namespace) if len(artifacts) == 0: raise ValueError("no artifacts found in namespace") diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index 1cd198756..0bf298870 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -87,7 +87,3 @@ def try_publish_event_payload_batch(self, mock_post, driver): json=event.to_dict(), headers={"Authorization": "Bearer foo bar"}, ) - - def test_no_structure_run_id(self): - with pytest.raises(ValueError): - GriptapeCloudEventListenerDriver(api_key="foo bar") diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index d84fc7cdd..893c21d60 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -6,10 +6,10 @@ class TestCsvExtractionEngine: @pytest.fixture() def engine(self): - return CsvExtractionEngine() + return CsvExtractionEngine(column_names=["test1"]) def test_extract(self, engine): - result = engine.extract("foo", column_names=["test1"]) + result = engine.extract("foo") assert len(result.value) == 1 assert result.value[0].value == {"test1": "mock output"} diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index d95adbb43..2d4626d3a 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -12,19 +12,20 @@ def engine(self): return JsonExtractionEngine( prompt_driver=MockPromptDriver( mock_output='[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - ) + ), + template_schema=Schema({"foo": "bar"}).json_schema("TemplateSchema"), ) def test_extract(self, engine): - json_schema = Schema({"foo": "bar"}).json_schema("TemplateSchema") - result = engine.extract("foo", template_schema=json_schema) + result = engine.extract("foo") assert len(result.value) == 2 assert result.value[0].value == '{"test_key_1": "test_value_1"}' assert result.value[1].value == '{"test_key_2": "test_value_2"}' def test_extract_error(self, engine): - assert isinstance(engine.extract("foo", template_schema=lambda: "non serializable"), ErrorArtifact) + engine.template_schema = lambda: "non serializable" + assert isinstance(engine.extract("foo"), ErrorArtifact) def test_json_to_text_artifacts(self, engine): assert [ diff --git a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py index c7f2cfcbd..78d1c662d 100644 --- a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py @@ -25,13 +25,3 @@ def test_load_artifacts(self, storage): def test_can_store(self, storage): assert not storage.can_store(TextArtifact("foo")) assert storage.can_store(BlobArtifact(b"foo")) - - def test_summarize(self, storage): - storage.store_artifact("foo", BlobArtifact(b"test")) - - assert storage.summarize("foo").value == "can't summarize artifacts" - - def test_query(self, storage): - storage.store_artifact("foo", BlobArtifact(b"test")) - - assert storage.query("foo", "query").value == "can't query artifacts" diff --git a/tests/unit/memory/tool/storage/test_text_artifact_storage.py b/tests/unit/memory/tool/storage/test_text_artifact_storage.py index 64f44c581..2f49421d4 100644 --- a/tests/unit/memory/tool/storage/test_text_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_text_artifact_storage.py @@ -25,13 +25,3 @@ def test_load_artifacts(self, storage): def test_can_store(self, storage): assert storage.can_store(TextArtifact("foo")) assert not storage.can_store(BlobArtifact(b"foo")) - - def test_summarize(self, storage): - storage.store_artifact("foo", TextArtifact("test")) - - assert storage.summarize("foo").value == "mock output" - - def test_query(self, storage): - storage.store_artifact("foo", TextArtifact("test")) - - assert storage.query("foo", "query").value == "mock output" diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index 53e4703a6..2f6ffe1c9 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -96,13 +96,3 @@ def test_load_artifacts_for_blob_list_artifact(self, memory): ) assert len(memory.load_artifacts("test")) == 2 - - def test_summarize_namespace(self, memory): - memory.store_artifact("foo", TextArtifact("test")) - - assert memory.summarize_namespace("foo").value == "mock output" - - def test_query_namespace(self, memory): - memory.store_artifact("foo", TextArtifact("test")) - - assert memory.query_namespace("foo", "query").value == "mock output" diff --git a/tests/unit/tasks/test_csv_extraction_task.py b/tests/unit/tasks/test_csv_extraction_task.py deleted file mode 100644 index ec8f70b23..000000000 --- a/tests/unit/tasks/test_csv_extraction_task.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -from griptape.engines import CsvExtractionEngine -from griptape.structures import Agent -from griptape.tasks import CsvExtractionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver - - -class TestCsvExtractionTask: - @pytest.fixture() - def task(self): - return CsvExtractionTask(args={"column_names": ["test1"]}) - - def test_run(self, task): - agent = Agent() - - agent.add_task(task) - - result = task.run() - - assert len(result.value) == 1 - assert result.value[0].value == {"test1": "mock output"} - - def test_config_extraction_engine(self, task): - Agent().add_task(task) - - assert isinstance(task.extraction_engine, CsvExtractionEngine) - assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index 76a4c3bd2..2d7ab442c 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -8,7 +8,7 @@ class TestExtractionTask: @pytest.fixture() def task(self): - return ExtractionTask(extraction_engine=CsvExtractionEngine(), args={"column_names": ["test1"]}) + return ExtractionTask(extraction_engine=CsvExtractionEngine(column_names=["test1"])) def test_run(self, task): agent = Agent() diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py deleted file mode 100644 index 8f9278c3c..000000000 --- a/tests/unit/tasks/test_json_extraction_task.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -from schema import Schema - -from griptape.engines import JsonExtractionEngine -from griptape.structures import Agent -from griptape.tasks import JsonExtractionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver - - -class TestJsonExtractionTask: - @pytest.fixture() - def task(self): - return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) - - def test_run(self, task, mock_config): - mock_config.drivers.prompt.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - agent = Agent() - - agent.add_task(task) - - result = task.run() - - assert len(result.value) == 2 - assert result.value[0].value == '{"test_key_1": "test_value_1"}' - assert result.value[1].value == '{"test_key_2": "test_value_2"}' - - def test_config_extraction_engine(self, task): - Agent().add_task(task) - - assert isinstance(task.extraction_engine, JsonExtractionEngine) - assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) diff --git a/tests/unit/tools/test_task_memory_client.py b/tests/unit/tools/test_task_memory_client.py deleted file mode 100644 index 4276b89ec..000000000 --- a/tests/unit/tools/test_task_memory_client.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -from griptape.artifacts import TextArtifact -from griptape.tools import TaskMemoryClient -from tests.utils import defaults - - -class TestTaskMemoryClient: - @pytest.fixture() - def tool(self): - return TaskMemoryClient(off_prompt=True, input_memory=[defaults.text_task_memory("TestMemory")]) - - def test_summarize(self, tool): - tool.input_memory[0].store_artifact("foo", TextArtifact("test")) - - assert ( - tool.summarize({"values": {"memory_name": tool.input_memory[0].name, "artifact_namespace": "foo"}}).value - == "mock output" - ) - - def test_query(self, tool): - tool.input_memory[0].store_artifact("foo", TextArtifact("test")) - - assert ( - tool.query( - {"values": {"query": "foobar", "memory_name": tool.input_memory[0].name, "artifact_namespace": "foo"}} - ).value - == "mock output" - ) From 8322811dda267c5f962d8918a2ca705ad75fb27e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 11:39:21 -0700 Subject: [PATCH 41/63] Improve json/csv extraction performancek --- .../extraction/csv_extraction_engine.py | 37 ++++++++++---- .../extraction/json_extraction_engine.py | 48 +++++++++++++++---- .../engines/extraction/csv/system.j2 | 7 +++ .../templates/engines/extraction/csv/user.j2 | 4 ++ .../engines/extraction/csv_extraction.j2 | 11 ----- .../engines/extraction/json/system.j2 | 6 +++ .../{json_extraction.j2 => json/user.j2} | 9 +--- 7 files changed, 84 insertions(+), 38 deletions(-) create mode 100644 griptape/templates/engines/extraction/csv/system.j2 create mode 100644 griptape/templates/engines/extraction/csv/user.j2 delete mode 100644 griptape/templates/engines/extraction/csv_extraction.j2 create mode 100644 griptape/templates/engines/extraction/json/system.j2 rename griptape/templates/engines/extraction/{json_extraction.j2 => json/user.j2} (56%) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 1ddef52f1..c9c040f65 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -18,7 +18,8 @@ @define class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(default=Factory(list), kw_only=True) - template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv_extraction.j2")), kw_only=True) + system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) + user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) def extract( self, @@ -55,15 +56,28 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[CsvRowArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - full_text = self.template_generator.render( - text=artifacts_text, + system_prompt = self.system_template_generator.render( + column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) + user_prompt = self.user_template_generator.render( + text=artifacts_text, + ) - if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: + if ( + self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt) + >= self.min_response_tokens + ): rows.extend( self.text_to_csv_rows( - self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(user_prompt, role=Message.USER_ROLE), + ] + ) + ).value, self.column_names, ), ) @@ -71,15 +85,20 @@ def _extract_rec( return rows else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.template_generator.render( - column_names=self.column_names, + partial_text = self.user_template_generator.render( text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) rows.extend( self.text_to_csv_rows( - self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(partial_text, role=Message.USER_ROLE), + ] + ) + ).value, self.column_names, ), ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 7e1ac23ca..56815bc06 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import re from typing import TYPE_CHECKING, Optional, cast from attrs import Factory, define, field @@ -17,8 +18,13 @@ @define class JsonExtractionEngine(BaseExtractionEngine): + JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" + template_schema: dict = field(default=Factory(dict), kw_only=True) - template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json_extraction.j2")), kw_only=True) + system_template_generator: J2 = field( + default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True + ) + user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) def extract( self, @@ -39,7 +45,12 @@ def extract( return ErrorArtifact(f"error extracting JSON: {e}") def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: - return [TextArtifact(json.dumps(e)) for e in json.loads(json_input)] + json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) + + if json_matches: + return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])] + else: + return [] def _extract_rec( self, @@ -49,31 +60,48 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - full_text = self.template_generator.render( + system_prompt = self.system_template_generator.render( json_template_schema=json.dumps(self.template_schema), - text=artifacts_text, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) + user_prompt = self.user_template_generator.render( + text=artifacts_text, + ) - if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: + if ( + self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) + >= self.min_response_tokens + ): extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(user_prompt, role=Message.USER_ROLE), + ] + ) + ).value ), ) return extractions else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.template_generator.render( - template_schema=self.template_schema, + partial_text = self.user_template_generator.render( text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(partial_text, role=Message.USER_ROLE), + ] + ) + ).value, ), ) diff --git a/griptape/templates/engines/extraction/csv/system.j2 b/griptape/templates/engines/extraction/csv/system.j2 new file mode 100644 index 000000000..7c5776257 --- /dev/null +++ b/griptape/templates/engines/extraction/csv/system.j2 @@ -0,0 +1,7 @@ +Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes. +Column Names: """{{ column_names }}""" + +{% if rulesets %} + +{{ rulesets }} +{% endif %} diff --git a/griptape/templates/engines/extraction/csv/user.j2 b/griptape/templates/engines/extraction/csv/user.j2 new file mode 100644 index 000000000..0f33dadc3 --- /dev/null +++ b/griptape/templates/engines/extraction/csv/user.j2 @@ -0,0 +1,4 @@ +Extract information from the Text based on the Column Names and output it as a CSV file. +Text: """{{ text }}""" + +Answer: diff --git a/griptape/templates/engines/extraction/csv_extraction.j2 b/griptape/templates/engines/extraction/csv_extraction.j2 deleted file mode 100644 index 6f9da346b..000000000 --- a/griptape/templates/engines/extraction/csv_extraction.j2 +++ /dev/null @@ -1,11 +0,0 @@ -Text: """{{ text }}""" - -Column Names: """{{ column_names }}""" - -Extract information from the Text based on the Column Names and output it as a CSV file. Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes. -{% if rulesets %} - -{{ rulesets }} -{% endif %} - -Answer: diff --git a/griptape/templates/engines/extraction/json/system.j2 b/griptape/templates/engines/extraction/json/system.j2 new file mode 100644 index 000000000..987ff19a9 --- /dev/null +++ b/griptape/templates/engines/extraction/json/system.j2 @@ -0,0 +1,6 @@ +Extraction Template JSON Schema: """{{ json_template_schema }}""" + +{% if rulesets %} + +{{ rulesets }} +{% endif %} diff --git a/griptape/templates/engines/extraction/json_extraction.j2 b/griptape/templates/engines/extraction/json/user.j2 similarity index 56% rename from griptape/templates/engines/extraction/json_extraction.j2 rename to griptape/templates/engines/extraction/json/user.j2 index 85d95bef9..984977d9a 100644 --- a/griptape/templates/engines/extraction/json_extraction.j2 +++ b/griptape/templates/engines/extraction/json/user.j2 @@ -1,11 +1,4 @@ -Text: """{{ text }}""" - -Extraction Template JSON Schema: """{{ json_template_schema }}""" - Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects. -{% if rulesets %} - -{{ rulesets }} -{% endif %} +Text: """{{ text }}""" JSON array: From 98b67b24a787a3fa56c4363afb1173b74733ca4f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 11:42:34 -0700 Subject: [PATCH 42/63] Support rules on extracton client --- griptape/tools/extraction_client/tool.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/griptape/tools/extraction_client/tool.py b/griptape/tools/extraction_client/tool.py index 911810a70..ce8c5b034 100644 --- a/griptape/tools/extraction_client/tool.py +++ b/griptape/tools/extraction_client/tool.py @@ -7,6 +7,7 @@ from griptape.artifacts import ErrorArtifact from griptape.engines import CsvExtractionEngine, JsonExtractionEngine +from griptape.mixins import RuleMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -16,7 +17,7 @@ @define(kw_only=True) -class ExtractionClient(BaseTool): +class ExtractionClient(BaseTool, RuleMixin): """Tool for using an Extraction Engine. Attributes: @@ -77,7 +78,7 @@ def _extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: data = params["values"]["data"] if isinstance(data, str): - return self.extraction_engine.extract(data) + return self.extraction_engine.extract(data, rulesets=self.rulesets) else: memory = self.find_input_memory(data["memory_name"]) artifact_namespace = data["artifact_namespace"] From 4e3dc5f245117df0749175289f454e9ccac164f4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 12:54:28 -0700 Subject: [PATCH 43/63] Add rules to tools --- griptape/tools/prompt_summary_client/tool.py | 13 ++-- griptape/tools/rag_client/tool.py | 70 ++++++++++++++++---- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/griptape/tools/prompt_summary_client/tool.py b/griptape/tools/prompt_summary_client/tool.py index 37b42a3ac..f71b8f325 100644 --- a/griptape/tools/prompt_summary_client/tool.py +++ b/griptape/tools/prompt_summary_client/tool.py @@ -1,27 +1,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from attrs import define, field from schema import Literal, Or, Schema from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.engines import PromptSummaryEngine +from griptape.mixins.rule_mixin import RuleMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity -if TYPE_CHECKING: - from griptape.engines import PromptSummaryEngine - @define(kw_only=True) -class PromptSummaryClient(BaseTool): +class PromptSummaryClient(BaseTool, RuleMixin): """Tool for using a Prompt Summary Engine. Attributes: prompt_summary_engine: `PromptSummaryEngine`. """ - prompt_summary_engine: PromptSummaryEngine = field(kw_only=True) + prompt_summary_engine: PromptSummaryEngine = field(kw_only=True, default=PromptSummaryEngine()) @activity( config={ @@ -55,4 +52,4 @@ def summarize(self, params: dict) -> BaseArtifact: else: return ErrorArtifact("memory not found") - return self.prompt_summary_engine.summarize_artifacts(artifacts) + return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.all_rulesets) diff --git a/griptape/tools/rag_client/tool.py b/griptape/tools/rag_client/tool.py index bbdef8159..a1a868e24 100644 --- a/griptape/tools/rag_client/tool.py +++ b/griptape/tools/rag_client/tool.py @@ -1,20 +1,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from attrs import Factory, define, field +from schema import Literal, Or, Schema -from attrs import define, field -from schema import Literal, Schema - -from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact +from griptape.engines.rag import RagEngine +from griptape.engines.rag.modules import ( + MetadataBeforeResponseRagModule, + PromptResponseRagModule, + RulesetsBeforeResponseRagModule, +) +from griptape.engines.rag.rag_context import RagContext +from griptape.engines.rag.stages import ResponseRagStage +from griptape.mixins.rule_mixin import RuleMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity -if TYPE_CHECKING: - from griptape.engines.rag import RagEngine - @define(kw_only=True) -class RagClient(BaseTool): +class RagClient(BaseTool, RuleMixin): """Tool for querying a RAG engine. Attributes: @@ -23,19 +27,59 @@ class RagClient(BaseTool): """ description: str = field() - rag_engine: RagEngine = field() + rag_engine: RagEngine = field( + default=Factory( + lambda self: RagEngine( + response_stage=ResponseRagStage( + before_response_modules=[ + RulesetsBeforeResponseRagModule(rulesets=self.all_rulesets), + MetadataBeforeResponseRagModule(), + ], + response_module=PromptResponseRagModule(), + ), + ), + takes_self=True, + ) + ) @activity( config={ - "description": "{{ _self.description }}", - "schema": Schema({Literal("query", description="A natural language search query"): str}), + "description": "Can be used to search content with the following description: {{ _self.description }}", + "schema": Schema( + { + Literal("query", description="A natural language search query"): str, + Literal("content"): Or( + str, + Schema( + { + "memory_name": str, + "artifact_namespace": str, + } + ), + ), + } + ), }, ) def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] + summary = params["values"]["content"] + + if isinstance(summary, str): + text_artifacts = [TextArtifact(summary)] + else: + memory = self.find_input_memory(summary["memory_name"]) + artifact_namespace = summary["artifact_namespace"] + + if memory is not None: + artifacts = memory.load_artifacts(artifact_namespace) + else: + return ErrorArtifact("memory not found") + + text_artifacts = [artifact for artifact in artifacts if isinstance(artifact, TextArtifact)] try: - result = self.rag_engine.process_query(query) + result = self.rag_engine.process(RagContext(query=query, text_chunks=text_artifacts)) if result.output is None: return ErrorArtifact("query output is empty") From d0fab25e896c64ed78920b14d8a47f9528341019 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 12 Aug 2024 11:51:22 -0700 Subject: [PATCH 44/63] Fix utilities checking for stream --- griptape/utils/chat.py | 7 +++++-- griptape/utils/stream.py | 9 ++++++--- tests/unit/utils/test_stream.py | 14 +++++++++----- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 56b53c0ce..07fea92d8 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,9 +25,12 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - from griptape.config import config + from griptape.tasks.prompt_task import PromptTask - if config.drivers.prompt.stream: + streaming_tasks = [ + task for task in self.structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream + ] + if streaming_tasks: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index c5545bc44..6da58b9e6 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,10 +34,13 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - from griptape.config import config + from griptape.tasks import PromptTask - if not config.drivers.prompt.stream: - raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") + streaming_tasks = [ + task for task in structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream + ] + if not streaming_tasks: + raise ValueError("Structure does not have any streaming tasks, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index edd0258f2..caddbb1a3 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,19 +2,17 @@ import pytest -from griptape.config import config -from griptape.structures import Agent +from griptape.structures import Agent, Pipeline from griptape.utils import Stream class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - config.drivers.prompt.stream = request.param - return Agent() + return Agent(stream=request.param) def test_init(self, agent): - if config.drivers.prompt.stream: + if agent.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent @@ -29,3 +27,9 @@ def test_init(self, agent): else: with pytest.raises(ValueError): Stream(agent) + + def test_validate_structure_invalid(self): + pipeline = Pipeline(tasks=[]) + + with pytest.raises(ValueError): + Stream(pipeline) From 9db6797527fed375e6d101559fc0432fa1f0a397 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 12 Aug 2024 15:30:04 -0700 Subject: [PATCH 45/63] Remove TaskMemoryClient --- CHANGELOG.md | 17 ++ README.md | 73 ++++-- _typos.toml | 4 + docs/examples/multiple-agent-shared-memory.md | 4 +- docs/examples/src/multi_agent_workflow_1.py | 4 +- .../src/multiple_agent_shared_memory_1.py | 4 +- docs/examples/src/query_webpage_astra_db_1.py | 4 +- .../drivers/prompt-drivers.md | 2 +- .../drivers/src/embedding_drivers_10.py | 4 +- .../drivers/src/web_scraper_drivers_3.py | 3 +- .../drivers/src/web_search_drivers_2.py | 4 +- docs/griptape-framework/index.md | 109 +++++--- docs/griptape-framework/misc/src/events_3.py | 4 +- docs/griptape-framework/misc/src/events_4.py | 4 +- docs/griptape-framework/src/index_4.py | 4 +- .../structures/src/task_memory_3.py | 4 +- .../structures/src/task_memory_5.py | 4 +- .../structures/src/task_memory_6.py | 4 +- .../structures/src/task_memory_8.py | 4 +- .../structures/src/tasks_16.py | 4 +- .../structures/src/tasks_4.py | 4 +- .../structures/task-memory.md | 241 +++++++++++------- docs/griptape-framework/structures/tasks.md | 88 ++++--- docs/griptape-framework/tools/index.md | 101 +++++--- docs/griptape-framework/tools/src/index_1.py | 8 +- .../official-tools/aws-iam-client.md | 99 +++---- .../official-tools/aws-s3-client.md | 73 +++--- .../griptape-tools/official-tools/computer.md | 112 +++----- .../official-tools/sql-client.md | 81 +++--- .../official-tools/src/aws_s3_client_1.py | 4 +- .../official-tools/src/computer_1.py | 11 +- .../src/task_memory_client_1.py | 4 - .../src/vector_store_client_1.py | 4 +- .../official-tools/src/web_scraper_1.py | 4 +- .../official-tools/task-memory-client.md | 7 - .../official-tools/web-scraper.md | 146 ++++++----- griptape/tools/__init__.py | 4 +- .../__init__.py | 0 griptape/tools/query_client/manifest.yml | 5 + griptape/tools/query_client/requirements.txt | 0 griptape/tools/query_client/tool.py | 78 ++++++ griptape/tools/rag_client/tool.py | 66 +---- .../tools/task_memory_client/manifest.yml | 5 - griptape/tools/task_memory_client/tool.py | 41 --- mkdocs.yml | 1 - tests/integration/tasks/test_toolkit_task.py | 4 +- 46 files changed, 806 insertions(+), 649 deletions(-) delete mode 100644 docs/griptape-tools/official-tools/src/task_memory_client_1.py delete mode 100644 docs/griptape-tools/official-tools/task-memory-client.md rename griptape/tools/{task_memory_client => query_client}/__init__.py (100%) create mode 100644 griptape/tools/query_client/manifest.yml create mode 100644 griptape/tools/query_client/requirements.txt create mode 100644 griptape/tools/query_client/tool.py delete mode 100644 griptape/tools/task_memory_client/manifest.yml delete mode 100644 griptape/tools/task_memory_client/tool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b7b08f36..6e92817b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. - Global config, `griptape.config.config`, for setting global configuration defaults. - Unique name generation for all `RagEngine` modules. +- `ExtractionClient` Tool for having the LLM extract structured data from text. +- `PromptSummaryClient` Tool for having the LLM summarize text. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. @@ -28,9 +30,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. - **BREAKING**: Removed before and after response modules from `ResponseRagStage`. - **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. +- **BREAKING**: Removed `TextArtifactStorage.query` and `TextArtifactStorage.summarize`. +- **BREAKING**: Removed `TextArtifactStorage.rag_engine`, and `TextArtifactStorage.retrieval_rag_module_name`. +- **BREAKING**: Removed `TextArtifactStorage.summary_engine`, `TextArtifactStorage.csv_extraction_engine`, and `TextArtifactStorage.json_extraction_engine`. +- **BREAKING**: Removed `TaskMemory.summarize_namespace` and `TaskMemory.query_namespace`. +- **BREAKING**: Removed `Structure.rag_engine`. +- **BREAKING**: Split `JsonExtractionEngine.template_generator` into `JsonExtractionEngine.system_template_generator` and `JsonExtractionEngine.user_template_generator`. +- **BREAKING**: Split `CsvExtractionEngine.template_generator` into `CsvExtractionEngine.system_template_generator` and `CsvExtractionEngine.user_template_generator`. +- **BREAKING**: Changed `JsonExtractionEngine.template_schema` from a `run` argument to a class attribute. +- **BREAKING**: Changed `CsvExtractionEngine.column_names` from a `run` argument to a class attribute. +- **BREAKING**: Removed `JsonExtractionTask`, and `CsvExtractionTask` use `ExtractionTask` instead. +- **BREAKING**: Removed `TaskMemoryClient`, use `RagClient`, `ExtractionClient`, or `PromptSummaryClient` instead. +- `RagClient` now can be used to search through Artifacts stored in Task Memory. - Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. +### Fixed +- `JsonExtractionEngine` failing to parse json when the LLM outputs more than just the json. + ## [0.29.1] - 2024-08-02 ### Changed diff --git a/README.md b/README.md index 42df8fbd6..555efd984 100644 --- a/README.md +++ b/README.md @@ -89,13 +89,13 @@ With Griptape, you can create Structures, such as Agents, Pipelines, and Workflo ```python from griptape.structures import Agent -from griptape.tools import WebScraper, FileManager, TaskMemoryClient +from griptape.tools import WebScraper, FileManager, PromptSummaryClient agent = Agent( input="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.", tools=[ WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=True), + PromptSummaryClient(off_prompt=True), FileManager() ] ) @@ -104,42 +104,63 @@ agent.run("https://griptape.ai", "griptape.txt") And here is the output: ``` -[04/02/24 13:51:09] INFO ToolkitTask 85700ec1b0594e1a9502c0efe7da6ef4 +[08/12/24 14:48:15] INFO ToolkitTask c90d263ec69046e8b30323c131ae4ba0 Input: Load https://griptape.ai, summarize it, and store it in a file called griptape.txt. -[04/02/24 13:51:15] INFO Subtask db6a3e7cb2f549128c358149d340f91c - Thought: First, I need to load the content of the website using the WebScraper action. Then, I will use the TaskMemoryClient action to - summarize the content. Finally, I will save the summarized content to a file using the FileManager action. +[08/12/24 14:48:16] INFO Subtask ebe23832cbe2464fb9ecde9fcee7c30f Actions: [ { + "tag": "call_62kBnkswnk9Y6GH6kn1GIKk6", "name": "WebScraper", "path": "get_content", "input": { "values": { "url": "https://griptape.ai" } - }, - "tag": "load_website_content" + } } ] -[04/02/24 13:51:16] INFO Subtask db6a3e7cb2f549128c358149d340f91c +[08/12/24 14:48:17] INFO Subtask ebe23832cbe2464fb9ecde9fcee7c30f Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace - "752b38bb86da4baabdbd9f444eb4a0d1" -[04/02/24 13:51:19] INFO Subtask c3edba87ebf845d4b85e3a791f8fde8d - Thought: Now that the website content is loaded into memory, I need to summarize it using the TaskMemoryClient action. - Actions: [{"tag": "summarize_content", "name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", - "artifact_namespace": "752b38bb86da4baabdbd9f444eb4a0d1"}}}] -[04/02/24 13:51:25] INFO Subtask c3edba87ebf845d4b85e3a791f8fde8d - Response: Output of "TaskMemoryClient.summarize" was stored in memory with memory_name "TaskMemory" and artifact_namespace - "c4f131c201f147dcab07be3925b46294" -[04/02/24 13:51:33] INFO Subtask 06fe01ca64a744b38a8c08eb152aaacb - Thought: Now that the content has been summarized and stored in memory, I need to save this summarized content to a file named 'griptape.txt' - using the FileManager action. - Actions: [{"tag": "save_summarized_content", "name": "FileManager", "path": "save_memory_artifacts_to_disk", "input": {"values": {"dir_name": - ".", "file_name": "griptape.txt", "memory_name": "TaskMemory", "artifact_namespace": "c4f131c201f147dcab07be3925b46294"}}}] - INFO Subtask 06fe01ca64a744b38a8c08eb152aaacb - Response: saved successfully -[04/02/24 13:51:35] INFO ToolkitTask 85700ec1b0594e1a9502c0efe7da6ef4 - Output: The summarized content of the website https://griptape.ai has been successfully saved to a file named 'griptape.txt'. + "cecca28eb0c74bcd8c7119ed7f790c95" +[08/12/24 14:48:18] INFO Subtask dca04901436d49d2ade86cd6b4e1038a + Actions: [ + { + "tag": "call_o9F1taIxHty0mDlWLcAjTAAu", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "cecca28eb0c74bcd8c7119ed7f790c95" + } + } + } + } + ] +[08/12/24 14:48:21] INFO Subtask dca04901436d49d2ade86cd6b4e1038a + Response: Output of "PromptSummaryClient.summarize" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "73765e32b8404e32927822250dc2ae8b" +[08/12/24 14:48:22] INFO Subtask c233853450fb4fd6a3e9c04c52b33bf6 + Actions: [ + { + "tag": "call_eKvIUIw45aRYKDBpT1gGKc9b", + "name": "FileManager", + "path": "save_memory_artifacts_to_disk", + "input": { + "values": { + "dir_name": ".", + "file_name": "griptape.txt", + "memory_name": "TaskMemory", + "artifact_namespace": "73765e32b8404e32927822250dc2ae8b" + } + } + } + ] + INFO Subtask c233853450fb4fd6a3e9c04c52b33bf6 + Response: Successfully saved memory artifacts to disk +[08/12/24 14:48:23] INFO ToolkitTask c90d263ec69046e8b30323c131ae4ba0 + Output: The content from https://griptape.ai has been summarized and stored in a file called `griptape.txt`. ``` During the run, the Griptape Agent loaded a webpage with a [Tool](https://docs.griptape.ai/stable/griptape-tools/), stored its full content in [Task Memory](https://docs.griptape.ai/stable/griptape-framework/structures/task-memory.md), queried it to answer the original question, and finally saved the answer to a file. diff --git a/_typos.toml b/_typos.toml index 1819b51ef..29f892536 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,3 +1,6 @@ +[default] +extend-ignore-re = ["call_[[:alnum:]]+"] + [default.extend-words] # Don't correct the state ND ND = "ND" @@ -6,3 +9,4 @@ mornin = "mornin" [files] extend-exclude = ["docs/assets", "tests/resources/"] + diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index e73ffe5d6..41bf8677a 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -1,6 +1,6 @@ -This example shows how to use one `Agent` to load content into `TaskMemory` and get that content from another `Agent` using `TaskMemoryClient`. +This example shows how to use one `Agent` to load content into `TaskMemory` and get that content from another `Agent` using `QueryClient`. -The first `Agent` uses a remote vector store (`MongoDbAtlasVectorStoreDriver` in this example) to handle memory operations. The second `Agent` uses the same instance of `TaskMemory` and the `TaskMemoryClient` with the same `MongoDbAtlasVectorStoreDriver` to get the data. +The first `Agent` uses a remote vector store (`MongoDbAtlasVectorStoreDriver` in this example) to handle memory operations. The second `Agent` uses the same instance of `TaskMemory` and the `QueryClient` with the same `MongoDbAtlasVectorStoreDriver` to get the data. The `MongoDbAtlasVectorStoreDriver` assumes that you have a vector index configured where the path to the content is called `vector`, and the number of dimensions set on the index is `1536` (this is a commonly used number of dimensions for embedding models). diff --git a/docs/examples/src/multi_agent_workflow_1.py b/docs/examples/src/multi_agent_workflow_1.py index 311fa22a2..6a3b31c19 100644 --- a/docs/examples/src/multi_agent_workflow_1.py +++ b/docs/examples/src/multi_agent_workflow_1.py @@ -5,7 +5,7 @@ from griptape.structures import Agent, Workflow from griptape.tasks import PromptTask, StructureRunTask from griptape.tools import ( - TaskMemoryClient, + PromptSummaryClient, WebScraper, WebSearch, ) @@ -38,7 +38,7 @@ def build_researcher() -> Agent: WebScraper( off_prompt=True, ), - TaskMemoryClient(off_prompt=False), + PromptSummaryClient(off_prompt=False), ], rulesets=[ Ruleset( diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index 118684d37..dafa50559 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -3,7 +3,7 @@ from griptape.config import AzureOpenAiDriverConfig, config from griptape.drivers import AzureMongoDbVectorStoreDriver, AzureOpenAiEmbeddingDriver from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import QueryClient, WebScraper AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -46,7 +46,7 @@ ) asker = Agent( tools=[ - TaskMemoryClient(off_prompt=False), + QueryClient(off_prompt=False), ], meta_memory=loader.meta_memory, task_memory=loader.task_memory, diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index b5d2b0a01..d31d3b53f 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -14,7 +14,7 @@ from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import RagClient, TaskMemoryClient +from griptape.tools import QueryClient, RagClient namespace = "datastax_blog" input_blogpost = "www.datastax.com/blog/indexing-all-of-wikipedia-on-a-laptop" @@ -53,5 +53,5 @@ description="A DataStax blog post", rag_engine=engine, ) -agent = Agent(tools=[vector_store_tool, TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[vector_store_tool, QueryClient(off_prompt=False)]) agent.run("What engine made possible to index such an amount of data, " "and what kind of tuning was required?") diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 0be304211..54230b999 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -26,7 +26,7 @@ Griptape offers the following Prompt Drivers for interacting with LLMs. ### OpenAI Chat The [OpenAiChatPromptDriver](../../reference/griptape/drivers/prompt/openai_chat_prompt_driver.md) connects to the [OpenAI Chat](https://platform.openai.com/docs/guides/chat) API. -This driver uses [OpenAI function calling](https://platform.openai.com/docs/guides/function-calling) when using [Tools](../tools/index.md). +This driver uses [OpenAi function calling](https://platform.openai.com/docs/guides/function-calling) when using [Tools](../tools/index.md). ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_3.py" diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index 3ef816b29..ccd6fe710 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -4,7 +4,7 @@ VoyageAiEmbeddingDriver, ) from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import PromptSummaryClient, WebScraper config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), @@ -12,7 +12,7 @@ ) agent = Agent( - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)], ) agent.run("based on https://www.griptape.ai/, tell me what Griptape is") diff --git a/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py b/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py index d9fe11e85..6ccb0dafd 100644 --- a/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py +++ b/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py @@ -1,7 +1,7 @@ from griptape.drivers import MarkdownifyWebScraperDriver from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import WebScraper agent = Agent( tools=[ @@ -9,7 +9,6 @@ web_loader=WebLoader(web_scraper_driver=MarkdownifyWebScraperDriver(timeout=1000)), off_prompt=True, ), - TaskMemoryClient(off_prompt=False), ], ) agent.run("List all email addresses on griptape.ai in a flat numbered markdown list.") diff --git a/docs/griptape-framework/drivers/src/web_search_drivers_2.py b/docs/griptape-framework/drivers/src/web_search_drivers_2.py index 5cde1a9a8..4c2c469a3 100644 --- a/docs/griptape-framework/drivers/src/web_search_drivers_2.py +++ b/docs/griptape-framework/drivers/src/web_search_drivers_2.py @@ -2,7 +2,7 @@ from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebSearch +from griptape.tools import PromptSummaryClient, WebSearch agent = Agent( tools=[ @@ -12,7 +12,7 @@ search_id=os.environ["GOOGLE_API_SEARCH_ID"], ), ), - TaskMemoryClient(off_prompt=False), + PromptSummaryClient(off_prompt=False), ], ) agent.run("Give me some websites with information about AI frameworks.") diff --git a/docs/griptape-framework/index.md b/docs/griptape-framework/index.md index 52805897d..9fd4c18ac 100644 --- a/docs/griptape-framework/index.md +++ b/docs/griptape-framework/index.md @@ -128,44 +128,75 @@ Agents are great for getting started, but they are intentionally limited to a si ``` ``` -[09/08/23 10:02:34] INFO ToolkitTask 3c1d2f4a49384873820a9a8cd8acc983 +[08/12/24 14:50:28] INFO ToolkitTask 19dcf6020968468a91aa8a93c2a3f645 Input: Load https://www.griptape.ai, summarize it, and store it in griptape.txt -[09/08/23 10:02:44] INFO Subtask 42fd56ba100e45688401c5ce32b79a33 - Thought: To complete this task, I need to first load the webpage using the WebScraper tool's get_content - activity. Then, I will summarize the content using the TaskMemory tool's summarize activity. Finally, I will - store the summarized content in a file named griptape.txt using the FileManager tool's save_file_to_disk - activity. - - Action: {"name": "WebScraper", "path": "get_content", "input": {"values": {"url": - "https://www.griptape.ai"}}} -[09/08/23 10:02:45] INFO Subtask 42fd56ba100e45688401c5ce32b79a33 - Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and - artifact_namespace "39ca67bbe26b4e1584193b87ed82170d" -[09/08/23 10:02:53] INFO Subtask 8023e3d257274df29065b22e736faca8 - Thought: Now that the webpage content is stored in memory, I can use the TaskMemory tool's summarize activity - to summarize the content. - Action: {"name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "39ca67bbe26b4e1584193b87ed82170d"}}} -[09/08/23 10:02:57] INFO Subtask 8023e3d257274df29065b22e736faca8 - Response: Griptape is an open source framework that allows developers to build and deploy AI applications - using large language models (LLMs). It provides the ability to create conversational and event-driven apps that - can securely access and manipulate data. The framework enforces structures for predictability and creativity, - allowing developers to easily transition between the two. Griptape Cloud is a managed platform for deploying and - managing AI apps. -[09/08/23 10:03:06] INFO Subtask 7baae700239943c18b5b6b21873f0e13 - Thought: Now that I have the summarized content, I can store it in a file named griptape.txt using the - FileManager tool's save_file_to_disk activity. - Action: {"name": "FileManager", "path": "save_file_to_disk", "input": {"values": - {"memory_name": "TaskMemory", "artifact_namespace": "39ca67bbe26b4e1584193b87ed82170d", "path": - "griptape.txt"}}} - INFO Subtask 7baae700239943c18b5b6b21873f0e13 - Response: saved successfully -[09/08/23 10:03:14] INFO ToolkitTask 3c1d2f4a49384873820a9a8cd8acc983 - Output: The summarized content of the webpage https://www.griptape.ai has been successfully stored in the file - named griptape.txt. - INFO PromptTask 8635925ff23b46f28a740105bd11ca8f - Input: Say the following in spanish: The summarized content of the webpage https://www.griptape.ai has been - successfully stored in the file named griptape.txt. -[09/08/23 10:03:18] INFO PromptTask 8635925ff23b46f28a740105bd11ca8f - Output: El contenido resumido de la página web https://www.griptape.ai se ha almacenado con éxito en el archivo - llamado griptape.txt. +[08/12/24 14:50:30] INFO Subtask a685799379c5421b91768353fc219939 + Actions: [ + { + "tag": "call_YL5Ozd9WUtag4ykR5Agm12Ce", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://www.griptape.ai" + } + } + } + ] +[08/12/24 14:50:31] INFO Subtask a685799379c5421b91768353fc219939 + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "6be3a2e0494841fda966b98bec9ffccb" +[08/12/24 14:50:33] INFO Subtask 1cf0c19843aa4fada5745c4a82eb4237 + Actions: [ + { + "tag": "call_ElTYTPeocOU62I0VjzRqmfoF", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "6be3a2e0494841fda966b98bec9ffccb" + } + } + } + } + ] +[08/12/24 14:50:35] INFO Subtask 1cf0c19843aa4fada5745c4a82eb4237 + Response: Griptape offers a comprehensive solution for building, deploying, and scaling AI applications in the cloud. It provides developers + with a framework and cloud services to create retrieval-driven AI-powered applications. The Griptape Framework allows developers to build + business logic using Python, ensuring better security, performance, and cost-efficiency. It simplifies the creation of Gen AI Agents, Systems of + Agents, Pipelines, Workflows, and RAG implementations without needing extensive knowledge of Gen AI or Prompt Engineering. + + Griptape Cloud handles infrastructure management, offering services like ETL pipelines for data preparation, Retrieval as a Service (RAG) for + generating answers and summaries, and a Structure Runtime (RUN) for building AI agents and workflows. This enables seamless scaling and + integration with client applications, catering to custom projects, turnkey SaaS offerings, and finished apps. +[08/12/24 14:50:38] INFO Subtask aaaeca1a089844d4915d065deb3c00cf + Actions: [ + { + "tag": "call_eKvIUIw45aRYKDBpT1gGKc9b", + "name": "FileManager", + "path": "save_content_to_file", + "input": { + "values": { + "path": "griptape.txt", + "content": "Griptape offers a comprehensive solution for building, deploying, and scaling AI applications in the cloud. It provides + developers with a framework and cloud services to create retrieval-driven AI-powered applications. The Griptape Framework allows developers to + build business logic using Python, ensuring better security, performance, and cost-efficiency. It simplifies the creation of Gen AI Agents, + Systems of Agents, Pipelines, Workflows, and RAG implementations without needing extensive knowledge of Gen AI or Prompt + Engineering.\n\nGriptape Cloud handles infrastructure management, offering services like ETL pipelines for data preparation, Retrieval as a + Service (RAG) for generating answers and summaries, and a Structure Runtime (RUN) for building AI agents and workflows. This enables seamless + scaling and integration with client applications, catering to custom projects, turnkey SaaS offerings, and finished apps." + } + } + } + ] + INFO Subtask aaaeca1a089844d4915d065deb3c00cf + Response: Successfully saved file +[08/12/24 14:50:39] INFO ToolkitTask 19dcf6020968468a91aa8a93c2a3f645 + Output: The content from https://www.griptape.ai has been summarized and stored in griptape.txt. + INFO PromptTask dbbb38f144f445db896dc12854f17ad3 + Input: Say the following in spanish: The content from https://www.griptape.ai has been summarized and stored in griptape.txt. +[08/12/24 14:50:42] INFO PromptTask dbbb38f144f445db896dc12854f17ad3 + Output: El contenido de https://www.griptape.ai ha sido resumido y almacenado en griptape.txt. ``` diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index ab995e018..8c49f8146 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -5,7 +5,7 @@ from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import PromptSummaryClient, WebScraper event_bus.add_event_listeners( [ @@ -23,7 +23,7 @@ pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/misc/src/events_4.py b/docs/griptape-framework/misc/src/events_4.py index f5523cb11..a3fe44b78 100644 --- a/docs/griptape-framework/misc/src/events_4.py +++ b/docs/griptape-framework/misc/src/events_4.py @@ -1,13 +1,13 @@ from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import PromptSummaryClient, WebScraper from griptape.utils import Stream pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/src/index_4.py b/docs/griptape-framework/src/index_4.py index 0bb345438..b1108d9e9 100644 --- a/docs/griptape-framework/src/index_4.py +++ b/docs/griptape-framework/src/index_4.py @@ -1,7 +1,7 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManager, PromptSummaryClient, WebScraper # Pipelines represent sequences of tasks. pipeline = Pipeline(conversation_memory=ConversationMemory()) @@ -11,7 +11,7 @@ ToolkitTask( "{{ args[0] }}", # Add tools for web scraping, and file management - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), PromptSummaryClient(off_prompt=False)], ), # Augment `input` from the previous task. PromptTask("Say the following in spanish: {{ parent_output }}"), diff --git a/docs/griptape-framework/structures/src/task_memory_3.py b/docs/griptape-framework/structures/src/task_memory_3.py index cab4f4e3e..14649a222 100644 --- a/docs/griptape-framework/structures/src/task_memory_3.py +++ b/docs/griptape-framework/structures/src/task_memory_3.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator, TaskMemoryClient +from griptape.tools import Calculator, PromptSummaryClient # Create an agent with the Calculator tool -agent = Agent(tools=[Calculator(off_prompt=True), TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[Calculator(off_prompt=True), PromptSummaryClient(off_prompt=False)]) agent.run("What is the square root of 12345?") diff --git a/docs/griptape-framework/structures/src/task_memory_5.py b/docs/griptape-framework/structures/src/task_memory_5.py index a5d3995a9..255e72397 100644 --- a/docs/griptape-framework/structures/src/task_memory_5.py +++ b/docs/griptape-framework/structures/src/task_memory_5.py @@ -1,10 +1,10 @@ from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import QueryClient, WebScraper agent = Agent( tools=[ WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=False), + QueryClient(off_prompt=False), ] ) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 5a81ee8cd..86c0b045a 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -11,7 +11,7 @@ from griptape.memory import TaskMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManager, QueryClient, WebScraper config.drivers = OpenAiDriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4"), @@ -29,7 +29,7 @@ ), tools=[ WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=True, allowlist=["query"]), + QueryClient(off_prompt=True), FileManager(off_prompt=True), ], ) diff --git a/docs/griptape-framework/structures/src/task_memory_8.py b/docs/griptape-framework/structures/src/task_memory_8.py index 9aba9516f..4f19a235f 100644 --- a/docs/griptape-framework/structures/src/task_memory_8.py +++ b/docs/griptape-framework/structures/src/task_memory_8.py @@ -1,10 +1,10 @@ from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import PromptSummaryClient, WebScraper agent = Agent( tools=[ WebScraper(off_prompt=True), # This tool will store the data in Task Memory - TaskMemoryClient( + PromptSummaryClient( off_prompt=True ), # This tool will store the data back in Task Memory with no way to get it out ] diff --git a/docs/griptape-framework/structures/src/tasks_16.py b/docs/griptape-framework/structures/src/tasks_16.py index 796b836da..ac8d8d5b2 100644 --- a/docs/griptape-framework/structures/src/tasks_16.py +++ b/docs/griptape-framework/structures/src/tasks_16.py @@ -5,7 +5,7 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask from griptape.tools import ( - TaskMemoryClient, + PromptSummaryClient, WebScraper, WebSearch, ) @@ -23,7 +23,7 @@ def build_researcher() -> Agent: WebScraper( off_prompt=True, ), - TaskMemoryClient(off_prompt=False), + PromptSummaryClient(off_prompt=False), ], rulesets=[ Ruleset( diff --git a/docs/griptape-framework/structures/src/tasks_4.py b/docs/griptape-framework/structures/src/tasks_4.py index 43737980b..747a82c26 100644 --- a/docs/griptape-framework/structures/src/tasks_4.py +++ b/docs/griptape-framework/structures/src/tasks_4.py @@ -1,12 +1,12 @@ from griptape.structures import Agent from griptape.tasks import ToolkitTask -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManager, PromptSummaryClient, WebScraper agent = Agent() agent.add_task( ToolkitTask( "Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt", - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), TaskMemoryClient(off_prompt=True)], + tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), PromptSummaryClient(off_prompt=True)], ), ) diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 6b858ff5b..b58b99e94 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -69,18 +69,15 @@ Let's explore what happens when `off_prompt` is set to `True`: ``` When we set `off_prompt` to `True`, the Agent does not function as expected, even generating an error. This is because the Calculator output is being stored in Task Memory but the Agent has no way to access it. -To fix this, we need a [Tool that can read from Task Memory](#tools-that-can-read-from-task-memory) such as the `TaskMemoryClient`. +To fix this, we need a [Tool that can read from Task Memory](#tools-that-can-read-from-task-memory) such as the `PromptSummaryClient`. This is an example of [not providing a Task Memory compatible Tool](#not-providing-a-task-memory-compatible-tool). -## Task Memory Client +## Prompt Summary Client -The [TaskMemoryClient](../../griptape-tools/official-tools/task-memory-client.md) is a Tool that allows an Agent to interact with Task Memory. It has the following methods: +The [PromptSummaryClient](../../griptape-tools/official-tools/prompt-summary-client.md) is a Tool that allows an Agent to summarize the Artifacts in Task Memory. It has the following methods: -- `query`: Retrieve the content of an Artifact stored in Task Memory. -- `summarize`: Summarize the content of an Artifact stored in Task Memory. - -Let's add `TaskMemoryClient` to the Agent and run the same task. -Note that on the `TaskMemoryClient` we've set `off_prompt` to `False` so that the results of the query can be returned directly to the LLM. +Let's add `PromptSummaryClient` to the Agent and run the same task. +Note that on the `PromptSummaryClient` we've set `off_prompt` to `False` so that the results of the query can be returned directly to the LLM. If we had kept it as `True`, the results would have been stored back Task Memory which would've put us back to square one. See [Task Memory Looping](#task-memory-looping) for more information on this scenario. ```python @@ -88,22 +85,43 @@ If we had kept it as `True`, the results would have been stored back Task Memory ``` ``` -[04/26/24 13:13:01] INFO ToolkitTask 5b46f9ef677c4b31906b48aba3f45e2c +[08/12/24 14:54:04] INFO ToolkitTask f7ebd8acc3d64e3ca9db82ef9ec4e65f Input: What is the square root of 12345? -[04/26/24 13:13:07] INFO Subtask 611d98ea5576430fbc63259420577ab2 - Thought: To find the square root of 12345, I can use the Calculator action with the expression "12345 ** 0.5". - Actions: [{"name": "Calculator", "path": "calculate", "input": {"values": {"expression": "12345 ** 0.5"}}, "tag": "sqrt_12345"}] -[04/26/24 13:13:08] INFO Subtask 611d98ea5576430fbc63259420577ab2 +[08/12/24 14:54:05] INFO Subtask 777693d039e74ed288f663742fdde2ea + Actions: [ + { + "tag": "call_DXSs19G27VOV7EmP3PoRwGZI", + "name": "Calculator", + "path": "calculate", + "input": { + "values": { + "expression": "12345 ** 0.5" + } + } + } + ] + INFO Subtask 777693d039e74ed288f663742fdde2ea Response: Output of "Calculator.calculate" was stored in memory with memory_name "TaskMemory" and artifact_namespace - "7554b69e1d414a469b8882e2266dcea1" -[04/26/24 13:13:15] INFO Subtask 32b9163a15644212be60b8fba07bd23b - Thought: The square root of 12345 has been calculated and stored in memory. I can retrieve this value using the TaskMemoryClient action with - the query path, providing the memory_name and artifact_namespace as input. - Actions: [{"tag": "retrieve_sqrt", "name": "TaskMemoryClient", "path": "query", "input": {"values": {"memory_name": "TaskMemory", - "artifact_namespace": "7554b69e1d414a469b8882e2266dcea1", "query": "What is the result of the calculation?"}}}] -[04/26/24 13:13:16] INFO Subtask 32b9163a15644212be60b8fba07bd23b - Response: The result of the calculation is 111.1080555135405. -[04/26/24 13:13:17] INFO ToolkitTask 5b46f9ef677c4b31906b48aba3f45e2c + "370853a8937f4dd7a9e923254459cff2" +[08/12/24 14:54:06] INFO Subtask c8394ca51f1f4ae1b715618a2c5c8120 + Actions: [ + { + "tag": "call_qqpsWEvAUGIcPLrwAHGuH6o3", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "370853a8937f4dd7a9e923254459cff2" + } + } + } + } + ] +[08/12/24 14:54:07] INFO Subtask c8394ca51f1f4ae1b715618a2c5c8120 + Response: The text contains a single numerical value: 111.1080555135405. +[08/12/24 14:54:08] INFO ToolkitTask f7ebd8acc3d64e3ca9db82ef9ec4e65f Output: The square root of 12345 is approximately 111.108. ``` @@ -125,8 +143,8 @@ When running this example, we get the following error: Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}} ``` -This is because the content of the webpage is too large to fit in the LLM's input token limit. We can fix this by storing the content in Task Memory, and then querying it with the `TaskMemoryClient`. -Note that we're setting `off_prompt` to `False` on the `TaskMemoryClient` so that the _queried_ content can be returned directly to the LLM. +This is because the content of the webpage is too large to fit in the LLM's input token limit. We can fix this by storing the content in Task Memory, and then querying it with the `QueryClient`. +Note that we're setting `off_prompt` to `False` on the `QueryClient` so that the _queried_ content can be returned directly to the LLM. ```python --8<-- "docs/griptape-framework/structures/src/task_memory_5.py" @@ -134,24 +152,46 @@ Note that we're setting `off_prompt` to `False` on the `TaskMemoryClient` so tha And now we get the expected output: ``` -[04/26/24 13:51:51] INFO ToolkitTask 7aca20f202df47a2b9848ed7025f9c21 +[08/12/24 14:56:18] INFO ToolkitTask d3ce58587dc944b0a30a205631b82944 Input: According to this page https://en.wikipedia.org/wiki/Elden_Ring, how many copies of Elden Ring have been sold? -[04/26/24 13:51:58] INFO Subtask 5b21d8ead32b4644abcd1e852bb5f512 - Thought: I need to scrape the content of the provided URL to find the information about how many copies of Elden Ring have been sold. - Actions: [{"name": "WebScraper", "path": "get_content", "input": {"values": {"url": "https://en.wikipedia.org/wiki/Elden_Ring"}}, "tag": - "scrape_elden_ring"}] -[04/26/24 13:52:04] INFO Subtask 5b21d8ead32b4644abcd1e852bb5f512 +[08/12/24 14:56:20] INFO Subtask 494850ec40fe474c83d48b5620c5dcbb + Actions: [ + { + "tag": "call_DGsOHC4AVxhV7RPVA7q3rATX", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://en.wikipedia.org/wiki/Elden_Ring" + } + } + } + ] +[08/12/24 14:56:25] INFO Subtask 494850ec40fe474c83d48b5620c5dcbb Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace - "2d4ebc7211074bb7be26613eb25d8fc1" -[04/26/24 13:52:11] INFO Subtask f12eb3d3b4924e4085808236b460b43d - Thought: Now that the webpage content is stored in memory, I need to query this memory to find the information about how many copies of Elden - Ring have been sold. - Actions: [{"tag": "query_sales", "name": "TaskMemoryClient", "path": "query", "input": {"values": {"memory_name": "TaskMemory", - "artifact_namespace": "2d4ebc7211074bb7be26613eb25d8fc1", "query": "How many copies of Elden Ring have been sold?"}}}] -[04/26/24 13:52:14] INFO Subtask f12eb3d3b4924e4085808236b460b43d - Response: Elden Ring sold 23 million copies by February 2024. -[04/26/24 13:52:15] INFO ToolkitTask 7aca20f202df47a2b9848ed7025f9c21 - Output: Elden Ring sold 23 million copies by February 2024. + "b9f53d6d9b35455aaf4d99719c1bfffa" +[08/12/24 14:56:26] INFO Subtask 8669ee523bb64550850566011bcd14e2 + Actions: [ + { + "tag": "call_DGsOHC4AVxhV7RPVA7q3rATX", + "name": "QueryClient", + "path": "search", + "input": { + "values": { + "query": "number of copies sold", + "content": { + "memory_name": "TaskMemory", + "artifact_namespace": "b9f53d6d9b35455aaf4d99719c1bfffa" + } + } + } + } + ] +[08/12/24 14:56:29] INFO Subtask 8669ee523bb64550850566011bcd14e2 + Response: "Elden Ring" sold 13.4 million copies worldwide by the end of March 2022 and 25 million by June 2024. The downloadable content (DLC) + "Shadow of the Erdtree" sold five million copies within three days of its release. +[08/12/24 14:56:30] INFO ToolkitTask d3ce58587dc944b0a30a205631b82944 + Output: Elden Ring sold 13.4 million copies worldwide by the end of March 2022 and 25 million by June 2024. ``` ## Sensitive Data @@ -166,61 +206,65 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ``` ``` -[06/21/24 16:00:01] INFO ToolkitTask 17f30ac14701490c8ef71508f420ea9f - Input: Use this page - https://en.wikipedia.org/wiki/Elden_Ring to find - how many copies of Elden Ring have been sold, and - then save the result to a file. -[06/21/24 16:00:05] INFO Subtask cb06889205334ec9afd7e97f7f231ab5 - Thought: First, I need to scrape the content of the - provided URL to find the information about how many - copies of Elden Ring have been sold. Then, I will - save this information to a file. - - Actions: [{"name": "WebScraper", "path": - "get_content", "input": {"values": {"url": - "https://en.wikipedia.org/wiki/Elden_Ring"}}, - "tag": "scrape_elden_ring"}] -[06/21/24 16:00:12] INFO Subtask cb06889205334ec9afd7e97f7f231ab5 - Response: Output of "WebScraper.get_content" was - stored in memory with memory_name "TaskMemory" and - artifact_namespace - "7e48bcff0da94ad3b06aa4e173f8f37b" -[06/21/24 16:00:17] INFO Subtask 56102d42475d413299ce52a0230506b7 - Thought: Now that the webpage content is stored in - memory, I need to query this memory to find the - information about how many copies of Elden Ring - have been sold. - Actions: [{"tag": "query_sales", "name": - "TaskMemoryClient", "path": "query", "input": - {"values": {"memory_name": "TaskMemory", - "artifact_namespace": - "7e48bcff0da94ad3b06aa4e173f8f37b", "query": "How - many copies of Elden Ring have been sold?"}}}] -[06/21/24 16:00:19] INFO Subtask 56102d42475d413299ce52a0230506b7 - Response: Output of "TaskMemoryClient.query" was - stored in memory with memory_name "TaskMemory" and - artifact_namespace - "9ecf4d7b7d0c46149dfc46ba236f178e" -[06/21/24 16:00:25] INFO Subtask ed2921791dcf46b68c9d8d2f8dbeddbd - Thought: Now that I have the sales information - stored in memory, I need to save this information - to a file. - Actions: [{"tag": "save_sales_info", "name": - "FileManager", "path": - "save_memory_artifacts_to_disk", "input": - {"values": {"dir_name": "sales_info", "file_name": - "elden_ring_sales.txt", "memory_name": - "TaskMemory", "artifact_namespace": - "9ecf4d7b7d0c46149dfc46ba236f178e"}}}] - INFO Subtask ed2921791dcf46b68c9d8d2f8dbeddbd - Response: Successfully saved memory artifacts to - disk -[06/21/24 16:00:27] INFO ToolkitTask 17f30ac14701490c8ef71508f420ea9f - Output: The information about how many copies of - Elden Ring have been sold has been successfully - saved to the file "elden_ring_sales.txt" in the - "sales_info" directory. +[08/12/24 14:55:21] INFO ToolkitTask 329b1abc760e4d30bbf23e349451d930 + Input: Use this page https://en.wikipedia.org/wiki/Elden_Ring to find how many copies of Elden Ring have been sold, and then save the result to + a file. +[08/12/24 14:55:23] INFO Subtask 26205b5623174424b618abafd886c4d8 + Actions: [ + { + "tag": "call_xMK0IyFZFbjlTapK7AA6kbNq", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://en.wikipedia.org/wiki/Elden_Ring" + } + } + } + ] +[08/12/24 14:55:28] INFO Subtask 26205b5623174424b618abafd886c4d8 + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "44b8f230645148d0b8d44354c0f2df5b" +[08/12/24 14:55:31] INFO Subtask d8b4cf297a0d4d9db04e4f8e63b746c8 + Actions: [ + { + "tag": "call_Oiqq6oI20yqmdNrH9Mawb2fS", + "name": "QueryClient", + "path": "search", + "input": { + "values": { + "query": "copies sold", + "content": { + "memory_name": "TaskMemory", + "artifact_namespace": "44b8f230645148d0b8d44354c0f2df5b" + } + } + } + } + ] +[08/12/24 14:55:34] INFO Subtask d8b4cf297a0d4d9db04e4f8e63b746c8 + Response: Output of "QueryClient.search" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "fd828ddd629e4974a7837f9dfde65954" +[08/12/24 14:55:38] INFO Subtask 7aafcb3fb0d845858e2fcf9b8dc8a7ec + Actions: [ + { + "tag": "call_nV1DIPAEhUEAVMCjXND0pKoS", + "name": "FileManager", + "path": "save_memory_artifacts_to_disk", + "input": { + "values": { + "dir_name": "results", + "file_name": "elden_ring_sales.txt", + "memory_name": "TaskMemory", + "artifact_namespace": "fd828ddd629e4974a7837f9dfde65954" + } + } + } + ] + INFO Subtask 7aafcb3fb0d845858e2fcf9b8dc8a7ec + Response: Successfully saved memory artifacts to disk +[08/12/24 14:55:40] INFO ToolkitTask 329b1abc760e4d30bbf23e349451d930 + Output: Successfully saved the number of copies sold of Elden Ring to a file named "elden_ring_sales.txt" in the "results" directory. ``` ## Tools That Can Read From Task Memory @@ -229,11 +273,10 @@ As seen in the previous example, certain Tools are designed to read directly fro Today, these include: -- [TaskMemoryClient](../../griptape-tools/official-tools/task-memory-client.md) +- [PromptSummaryClient](../../griptape-tools/official-tools/prompt-summary-client.md) +- [ExtractionClient](../../griptape-tools/official-tools/extraction-client.md) +- [RagClient](../../griptape-tools/official-tools/rag-client.md) - [FileManager](../../griptape-tools/official-tools/file-manager.md) -- [AwsS3Client](../../griptape-tools/official-tools/aws-s3-client.md) -- [GoogleDriveClient](../../griptape-tools/official-tools/google-drive-client.md) -- [GoogleDocsClient](../../griptape-tools/official-tools/google-docs-client.md) ## Task Memory Considerations diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index a57c1e604..cef7f653b 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -99,39 +99,63 @@ This Task takes in one or more Tools which the LLM will decide to use through Ch ``` ``` -[09/08/23 11:14:55] INFO ToolkitTask 22af656c6ad643e188fe80f9378dfff9 +[08/12/24 15:16:30] INFO ToolkitTask f5b44fe1dadc4e6688053df71d97e0de Input: Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt -[09/08/23 11:15:02] INFO Subtask 7a6356470e6a4b08b61edc5591b37f0c - Thought: The first step is to load the webpage using the WebScraper tool's get_content activity. - - Action: {"name": "WebScraper", "path": "get_content", "input": {"values": {"url": - "https://www.griptape.ai"}}} -[09/08/23 11:15:03] INFO Subtask 7a6356470e6a4b08b61edc5591b37f0c - Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and - artifact_namespace "2b50373849d140f698ba8071066437ee" -[09/08/23 11:15:11] INFO Subtask a22a7e4ebf594b4b895fcbe8a95c1dd3 - Thought: Now that the webpage content is stored in memory, I can use the TaskMemory tool's summarize activity - to summarize it. - Action: {"name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "2b50373849d140f698ba8071066437ee"}}} -[09/08/23 11:15:15] INFO Subtask a22a7e4ebf594b4b895fcbe8a95c1dd3 - Response: Griptape is an open source framework that allows developers to build and deploy AI applications - using large language models (LLMs). It provides the ability to create conversational and event-driven apps that - can access and manipulate data securely. Griptape enforces structures like sequential pipelines and DAG-based - workflows for predictability, while also allowing for creativity by safely prompting LLMs with external APIs and - data stores. The framework can be used to create AI systems that operate across both dimensions. Griptape Cloud - is a managed platform for deploying and managing AI apps, and it offers features like scheduling and connecting - to data stores and APIs. -[09/08/23 11:15:27] INFO Subtask 7afb3d44d0114b7f8ef2dac4314a8e90 - Thought: Now that I have the summary, I can use the FileManager tool's save_file_to_disk activity to store the - summary in a file named griptape.txt. - Action: {"name": "FileManager", "path": "save_file_to_disk", "input": {"values": - {"memory_name": "TaskMemory", "artifact_namespace": "2b50373849d140f698ba8071066437ee", "path": - "griptape.txt"}}} - INFO Subtask 7afb3d44d0114b7f8ef2dac4314a8e90 - Response: saved successfully -[09/08/23 11:15:31] INFO ToolkitTask 22af656c6ad643e188fe80f9378dfff9 - Output: The summary of the webpage https://www.griptape.ai has been successfully stored in a file named - griptape.txt. +[08/12/24 15:16:32] INFO Subtask a4483eddfbe84129b0f4c04ef0f5d695 + Actions: [ + { + "tag": "call_AFeOL9MGhZ4mPFCULcBEm4NQ", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://www.griptape.ai" + } + } + } + ] + INFO Subtask a4483eddfbe84129b0f4c04ef0f5d695 + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "c6a6bcfc16f34481a068108aeaa6838e" +[08/12/24 15:16:33] INFO Subtask ee5f11666ded4dc39b94e4c59d18fbc7 + Actions: [ + { + "tag": "call_aT7DX0YSQPmOcnumWXrGoMNt", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "c6a6bcfc16f34481a068108aeaa6838e" + } + } + } + } + ] +[08/12/24 15:16:37] INFO Subtask ee5f11666ded4dc39b94e4c59d18fbc7 + Response: Output of "PromptSummaryClient.summarize" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "669d29a704444176be93d09d014298df" +[08/12/24 15:16:38] INFO Subtask d9b2dd9f96d841f49f5d460e33905183 + Actions: [ + { + "tag": "call_QgMk1M1UuD6DAnxjfQz1MH6X", + "name": "FileManager", + "path": "save_memory_artifacts_to_disk", + "input": { + "values": { + "dir_name": ".", + "file_name": "griptape.txt", + "memory_name": "TaskMemory", + "artifact_namespace": "669d29a704444176be93d09d014298df" + } + } + } + ] + INFO Subtask d9b2dd9f96d841f49f5d460e33905183 + Response: Successfully saved memory artifacts to disk +[08/12/24 15:16:39] INFO ToolkitTask f5b44fe1dadc4e6688053df71d97e0de + Output: The content from https://www.griptape.ai has been summarized and stored in a file called `griptape.txt`. ``` ## Tool Task diff --git a/docs/griptape-framework/tools/index.md b/docs/griptape-framework/tools/index.md index c974b9658..727b8f78d 100644 --- a/docs/griptape-framework/tools/index.md +++ b/docs/griptape-framework/tools/index.md @@ -19,37 +19,72 @@ Here is an example of a Pipeline using Tools: ``` ``` -[09/08/23 10:53:56] INFO ToolkitTask 979d99f68766423ea05b367e951281bc - Input: Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt -[09/08/23 10:54:02] INFO Subtask 97bd154a71e14a1699f8152e50490a71 - Thought: The first step is to load the content of the webpage. I can use the WebScraper tool with the get_content - activity for this. - - Action: {"name": "WebScraper", "path": "get_content", "input": {"values": {"url": - "https://www.griptape.ai"}}} -[09/08/23 10:54:03] INFO Subtask 97bd154a71e14a1699f8152e50490a71 - Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and - artifact_namespace "9eb6f5828cf64356bf323f11d28be27e" -[09/08/23 10:54:09] INFO Subtask 7ee08458ce154e3d970711b7d3ed79ba - Thought: Now that the webpage content is stored in memory, I can use the TaskMemory tool with the summarize - activity to summarize the content. - Action: {"name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "9eb6f5828cf64356bf323f11d28be27e"}}} -[09/08/23 10:54:12] INFO Subtask 7ee08458ce154e3d970711b7d3ed79ba - Response: Griptape is an open source framework that allows developers to build and deploy AI applications - using large language models (LLMs). It provides the ability to create conversational and event-driven apps that - can access and manipulate data securely. Griptape enforces structures like sequential pipelines and workflows for - predictability, while also allowing for creativity by safely prompting LLMs with external APIs and data stores. - The framework can be used to create AI systems that operate across both predictability and creativity dimensions. - Griptape Cloud is a managed platform for deploying and managing AI apps. -[09/08/23 10:54:24] INFO Subtask a024949a9a134f058f2e6b7c379c8713 - Thought: Now that I have the summary, I can store it in a file called griptape.txt. I can use the FileManager - tool with the save_file_to_disk activity for this. - Action: {"name": "FileManager", "path": "save_file_to_disk", "input": {"values": - {"memory_name": "TaskMemory", "artifact_namespace": "9eb6f5828cf64356bf323f11d28be27e", "path": - "griptape.txt"}}} - INFO Subtask a024949a9a134f058f2e6b7c379c8713 - Response: saved successfully -[09/08/23 10:54:27] INFO ToolkitTask 979d99f68766423ea05b367e951281bc - Output: The summary of the webpage https://www.griptape.ai has been successfully stored in a file called - griptape.txt. +[08/12/24 15:18:19] INFO ToolkitTask 48ac0486e5374e1ea53e8d2b955e511f + Input: Load https://www.griptape.ai, summarize it, and store it in griptape.txt +[08/12/24 15:18:20] INFO Subtask 3b8365c077ae4a7e94087bfeff7a858c + Actions: [ + { + "tag": "call_P6vaURTXfiYBJZolTkUSRHRc", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://www.griptape.ai" + } + } + } + ] + INFO Subtask 3b8365c077ae4a7e94087bfeff7a858c + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "301e546f4450489ea4680645297092a2" +[08/12/24 15:18:21] INFO Subtask 930e9ca52e4140a48cce1e47368d45be + Actions: [ + { + "tag": "call_0VOTEvinRer7rG4oEirBYcow", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "301e546f4450489ea4680645297092a2" + } + } + } + } + ] +[08/12/24 15:18:24] INFO Subtask 930e9ca52e4140a48cce1e47368d45be + Response: Griptape offers a comprehensive solution for building, deploying, and scaling AI applications in the cloud. It provides developers + with a framework and cloud services to create retrieval-driven AI-powered applications without needing extensive knowledge of AI or prompt + engineering. The Griptape Framework allows developers to build business logic using Python, ensuring better security, performance, and + cost-efficiency. Griptape Cloud handles infrastructure management, enabling seamless deployment and scaling of applications. Key features + include automated data preparation (ETL), retrieval as a service (RAG), and a structure runtime (RUN) for building AI agents, pipelines, and + workflows. Griptape also offers solutions for custom projects, turnkey SaaS offerings, and finished applications. +[08/12/24 15:18:27] INFO Subtask d0f22504f576401f8d7e8ea78270a376 + Actions: [ + { + "tag": "call_zdUe2vdR0DCfR6LKcxjI6ayb", + "name": "FileManager", + "path": "save_content_to_file", + "input": { + "values": { + "path": "griptape.txt", + "content": "Griptape offers a comprehensive solution for building, deploying, and scaling AI applications in the cloud. It provides + developers with a framework and cloud services to create retrieval-driven AI-powered applications without needing extensive knowledge of AI or + prompt engineering. The Griptape Framework allows developers to build business logic using Python, ensuring better security, performance, and + cost-efficiency. Griptape Cloud handles infrastructure management, enabling seamless deployment and scaling of applications. Key features + include automated data preparation (ETL), retrieval as a service (RAG), and a structure runtime (RUN) for building AI agents, pipelines, and + workflows. Griptape also offers solutions for custom projects, turnkey SaaS offerings, and finished applications." + } + } + } + ] + INFO Subtask d0f22504f576401f8d7e8ea78270a376 + Response: Successfully saved file +[08/12/24 15:18:28] INFO ToolkitTask 48ac0486e5374e1ea53e8d2b955e511f + Output: The content from https://www.griptape.ai has been summarized and stored in griptape.txt. + INFO PromptTask 4a9c59b1c06d4c549373d243a12f1285 + Input: Say the following in spanish: The content from https://www.griptape.ai has been summarized and stored in griptape.txt. + INFO PromptTask 4a9c59b1c06d4c549373d243a12f1285 + Output: El contenido de https://www.griptape.ai ha sido resumido y almacenado en griptape.txt. ``` diff --git a/docs/griptape-framework/tools/src/index_1.py b/docs/griptape-framework/tools/src/index_1.py index 52d8ac8e7..0366a314e 100644 --- a/docs/griptape-framework/tools/src/index_1.py +++ b/docs/griptape-framework/tools/src/index_1.py @@ -1,13 +1,17 @@ from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManager, PromptSummaryClient, WebScraper pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( "Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt", - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[ + WebScraper(off_prompt=True), + FileManager(off_prompt=True), + PromptSummaryClient(off_prompt=False), + ], ), ) diff --git a/docs/griptape-tools/official-tools/aws-iam-client.md b/docs/griptape-tools/official-tools/aws-iam-client.md index 22b4115a3..a9b8a816e 100644 --- a/docs/griptape-tools/official-tools/aws-iam-client.md +++ b/docs/griptape-tools/official-tools/aws-iam-client.md @@ -6,50 +6,59 @@ This tool enables LLMs to make AWS IAM API requests. --8<-- "docs/griptape-tools/official-tools/src/aws_iam_client_1.py" ``` ``` -[09/11/23 16:45:45] INFO Task 890fcf77fb074c9490d5c91563e0c995 - Input: List all my IAM users -[09/11/23 16:45:51] INFO Subtask f2f0809ee10d4538972ed01fdd6a2fb8 - Thought: To list all IAM users, I can use the - AwsIamClient tool with the list_users activity. - This activity does not require any input. - - Action: {"name": "AwsIamClient", - "path": "list_users"} -[09/11/23 16:45:52] INFO Subtask f2f0809ee10d4538972ed01fdd6a2fb8 - Response: Output of "AwsIamClient.list_users" - was stored in memory with memory_name - "TaskMemory" and artifact_namespace - "51d22a018a434904a5da3bb8d4f763f7" -[09/11/23 16:45:59] INFO Subtask 8e0e918571544eeebf46de898466c48c - Thought: The output of the list_users activity is - stored in memory. I can retrieve this information - using the TaskMemory tool with the summarize - activity. - Action: {"name": "TaskMemoryClient", "path": - "summarize", "input": {"values": {"memory_name": - "TaskMemory", "artifact_namespace": - "51d22a018a434904a5da3bb8d4f763f7"}}} -[09/11/23 16:46:03] INFO Subtask 8e0e918571544eeebf46de898466c48c - Response: The text provides information about - two different users in an AWS IAM system. The first - user is named "example-user-1" and has a - user ID of "AIDASHBEHWJLQV2IOYDHM". The second user - is named "example-user-2" and - has a user ID of "AIDASHBEHWJLWHVS76C6X". Both - users have a path of "/", and their ARNs (Amazon - Resource Names) indicate their location in the IAM - system. The first user was created on July 18, - 2023, at 20:29:27 UTC, while the second user was - created on August 29, 2023, at 20:56:37 UTC. -[09/11/23 16:46:13] INFO Task 890fcf77fb074c9490d5c91563e0c995 - Output: There are two IAM users in your AWS - account: - - 1. User "example-user-1" with user ID - "AIDASHBEHWJLQV2IOYDHM", created on July 18, 2023, - at 20:29:27 UTC. - 2. User "example-user-2" with - user ID "AIDASHBEHWJLWHVS76C6X", created on August - 29, 2023, at 20:56:37 UTC. +[08/12/24 14:56:59] INFO ToolkitTask 12345abcd67890efghijk1112131415 + Input: List all my IAM users +[08/12/24 14:57:00] INFO Subtask 54321dcba09876fedcba1234567890ab + Actions: [ + { + "tag": "call_OxhQ9ITNIFq0WjkSnOCYAx8h", + "name": "AwsIamClient", + "path": "list_users", + "input": { + "values": {} + } + } + ] + INFO Subtask 54321dcba09876fedcba1234567890ab + Response: {'Path': '/', 'UserName': 'dummy-user-1', 'UserId': 'AIDAAAAAA1111AAAAAA1111', 'Arn': + 'arn:aws:iam::123456789012:user/dummy-user-1', 'CreateDate': datetime.datetime(2024, 8, 7, 15, 8, 7, tzinfo=tzutc())} + {'Path': '/', 'UserName': 'dummy-user-2', 'UserId': 'AIDBBBBBB2222BBBBBB2222', 'Arn': + 'arn:aws:iam::123456789012:user/dummy-user-2', 'CreateDate': datetime.datetime(2023, 7, 18, 20, 29, 27, tzinfo=tzutc())} + + {'Path': '/', 'UserName': 'dummy-user-3', 'UserId': 'AIDCCCCCC3333CCCCCC3333', 'Arn': + 'arn:aws:iam::123456789012:user/dummy-user-3', 'CreateDate': datetime.datetime(2024, 7, 15, 19, 39, 41, tzinfo=tzutc())} + + {'Path': '/', 'UserName': 'dummy-user-4', 'UserId': 'AIDDDDDDD4444DDDDDD4444', 'Arn': + 'arn:aws:iam::123456789012:user/dummy-user-4', 'CreateDate': datetime.datetime(2024, 8, 2, 19, 28, 31, tzinfo=tzutc())} + + {'Path': '/', 'UserName': 'dummy-user-5', 'UserId': 'AIDEEEEE5555EEEEE5555', 'Arn': + 'arn:aws:iam::123456789012:user/dummy-user-5', 'CreateDate': datetime.datetime(2023, 8, 29, 20, 56, 37, tzinfo=tzutc())} +[08/12/24 14:57:08] INFO ToolkitTask 12345abcd67890efghijk1112131415 + Output: Here are all your IAM users: + + 1. **Username:** dummy-user-1 + - **UserId:** AIDAAAAAA1111AAAAAA1111 + - **Arn:** arn:aws:iam::123456789012:user/dummy-user-1 + - **CreateDate:** 2024-08-07 + + 2. **Username:** dummy-user-2 + - **UserId:** AIDBBBBBB2222BBBBBB2222 + - **Arn:** arn:aws:iam::123456789012:user/dummy-user-2 + - **CreateDate:** 2023-07-18 + + 3. **Username:** dummy-user-3 + - **UserId:** AIDCCCCCC3333CCCCCC3333 + - **Arn:** arn:aws:iam::123456789012:user/dummy-user-3 + - **CreateDate:** 2024-07-15 + + 4. **Username:** dummy-user-4 + - **UserId:** AIDDDDDDD4444DDDDDD4444 + - **Arn:** arn:aws:iam::123456789012:user/dummy-user-4 + - **CreateDate:** 2024-08-02 + + 5. **Username:** dummy-user-5 + - **UserId:** AIDEEEEE5555EEEEE5555 + - **Arn:** arn:aws:iam::123456789012:user/dummy-user-5 + - **CreateDate:** 2023-08-29 ``` diff --git a/docs/griptape-tools/official-tools/aws-s3-client.md b/docs/griptape-tools/official-tools/aws-s3-client.md index 70ca79a20..12b292887 100644 --- a/docs/griptape-tools/official-tools/aws-s3-client.md +++ b/docs/griptape-tools/official-tools/aws-s3-client.md @@ -6,45 +6,36 @@ This tool enables LLMs to make AWS S3 API requests. --8<-- "docs/griptape-tools/official-tools/src/aws_s3_client_1.py" ``` ``` -[09/11/23 16:49:35] INFO Task 8bf7538e217a4b5a8472829f5eee75b9 - Input: List all my S3 buckets. -[09/11/23 16:49:41] INFO Subtask 9fc44f5c8e73447ba737283cb2ef7f5d - Thought: To list all S3 buckets, I can use the - "list_s3_buckets" activity of the "AwsS3Client" - tool. This activity doesn't require any input. - - Action: {"name": "AwsS3Client", - "path": "list_s3_buckets"} -[09/11/23 16:49:42] INFO Subtask 9fc44f5c8e73447ba737283cb2ef7f5d - Response: Output of - "AwsS3Client.list_s3_buckets" was stored in memory - with memory_name "TaskMemory" and - artifact_namespace - "f2592085fd4a430286a46770ea508cc9" -[09/11/23 16:49:50] INFO Subtask 0e9bb639a432431a92ef40a8c085ca0f - Thought: The output of the "list_s3_buckets" - activity is stored in memory. I can retrieve this - information using the "summarize" activity of the - "TaskMemory" tool. - Action: {"name": "TaskMemoryClient", "path": - "summarize", "input": {"values": {"memory_name": - "TaskMemory", "artifact_namespace": - "f2592085fd4a430286a46770ea508cc9"}}} -[09/11/23 16:49:52] INFO Subtask 0e9bb639a432431a92ef40a8c085ca0f - Response: The text consists of multiple - dictionaries, each containing a 'Name' and - 'CreationDate' key-value pair. The 'Name' - represents the name of a resource or bucket, while - the 'CreationDate' represents the date and time - when the resource or bucket was created. -[09/11/23 16:50:03] INFO Task 8bf7538e217a4b5a8472829f5eee75b9 - Output: The names of your S3 buckets are as - follows: - 1. Bucket Name: 'example-bucket-1', Creation Date: - '2022-01-01T00:00:00Z' - 2. Bucket Name: 'example-bucket-2', Creation Date: - '2022-01-02T00:00:00Z' - 3. Bucket Name: 'example-bucket-3', Creation Date: - '2022-01-03T00:00:00Z' - Please note that the creation dates are in UTC. +[08/12/24 14:51:36] INFO ToolkitTask bfc329ebc7d34497b429ab0d18ff7e7b + Input: List all my S3 buckets. +[08/12/24 14:51:37] INFO Subtask dfd07f9e204c4a3d8f55ca3eb9d37ec5 + Actions: [ + { + "tag": "call_pZQ05Zmm6lSbEcvPWt4XEDj6", + "name": "AwsS3Client", + "path": "list_s3_buckets", + "input": { + "values": {} + } + } + ] + INFO Subtask dfd07f9e204c4a3d8f55ca3eb9d37ec5 + Response: {'Name': 'dummy-bucket-1', 'CreationDate': datetime.datetime(2023, 9, 14, 15, 41, 46, + tzinfo=tzutc())} + + {'Name': 'dummy-bucket-2', 'CreationDate': datetime.datetime(2023, 9, 14, 15, 40, 33, tzinfo=tzutc())} + + {'Name': 'dummy-bucket-3', 'CreationDate': datetime.datetime(2023, 6, 23, 20, 19, 53, tzinfo=tzutc())} + + {'Name': 'dummy-bucket-4', 'CreationDate': datetime.datetime(2023, 8, 19, 17, 17, 13, tzinfo=tzutc())} + + {'Name': 'dummy-bucket-5', 'CreationDate': datetime.datetime(2024, 2, 15, 23, 17, 21, tzinfo=tzutc())} +[08/12/24 14:51:43] INFO ToolkitTask bfc329ebc7d34497b429ab0d18ff7e7b + Output: Here are all your S3 buckets: + + 1. dummy-bucket-1 (Created on 2023-09-14) + 2. dummy-bucket-2 (Created on 2023-09-14) + 3. dummy-bucket-3 (Created on 2023-06-23) + 4. dummy-bucket-4 (Created on 2023-08-19) + 5. dummy-bucket-5 (Created on 2024-02-15) ``` diff --git a/docs/griptape-tools/official-tools/computer.md b/docs/griptape-tools/official-tools/computer.md index 121224b20..b8d2489f0 100644 --- a/docs/griptape-tools/official-tools/computer.md +++ b/docs/griptape-tools/official-tools/computer.md @@ -8,80 +8,40 @@ You can specify a local working directory and environment variables during tool --8<-- "docs/griptape-tools/official-tools/src/computer_1.py" ``` ``` -[09/11/23 16:24:15] INFO Task d08009ee983c4286ba10f83bcf3080e6 - Input: Run this shell command for me: touch - my_new_file.txt -[09/11/23 16:24:21] INFO Subtask 1ec0f9ea528e44b89eb9d41da0e00856 - Thought: The user wants to create a new file named - "my_new_file.txt". I can do this by executing the - shell command "touch my_new_file.txt" using the - Computer tool with the execute_command activity. - - Action: {"name": "Computer", - "path": "execute_command", "input": {"values": - {"command": "touch my_new_file.txt"}}} -[09/11/23 16:24:22] INFO Subtask 1ec0f9ea528e44b89eb9d41da0e00856 - Response: Output of "Computer.execute_command" - was stored in memory with memory_name - "TaskMemory" and artifact_namespace - "54ffbc84a37a497480cab6ab4f904e7e" -[09/11/23 16:24:28] INFO Subtask f3e1b1d09b4c46babda27342680aa770 - Thought: The command has been executed and the - output is stored in memory. However, the "touch" - command does not produce any output when it - successfully creates a file. To confirm that the - file was created, I can list the files in the - current directory using the "ls" command. - - Action: {"name": "Computer", - "path": "execute_command", "input": {"values": - {"command": "ls"}}} -[09/11/23 16:24:29] INFO Subtask f3e1b1d09b4c46babda27342680aa770 - Response: Output of "Computer.execute_command" - was stored in memory with memory_name - "TaskMemory" and artifact_namespace - "82bc4937564e4901b7fc51fced45b643" -[09/11/23 16:24:40] INFO Subtask 36f0bbfdd7974e6cb52766ba21dc64e0 - Thought: The output of the "ls" command is stored - in memory. I need to retrieve this output to check - if "my_new_file.txt" is listed, which would confirm - that the file was created successfully. - Action: {"name": "TaskMemoryClient", "path": - "query", "input": {"values": {"memory_name": - "TaskMemory", "artifact_namespace": - "82bc4937564e4901b7fc51fced45b643", "query": "Is - my_new_file.txt in the list of files?"}}} -[09/11/23 16:24:41] INFO Subtask 36f0bbfdd7974e6cb52766ba21dc64e0 - Response: Yes. -[09/11/23 16:24:42] INFO Task d08009ee983c4286ba10f83bcf3080e6 - Output: The file "my_new_file.txt" has been - successfully created. - INFO Task d08009ee983c4286ba10f83bcf3080e6 - Input: Run this shell command for me: echo 'This is - the content of the file.' > my_new_file.txt -[09/11/23 16:24:53] INFO Subtask a0a3fb162d6d4f3398a98c6d3604a491 - Thought: The user wants to write the text 'This is - the content of the file.' into the file - 'my_new_file.txt'. I can achieve this by using the - 'execute_command' activity of the 'Computer' tool. - - Action: {"name": "Computer", - "path": "execute_command", "input": {"values": - {"command": "echo 'This is the content of the - file.' > my_new_file.txt"}}} - INFO Subtask a0a3fb162d6d4f3398a98c6d3604a491 - Response: Output of "Computer.execute_command" - was stored in memory with memory_name - "TaskMemory" and artifact_namespace - "ec20f2e7ec674e0286c8d1f05d528957" -[09/11/23 16:25:00] INFO Task d08009ee983c4286ba10f83bcf3080e6 - Output: The text 'This is the content of the file.' - has been successfully written into - 'my_new_file.txt'. - INFO Task d08009ee983c4286ba10f83bcf3080e6 - Input: Run this shell command for me: cat - my_new_file.txt -[09/11/23 16:25:10] INFO Task d08009ee983c4286ba10f83bcf3080e6 - Output: The content of the file 'my_new_file.txt' - is: 'This is the content of the file.' +❮ poetry run python src/docs/task-memory.py +[08/12/24 15:13:56] INFO ToolkitTask 203ee958d1934811afe0bb86fb246e86 + Input: Make 2 files and then list the files in the current directory +[08/12/24 15:13:58] INFO Subtask eb4e843b6f37498f9f0e85ada68114ac + Actions: [ + { + "tag": "call_S17vPQsMCqWY1Lt5x8NtDnTK", + "name": "Computer", + "path": "execute_command", + "input": { + "values": { + "command": "touch file1.txt file2.txt" + } + } + } + ] + INFO Subtask eb4e843b6f37498f9f0e85ada68114ac + Response: Tool returned an empty value +[08/12/24 15:13:59] INFO Subtask 032770e7697d44f6a0c8559bfea60420 + Actions: [ + { + "tag": "call_n61SVDYUGWTt681BaDSaHgt1", + "name": "Computer", + "path": "execute_command", + "input": { + "values": { + "command": "ls" + } + } + } + ] + INFO Subtask 032770e7697d44f6a0c8559bfea60420 + Response: file1.txt + file2.txt +[08/12/24 15:14:00] INFO ToolkitTask 203ee958d1934811afe0bb86fb246e86 + Output: file1.txt, file2.txt ``` diff --git a/docs/griptape-tools/official-tools/sql-client.md b/docs/griptape-tools/official-tools/sql-client.md index 1d0d7abb0..4d8a27fb1 100644 --- a/docs/griptape-tools/official-tools/sql-client.md +++ b/docs/griptape-tools/official-tools/sql-client.md @@ -6,38 +6,51 @@ This tool enables LLMs to execute SQL statements via [SQLAlchemy](https://www.sq --8<-- "docs/griptape-tools/official-tools/src/sql_client_1.py" ``` ``` -[09/11/23 17:02:55] INFO Task d8331f8705b64b4b9d9a88137ed73f3f - Input: SELECT * FROM people; -[09/11/23 17:03:02] INFO Subtask 46c2f8926ce9469e9ca6b1b3364e3e41 - Thought: The user wants to retrieve all records - from the 'people' table. I can use the SqlClient - tool to execute this query. - - Action: {"name": "SqlClient", - "path": "execute_query", "input": {"values": - {"sql_query": "SELECT * FROM people;"}}} -[09/11/23 17:03:03] INFO Subtask 46c2f8926ce9469e9ca6b1b3364e3e41 - Response: Output of "SqlClient.execute_query" - was stored in memory with memory_name - "TaskMemory" and artifact_namespace - "217715ba3e444e4985bee223df5716a8" -[09/11/23 17:03:11] INFO Subtask e51f05449647482caa3051378ab5cb8c - Thought: The output of the SQL query has been - stored in memory. I can retrieve this data using - the TaskMemory's 'summarize' activity. - Action: {"name": "TaskMemoryClient", "path": - "summarize", "input": {"values": {"memory_name": - "TaskMemory", "artifact_namespace": - "217715ba3e444e4985bee223df5716a8"}}} -[09/11/23 17:03:12] INFO Subtask e51f05449647482caa3051378ab5cb8c - Response: The text includes a list of employees - with their respective IDs, names, positions. There - are two employees named Tanya Cooley who are both - managers, and two employees named John Doe who are - both coders. -[09/11/23 17:03:17] INFO Task d8331f8705b64b4b9d9a88137ed73f3f - Output: The 'people' table contains records of - several employees. Notably, there are two employees - named Tanya Cooley who are both managers, and two - employees named John Doe who are both coders. +[08/12/24 14:59:31] INFO ToolkitTask e302f7315d1a4f939e0125103ff4f09f + Input: SELECT * FROM people; +[08/12/24 14:59:34] INFO Subtask 809d1a281b85447f90706d431b77b845 + Actions: [ + { + "tag": "call_dCxHWwPwgmDvDKVd3QeOzyuT", + "name": "SqlClient", + "path": "execute_query", + "input": { + "values": { + "sql_query": "SELECT * FROM people" + } + } + } + ] +[08/12/24 14:59:35] INFO Subtask 809d1a281b85447f90706d431b77b845 + Response: 1,Lee,Andrews,"Engineer, electrical" + + 2,Michael,Woods,"Therapist, art" + + 3,Joshua,Allen,"Therapist, sports" + + 4,Eric,Foster,English as a second language teacher + + 5,John,Daniels,Printmaker + + 6,Matthew,Barton,Podiatrist + + 7,Audrey,Wilson,IT technical support officer + + 8,Leah,Knox,"Social research officer, government" + + 9,David,Macdonald,Public relations account executive + + 10,Erica,Ramos,"Accountant, chartered public finance" +[08/12/24 14:59:43] INFO ToolkitTask e302f7315d1a4f939e0125103ff4f09f + Output: + 1. Lee Andrews - Engineer, electrical + 2. Michael Woods - Therapist, art + 3. Joshua Allen - Therapist, sports + 4. Eric Foster - English as a second language teacher + 5. John Daniels - Printmaker + 6. Matthew Barton - Podiatrist + 7. Audrey Wilson - IT technical support officer + 8. Leah Knox - Social research officer, government + 9. David Macdonald - Public relations account executive + 10. Erica Ramos - Accountant, chartered public finance ``` diff --git a/docs/griptape-tools/official-tools/src/aws_s3_client_1.py b/docs/griptape-tools/official-tools/src/aws_s3_client_1.py index e1ba42525..c24d283b0 100644 --- a/docs/griptape-tools/official-tools/src/aws_s3_client_1.py +++ b/docs/griptape-tools/official-tools/src/aws_s3_client_1.py @@ -1,13 +1,13 @@ import boto3 from griptape.structures import Agent -from griptape.tools import AwsS3Client, TaskMemoryClient +from griptape.tools import AwsS3Client # Initialize the AWS S3 client aws_s3_client = AwsS3Client(session=boto3.Session(), off_prompt=True) # Create an agent with the AWS S3 client tool -agent = Agent(tools=[aws_s3_client, TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[aws_s3_client]) # Task to list all the AWS S3 buckets agent.run("List all my S3 buckets.") diff --git a/docs/griptape-tools/official-tools/src/computer_1.py b/docs/griptape-tools/official-tools/src/computer_1.py index 7fa22a46b..e7f136532 100644 --- a/docs/griptape-tools/official-tools/src/computer_1.py +++ b/docs/griptape-tools/official-tools/src/computer_1.py @@ -7,13 +7,4 @@ # Create an agent with the Computer tool agent = Agent(tools=[computer]) -# Create a file using the shell command -filename = "my_new_file.txt" -agent.run(f"Run this shell command for me: touch {filename}") - -# Add content to the file using the shell command -content = "This is the content of the file." -agent.run(f"Run this shell command for me: echo '{content}' > {filename}") - -# Output the contents of the file using the shell command -agent.run(f"Run this shell command for me: cat {filename}") +agent.run("Make 2 files and then list the files in the current directory") diff --git a/docs/griptape-tools/official-tools/src/task_memory_client_1.py b/docs/griptape-tools/official-tools/src/task_memory_client_1.py deleted file mode 100644 index e9c2562a1..000000000 --- a/docs/griptape-tools/official-tools/src/task_memory_client_1.py +++ /dev/null @@ -1,4 +0,0 @@ -from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper - -Agent(tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)]) diff --git a/docs/griptape-tools/official-tools/src/vector_store_client_1.py b/docs/griptape-tools/official-tools/src/vector_store_client_1.py index df9117960..c4e0c5bd9 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_client_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_client_1.py @@ -2,7 +2,7 @@ from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, VectorStoreClient +from griptape.tools import PromptSummaryClient, VectorStoreClient vector_store_driver = LocalVectorStoreDriver( embedding_driver=OpenAiEmbeddingDriver(), @@ -20,6 +20,6 @@ off_prompt=True, ) -agent = Agent(tools=[vector_db, TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[vector_db, PromptSummaryClient()]) agent.run("what is Griptape?") diff --git a/docs/griptape-tools/official-tools/src/web_scraper_1.py b/docs/griptape-tools/official-tools/src/web_scraper_1.py index 138e8600f..f858d558d 100644 --- a/docs/griptape-tools/official-tools/src/web_scraper_1.py +++ b/docs/griptape-tools/official-tools/src/web_scraper_1.py @@ -1,6 +1,6 @@ from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import PromptSummaryClient, WebScraper -agent = Agent(tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)]) agent.run("Based on https://www.griptape.ai/, tell me what griptape is") diff --git a/docs/griptape-tools/official-tools/task-memory-client.md b/docs/griptape-tools/official-tools/task-memory-client.md deleted file mode 100644 index fa88c85b9..000000000 --- a/docs/griptape-tools/official-tools/task-memory-client.md +++ /dev/null @@ -1,7 +0,0 @@ -# TaskMemoryClient - -This tool enables LLMs to query and summarize task outputs that are stored in short-term tool memory. This tool uniquely requires the user to set the `off_prompt` property explicitly for usability reasons (Griptape doesn't provide the default `True` value). - -```python ---8<-- "docs/griptape-tools/official-tools/src/task_memory_client_1.py" -``` diff --git a/docs/griptape-tools/official-tools/web-scraper.md b/docs/griptape-tools/official-tools/web-scraper.md index 5d8e1fe27..a4ceebfc5 100644 --- a/docs/griptape-tools/official-tools/web-scraper.md +++ b/docs/griptape-tools/official-tools/web-scraper.md @@ -6,62 +6,92 @@ This tool enables LLMs to scrape web pages for full text, summaries, authors, ti --8<-- "docs/griptape-tools/official-tools/src/web_scraper_1.py" ``` ``` -[09/11/23 15:27:39] INFO Task dd9ad12c5c1e4280a6e20d7c116303ed - Input: Based on https://www.griptape.ai/, tell me - what griptape is -[09/11/23 15:27:47] INFO Subtask 4b34be74b06a47ba9cb3a4b62aa35907 - Thought: I need to find out what griptape is based - on the information provided on the website - https://www.griptape.ai/. I can use the WebScraper - tool with the get_content activity to load the - content of the website. - - Action: {"name": "WebScraper", - "path": "get_content", "input": {"values": - {"url": "https://www.griptape.ai/"}}} -[09/11/23 15:27:48] INFO Subtask 4b34be74b06a47ba9cb3a4b62aa35907 - Response: Output of "WebScraper.get_content" was - stored in memory with memory_name "TaskMemory" - and artifact_namespace - "02da5930b8d74f7ca30aecc3760a3318" -[09/11/23 15:27:59] INFO Subtask 5b255e3e98aa401295f77532bc779390 - Thought: The content of the website has been stored - in memory. I can use the TaskMemory tool with - the summarize activity to get a summary of the - content. - Action: {"name": "TaskMemoryClient", "path": - "summarize", "input": {"values": {"memory_name": - "TaskMemory", "artifact_namespace": - "02da5930b8d74f7ca30aecc3760a3318"}}} -[09/11/23 15:28:03] INFO Subtask 5b255e3e98aa401295f77532bc779390 - Response: Griptape is an open source framework - that allows developers to build and deploy AI - applications using large language models (LLMs). It - provides the ability to create conversational and - event-driven apps that can access and manipulate - data securely. Griptape enforces structures like - sequential pipelines and DAG-based workflows for - predictability, while also allowing for creativity - by safely prompting LLMs with external APIs and - data stores. The framework can be used to create AI - systems that operate across both dimensions. - Griptape Cloud is a managed platform for deploying - and managing AI apps, and it offers features like - scheduling and connecting to data stores and APIs. -[09/11/23 15:28:12] INFO Task dd9ad12c5c1e4280a6e20d7c116303ed - Output: Griptape is an open source framework that - enables developers to build and deploy AI - applications using large language models (LLMs). It - allows the creation of conversational and - event-driven apps that can securely access and - manipulate data. Griptape enforces structures like - sequential pipelines and DAG-based workflows for - predictability, while also allowing for creativity - by safely prompting LLMs with external APIs and - data stores. The framework can be used to create AI - systems that operate across both dimensions. - Additionally, Griptape Cloud is a managed platform - for deploying and managing AI apps, offering - features like scheduling and connecting to data - stores and APIs. +[08/12/24 15:32:08] INFO ToolkitTask b14a4305365f4b17a4dcf235f84397e2 + Input: Based on https://www.griptape.ai/, tell me what griptape is +[08/12/24 15:32:10] INFO Subtask bf396977ea634eb28f55388d3f828f5d + Actions: [ + { + "tag": "call_ExEzJDZuBfnsa9pZMSr6mtsS", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://www.griptape.ai/" + } + } + } + ] + INFO Subtask bf396977ea634eb28f55388d3f828f5d + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "a55c85bf1aa944d5b69bbe8d61382179" +[08/12/24 15:32:11] INFO Subtask 31852039bd274b71bf46feaf22b68112 + Actions: [ + { + "tag": "call_6Dovx2GKE2GLjaYIuwXvBxVn", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "a55c85bf1aa944d5b69bbe8d61382179" + } + } + } + } + ] +[08/12/24 15:32:15] INFO Subtask 31852039bd274b71bf46feaf22b68112 + Response: Griptape offers a comprehensive solution for building, deploying, and scaling AI applications in the cloud. It provides developers + with a framework and cloud services to create retrieval-driven AI-powered applications without needing extensive knowledge in AI or prompt + engineering. + + **Griptape Framework:** + - Enables developers to build AI applications using Python. + - Offers better security, performance, and cost-efficiency with Off-Prompt™ technology. + - Facilitates the creation of Gen AI Agents, Systems of Agents, Pipelines, Workflows, and RAG implementations. + + **Griptape Cloud:** + - Simplifies deployment and execution of ETL, RAG, and other structures. + - Provides API abstractions and eliminates the need for infrastructure management. + - Supports seamless scaling to accommodate growing workloads. + + **Solutions & Applications:** + - Custom project development. + - Turnkey SaaS offerings for non-tech businesses. + - Ready-made apps and options to offer apps to customers. + + **Key Features:** + - Automated Data Prep (ETL): Connect, extract, transform, and load data into a vector database index. + - Retrieval as a Service (RAG): Generate answers, summaries, and details from your data using customizable retrieval patterns. + - Structure Runtime (RUN): Build and integrate AI agents, pipelines, and workflows into client applications. +[08/12/24 15:32:21] INFO ToolkitTask b14a4305365f4b17a4dcf235f84397e2 + Output: Griptape is a comprehensive solution designed to facilitate the building, deploying, and scaling of AI applications in the cloud. It + provides developers with a framework and cloud services that simplify the creation of retrieval-driven AI-powered applications, even for those + without extensive AI or prompt engineering expertise. + + ### Key Components of Griptape: + + 1. **Griptape Framework:** + - **Development:** Allows developers to build AI applications using Python. + - **Technology:** Utilizes Off-Prompt™ technology for enhanced security, performance, and cost-efficiency. + - **Capabilities:** Supports the creation of Gen AI Agents, Systems of Agents, Pipelines, Workflows, and Retrieval-Augmented Generation (RAG) + implementations. + + 2. **Griptape Cloud:** + - **Deployment:** Simplifies the deployment and execution of ETL (Extract, Transform, Load), RAG, and other structures. + - **API Abstractions:** Provides API abstractions to eliminate the need for infrastructure management. + - **Scalability:** Supports seamless scaling to accommodate growing workloads. + + ### Solutions & Applications: + - **Custom Projects:** Development of tailored AI solutions. + - **Turnkey SaaS:** Ready-to-use SaaS offerings for non-technical businesses. + - **Ready-made Apps:** Pre-built applications and options to offer apps to customers. + + ### Key Features: + - **Automated Data Prep (ETL):** Connects, extracts, transforms, and loads data into a vector database index. + - **Retrieval as a Service (RAG):** Generates answers, summaries, and details from data using customizable retrieval patterns. + - **Structure Runtime (RUN):** Facilitates the building and integration of AI agents, pipelines, and workflows into client applications. + + In summary, Griptape provides a robust platform for developing and managing AI applications, making it accessible for developers and businesses + to leverage AI technology effectively. ``` diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 6c566d1eb..c87557225 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -9,7 +9,6 @@ from .file_manager.tool import FileManager from .vector_store_client.tool import VectorStoreClient from .date_time.tool import DateTime -from .task_memory_client.tool import TaskMemoryClient from .base_aws_client import BaseAwsClient from .aws_iam_client.tool import AwsIamClient from .aws_s3_client.tool import AwsS3Client @@ -32,6 +31,7 @@ from .audio_transcription_client.tool import AudioTranscriptionClient from .extraction_client.tool import ExtractionClient from .prompt_summary_client.tool import PromptSummaryClient +from .query_client.tool import QueryClient __all__ = [ "BaseTool", @@ -53,7 +53,6 @@ "FileManager", "VectorStoreClient", "DateTime", - "TaskMemoryClient", "Computer", "OpenWeatherClient", "PromptImageGenerationClient", @@ -68,4 +67,5 @@ "AudioTranscriptionClient", "ExtractionClient", "PromptSummaryClient", + "QueryClient", ] diff --git a/griptape/tools/task_memory_client/__init__.py b/griptape/tools/query_client/__init__.py similarity index 100% rename from griptape/tools/task_memory_client/__init__.py rename to griptape/tools/query_client/__init__.py diff --git a/griptape/tools/query_client/manifest.yml b/griptape/tools/query_client/manifest.yml new file mode 100644 index 000000000..086a86d5a --- /dev/null +++ b/griptape/tools/query_client/manifest.yml @@ -0,0 +1,5 @@ +version: "v1" +name: Query Client +description: Tool for performing a query against data. +contact_email: hello@griptape.ai +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/query_client/requirements.txt b/griptape/tools/query_client/requirements.txt new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/query_client/tool.py b/griptape/tools/query_client/tool.py new file mode 100644 index 000000000..44ff65387 --- /dev/null +++ b/griptape/tools/query_client/tool.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from attrs import Factory, define, field +from schema import Literal, Or, Schema + +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.config import config +from griptape.engines.rag import RagEngine +from griptape.engines.rag.modules import ( + PromptResponseRagModule, +) +from griptape.engines.rag.rag_context import RagContext +from griptape.engines.rag.stages import ResponseRagStage +from griptape.mixins.rule_mixin import RuleMixin +from griptape.tools.rag_client.tool import RagClient +from griptape.utils.decorators import activity + + +@define(kw_only=True) +class QueryClient(RagClient, RuleMixin): + """Tool for performing a query against data.""" + + description: str = field(init=False) + rag_engine: RagEngine = field( + default=Factory( + lambda self: RagEngine( + response_stage=ResponseRagStage( + response_modules=[ + PromptResponseRagModule(prompt_driver=config.drivers.prompt, rulesets=self.rulesets) + ], + ), + ), + takes_self=True, + ), + ) + + @activity( + config={ + "description": "Can be used to search through textual content.", + "schema": Schema( + { + Literal("query", description="A natural language search query"): str, + Literal("content"): Or( + str, + Schema( + { + "memory_name": str, + "artifact_namespace": str, + } + ), + ), + } + ), + }, + ) + def query(self, params: dict) -> BaseArtifact: + query = params["values"]["query"] + summary = params["values"]["content"] + + if isinstance(summary, str): + text_artifacts = [TextArtifact(summary)] + else: + memory = self.find_input_memory(summary["memory_name"]) + artifact_namespace = summary["artifact_namespace"] + + if memory is not None: + artifacts = memory.load_artifacts(artifact_namespace) + else: + return ErrorArtifact("memory not found") + + text_artifacts = [artifact for artifact in artifacts if isinstance(artifact, TextArtifact)] + + outputs = self.rag_engine.process(RagContext(query=query, text_chunks=text_artifacts)).outputs + + if len(outputs) > 0: + return ListArtifact(outputs) + else: + return ErrorArtifact("query output is empty") diff --git a/griptape/tools/rag_client/tool.py b/griptape/tools/rag_client/tool.py index 5affd49a0..613e254af 100644 --- a/griptape/tools/rag_client/tool.py +++ b/griptape/tools/rag_client/tool.py @@ -1,22 +1,20 @@ from __future__ import annotations -from attrs import Factory, define, field -from schema import Literal, Or, Schema +from typing import TYPE_CHECKING -from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact -from griptape.engines.rag import RagEngine -from griptape.engines.rag.modules import ( - PromptResponseRagModule, -) -from griptape.engines.rag.rag_context import RagContext -from griptape.engines.rag.stages import ResponseRagStage -from griptape.mixins.rule_mixin import RuleMixin +from attrs import define, field +from schema import Literal, Schema + +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity +if TYPE_CHECKING: + from griptape.engines.rag import RagEngine + @define(kw_only=True) -class RagClient(BaseTool, RuleMixin): +class RagClient(BaseTool): """Tool for querying a RAG engine. Attributes: @@ -25,57 +23,19 @@ class RagClient(BaseTool, RuleMixin): """ description: str = field() - rag_engine: RagEngine = field( - default=Factory( - lambda self: RagEngine( - response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=self.config.prompt_driver, rulesets=self.rulesets) - ], - ), - ), - takes_self=True, - ) - ) + rag_engine: RagEngine = field() @activity( config={ - "description": "Can be used to search content with the following description: {{ _self.description }}", - "schema": Schema( - { - Literal("query", description="A natural language search query"): str, - Literal("content"): Or( - str, - Schema( - { - "memory_name": str, - "artifact_namespace": str, - } - ), - ), - } - ), + "description": "{{ _self.description }}", + "schema": Schema({Literal("query", description="A natural language search query"): str}), }, ) def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] - summary = params["values"]["content"] - - if isinstance(summary, str): - text_artifacts = [TextArtifact(summary)] - else: - memory = self.find_input_memory(summary["memory_name"]) - artifact_namespace = summary["artifact_namespace"] - - if memory is not None: - artifacts = memory.load_artifacts(artifact_namespace) - else: - return ErrorArtifact("memory not found") - - text_artifacts = [artifact for artifact in artifacts if isinstance(artifact, TextArtifact)] try: - outputs = self.rag_engine.process(RagContext(query=query, text_chunks=text_artifacts)).outputs + outputs = self.rag_engine.process_query(query).outputs if len(outputs) > 0: return ListArtifact(outputs) diff --git a/griptape/tools/task_memory_client/manifest.yml b/griptape/tools/task_memory_client/manifest.yml deleted file mode 100644 index 0bff1af3d..000000000 --- a/griptape/tools/task_memory_client/manifest.yml +++ /dev/null @@ -1,5 +0,0 @@ -version: "v1" -name: Task Memory Client -description: Tool for summarizing and querying TaskMemory. -contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/task_memory_client/tool.py b/griptape/tools/task_memory_client/tool.py deleted file mode 100644 index a20d63506..000000000 --- a/griptape/tools/task_memory_client/tool.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import define -from schema import Literal, Schema - -from griptape.tools import BaseTool -from griptape.utils.decorators import activity - -if TYPE_CHECKING: - from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact - - -@define -class TaskMemoryClient(BaseTool): - @activity( - config={ - "description": "Can be used to summarize memory content", - "schema": Schema({"memory_name": str, "artifact_namespace": str}), - }, - ) - def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact: ... - - @activity( - config={ - "description": "Can be used to search and query memory content", - "schema": Schema( - { - "memory_name": str, - "artifact_namespace": str, - Literal( - "query", - description="A natural language search query in the form of a question with enough " - "contextual information for another person to understand what the query is about", - ): str, - }, - ), - }, - ) - def query(self, params: dict) -> BaseArtifact: ... diff --git a/mkdocs.yml b/mkdocs.yml index 175918e87..bb65d99f3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -144,7 +144,6 @@ nav: - OpenWeatherClient: "griptape-tools/official-tools/openweather-client.md" - RestApiClient: "griptape-tools/official-tools/rest-api-client.md" - SqlClient: "griptape-tools/official-tools/sql-client.md" - - TaskMemoryClient: "griptape-tools/official-tools/task-memory-client.md" - VectorStoreClient: "griptape-tools/official-tools/vector-store-client.md" - WebScraper: "griptape-tools/official-tools/web-scraper.md" - WebSearch: "griptape-tools/official-tools/web-search.md" diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index 8dfcfdc73..4b61cb9e5 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -14,7 +14,7 @@ def structure_tester(self, request): from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent - from griptape.tools import TaskMemoryClient, WebScraper, WebSearch + from griptape.tools import PromptSummaryClient, WebScraper, WebSearch return StructureTester( Agent( @@ -25,7 +25,7 @@ def structure_tester(self, request): ) ), WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=False), + PromptSummaryClient(off_prompt=False), ], conversation_memory=None, prompt_driver=request.param, From e849e6e162decea7a925058906952750ad94c6b1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 12 Aug 2024 15:59:03 -0700 Subject: [PATCH 46/63] Add docs --- CHANGELOG.md | 1 + .../official-tools/extraction-client.md | 55 +++++++++ .../official-tools/prompt-summary-client.md | 107 ++++++++++++++++++ .../official-tools/query-client.md | 87 ++++++++++++++ .../official-tools/src/extraction_client_1.py | 28 +++++ .../src/prompt_summary_client_1.py | 8 ++ .../official-tools/src/query_client_1.py | 6 + mkdocs.yml | 3 + poetry.lock | 12 +- pyproject.toml | 2 +- 10 files changed, 303 insertions(+), 6 deletions(-) create mode 100644 docs/griptape-tools/official-tools/extraction-client.md create mode 100644 docs/griptape-tools/official-tools/prompt-summary-client.md create mode 100644 docs/griptape-tools/official-tools/query-client.md create mode 100644 docs/griptape-tools/official-tools/src/extraction_client_1.py create mode 100644 docs/griptape-tools/official-tools/src/prompt_summary_client_1.py create mode 100644 docs/griptape-tools/official-tools/src/query_client_1.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e92817b9..05bb8b4c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Unique name generation for all `RagEngine` modules. - `ExtractionClient` Tool for having the LLM extract structured data from text. - `PromptSummaryClient` Tool for having the LLM summarize text. +- `QueryClient` Tool for having hte LLM queyr text. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. diff --git a/docs/griptape-tools/official-tools/extraction-client.md b/docs/griptape-tools/official-tools/extraction-client.md new file mode 100644 index 000000000..ad2617219 --- /dev/null +++ b/docs/griptape-tools/official-tools/extraction-client.md @@ -0,0 +1,55 @@ +The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. + +Here is an example of how it can be used with a local vector store driver: + +```python +--8<-- "docs/griptape-tools/official-tools/src/rag_client_1.py" +``` +``` +[08/12/24 15:58:03] INFO ToolkitTask 43b3d209a83c470d8371b7ef4af175b4 + Input: Load https://griptape.ai and extract key info +[08/12/24 15:58:05] INFO Subtask 6a9a63802faf4717bab24bbbea2cb49b + Actions: [ + { + "tag": "call_SgrmWdXaYTQ1Cz9iB0iIZSYD", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://griptape.ai" + } + } + } + ] +[08/12/24 15:58:06] INFO Subtask 6a9a63802faf4717bab24bbbea2cb49b + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "bf1c865b82554c9e896cb514bb86844c" +[08/12/24 15:58:07] INFO Subtask c06388d6079541d5aaff25c30e322c51 + Actions: [ + { + "tag": "call_o3MrpM01OnhCfpxsMe85tpDF", + "name": "ExtractionClient", + "path": "extract_json", + "input": { + "values": { + "data": { + "memory_name": "TaskMemory", + "artifact_namespace": "bf1c865b82554c9e896cb514bb86844c" + } + } + } + } + ] +[08/12/24 15:58:11] INFO Subtask c06388d6079541d5aaff25c30e322c51 + Response: {"company_name": "Griptape", "industry": "AI Applications", "product_features": ["Turn any developer into an AI developer.", "Build + your business logic using predictable, programmable python.", "Off-Prompt\u2122 for better security, performance, and lower costs.", "Deploy and + run the ETL, RAG, and structures you developed.", "Simple API abstractions.", "Skip the infrastructure management.", "Scale seamlessly with + workload requirements.", "Clean and clear abstractions for building Gen AI Agents, Systems of Agents, Pipelines, Workflows, and RAG + implementations.", "Build ETL pipelines to prep data for secure LLM access.", "Compose retrieval patterns for fast, accurate, detailed + information.", "Write agents, pipelines, and workflows to integrate business logic.", "Automated Data Prep (ETL): Connect any data source, + extract, prep/transform, and load into a vector database index.", "Retrieval as a Service (RAG): Generate answers, summaries, and details from + your own data with ready-made or custom retrieval patterns.", "Structure Runtime (RUN): Build AI agents, pipelines, and workflows for real-time + interfaces, transactional processes, and batch workloads."]} +[08/12/24 15:58:14] INFO ToolkitTask 43b3d209a83c470d8371b7ef4af175b4 + Output: Extracted key information from Griptape's website. +``` diff --git a/docs/griptape-tools/official-tools/prompt-summary-client.md b/docs/griptape-tools/official-tools/prompt-summary-client.md new file mode 100644 index 000000000..87575d7bf --- /dev/null +++ b/docs/griptape-tools/official-tools/prompt-summary-client.md @@ -0,0 +1,107 @@ +The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. + +Here is an example of how it can be used with a local vector store driver: + +```python +--8<-- "docs/griptape-tools/official-tools/src/prompt_summary_client_1.py" +``` +``` +[08/12/24 15:54:46] INFO ToolkitTask 8be73eb542c44418ba880399044c017a + Input: How can I build Neovim from source for MacOS according to this https://github.com/neovim/neovim/blob/master/BUILD.md +[08/12/24 15:54:47] INFO Subtask cd362a149e1d400997be93c1342d1663 + Actions: [ + { + "tag": "call_DGsOHC4AVxhV7RPVA7q3rATX", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://github.com/neovim/neovim/blob/master/BUILD.md" + } + } + } + ] +[08/12/24 15:54:49] INFO Subtask cd362a149e1d400997be93c1342d1663 + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "990b689c57de4581b8715963c0aecfe3" +[08/12/24 15:54:50] INFO Subtask 919a4a9eb900439ab9bfbf6e921feba3 + Actions: [ + { + "tag": "call_DK3a4MYoElJbaCrUJekBReIc", + "name": "PromptSummaryClient", + "path": "summarize", + "input": { + "values": { + "summary": { + "memory_name": "TaskMemory", + "artifact_namespace": "990b689c57de4581b8715963c0aecfe3" + } + } + } + } + ] +[08/12/24 15:54:56] INFO Subtask 919a4a9eb900439ab9bfbf6e921feba3 + Response: The text provides detailed instructions for building and installing Neovim from source on various operating systems, including Linux, + BSD, Windows, and macOS. Key points include: + + 1. **Pre-requisites**: Ensure all build dependencies are installed. + 2. **Cloning and Building**: + - Clone the Neovim repository. + - Use `make` with appropriate `CMAKE_BUILD_TYPE` (Release, Debug, RelWithDebInfo). + - For stable releases, checkout the stable branch. + - Use `gmake` on BSD and follow specific instructions for Windows using MSVC. + 3. **Installation**: + - Default install location is `/usr/local`. + - On Debian/Ubuntu, build a DEB package for clean installation. + - Custom install locations can be set using `CMAKE_INSTALL_PREFIX`. + 4. **Dependencies**: + - Third-party dependencies are automatically downloaded. + - Use `Ninja` and `ccache` for faster builds. + 5. **Platform-Specific Instructions**: + - Detailed steps for building on Windows using Visual Studio. + - Instructions for building on macOS using Homebrew or MacPorts. + - Specific commands for building on NixOS, Cygwin, and MSYS2. + 6. **Debugging and Development**: + - Use `RelWithDebInfo` for optimized builds with debug info. + - Instructions for using debuggers and handling build cache. + 7. **Localization**: + - Commands for checking and updating translations. + + The text also includes troubleshooting tips and links to additional resources for further guidance. +[08/12/24 15:55:01] INFO ToolkitTask 8be73eb542c44418ba880399044c017a + Output: To build Neovim from source on macOS, follow these steps: + + 1. **Install Dependencies**: + Ensure you have the necessary build dependencies installed. You can use Homebrew to install them: + ```sh + brew install ninja libtool automake cmake pkg-config gettext curl + ``` + + 2. **Clone the Neovim Repository**: + Clone the Neovim repository from GitHub: + ```sh + git clone https://github.com/neovim/neovim.git + cd neovim + ``` + + 3. **Checkout the Stable Branch (Optional)**: + If you want to build the stable release, checkout the stable branch: + ```sh + git checkout stable + ``` + + 4. **Build Neovim**: + Use `make` to build Neovim. You can specify the build type (Release, Debug, RelWithDebInfo): + ```sh + make CMAKE_BUILD_TYPE=Release + ``` + + 5. **Install Neovim**: + After building, install Neovim. The default install location is `/usr/local`: + ```sh + sudo make install + ``` + + By following these steps, you should be able to build and install Neovim from source on macOS. For more detailed instructions and + troubleshooting tips, refer to the [BUILD.md](https://github.com/neovim/neovim/blob/master/BUILD.md) file in the Neovim repository. +``` diff --git a/docs/griptape-tools/official-tools/query-client.md b/docs/griptape-tools/official-tools/query-client.md new file mode 100644 index 000000000..53b5b2f5a --- /dev/null +++ b/docs/griptape-tools/official-tools/query-client.md @@ -0,0 +1,87 @@ +The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. + +Here is an example of how it can be used with a local vector store driver: + +```python +--8<-- "docs/griptape-tools/official-tools/src/query_client_1.py" +``` +``` +[08/12/24 15:49:23] INFO ToolkitTask a88abda2e5324bdf81a3e2b99c26b9df + Input: Tell me about the architecture as described here: https://neovim.io/doc/user/vim_diff.html +[08/12/24 15:49:24] INFO Subtask 3dc9910bcac44c718b3aedd6222e372a + Actions: [ + { + "tag": "call_VY4r5YRc2QDjtBvn89z5PH8E", + "name": "WebScraper", + "path": "get_content", + "input": { + "values": { + "url": "https://neovim.io/doc/user/vim_diff.html" + } + } + } + ] +[08/12/24 15:49:25] INFO Subtask 3dc9910bcac44c718b3aedd6222e372a + Response: Output of "WebScraper.get_content" was stored in memory with memory_name "TaskMemory" and artifact_namespace + "bec6deeac5f84e369c41210e67905415" +[08/12/24 15:49:26] INFO Subtask f41d2189ecff4458acb8e6dadb5b13aa + Actions: [ + { + "tag": "call_GtBICZi6oIeL85Aj7q5szul9", + "name": "QueryClient", + "path": "query", + "input": { + "values": { + "query": "architecture", + "content": { + "memory_name": "TaskMemory", + "artifact_namespace": "bec6deeac5f84e369c41210e67905415" + } + } + } + } + ] +[08/12/24 15:49:33] INFO Subtask f41d2189ecff4458acb8e6dadb5b13aa + Response: The architecture of Neovim (Nvim) is designed to improve stability, performance, and extensibility. Here are some key points about + Nvim's architecture: + + 1. **Decoupled UI**: The Nvim UI is decoupled from the core editor. All UIs, including the built-in TUI (terminal user interface), are plugins + that connect to a Nvim server. Multiple Nvim UI clients can connect to the same Nvim editor server. + + 2. **External Plugins**: External plugins run in separate processes, which improves stability and allows those plugins to work without blocking + the editor. Even "legacy" Python and Ruby plugins, which use the old Vim interfaces, run out-of-process, so they cannot crash Nvim. + + 3. **Libuv**: Platform and I/O facilities are built upon libuv. Nvim benefits from libuv features and bug fixes, and other projects benefit from + improvements to libuv by Nvim developers. + + 4. **Robust API**: Nvim has a robust API, which is used instead of exposing internal test functions like Vim's `test_autochdir()`, + `test_settime()`, etc. + + 5. **Feature Inclusion**: Nvim always includes all features, in contrast to Vim, which ships various combinations of 100+ optional features. + This reduces the surface area for bugs and removes a common source of confusion and friction for users. + + 6. **External Plugins and Extensions**: Nvim avoids features that cannot be provided on all platforms, delegating those to external + plugins/extensions. + + These architectural decisions make Nvim more stable, extensible, and user-friendly compared to traditional Vim. +[08/12/24 15:49:37] INFO ToolkitTask a88abda2e5324bdf81a3e2b99c26b9df + Output: The architecture of Neovim (Nvim) is designed to enhance stability, performance, and extensibility. Here are the key points: + + 1. **Decoupled UI**: The user interface (UI) is separated from the core editor. All UIs, including the built-in terminal user interface (TUI), + are plugins that connect to a Nvim server. This allows multiple UI clients to connect to the same Nvim editor server. + + 2. **External Plugins**: Plugins run in separate processes, which improves stability and prevents them from blocking the editor. Even older + Python and Ruby plugins run out-of-process, ensuring they cannot crash Nvim. + + 3. **Libuv**: Nvim's platform and I/O facilities are built on libuv, benefiting from its features and bug fixes. Improvements made by Nvim + developers to libuv also benefit other projects. + + 4. **Robust API**: Nvim provides a robust API, avoiding the need to expose internal test functions like Vim does. + + 5. **Feature Inclusion**: Unlike Vim, which ships with various combinations of optional features, Nvim includes all features by default. This + reduces bugs and user confusion. + + 6. **External Plugins and Extensions**: Nvim delegates features that cannot be provided on all platforms to external plugins/extensions. + + These architectural choices make Nvim more stable, extensible, and user-friendly compared to traditional Vim. +``` diff --git a/docs/griptape-tools/official-tools/src/extraction_client_1.py b/docs/griptape-tools/official-tools/src/extraction_client_1.py new file mode 100644 index 000000000..e39cdbd79 --- /dev/null +++ b/docs/griptape-tools/official-tools/src/extraction_client_1.py @@ -0,0 +1,28 @@ +import schema + +from griptape.engines import JsonExtractionEngine +from griptape.structures import Agent +from griptape.tools import ExtractionClient, WebScraper + +agent = Agent( + input="Load {{ args[0] }} and extract key info", + tools=[ + WebScraper(off_prompt=True), + ExtractionClient( + off_prompt=False, + extraction_engine=JsonExtractionEngine( + template_schema=schema.Schema( + { + "company_name": str, + "industry": str, + schema.Literal( + "product_features", + description="List of key product features.", + ): list[str], + } + ).json_schema("Company Info"), + ), + ), + ], +) +agent.run("https://griptape.ai") diff --git a/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py b/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py new file mode 100644 index 000000000..6e57b4b00 --- /dev/null +++ b/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py @@ -0,0 +1,8 @@ +from griptape.structures import Agent +from griptape.tools import PromptSummaryClient, WebScraper + +agent = Agent(tools=[WebScraper(off_prompt=True), PromptSummaryClient()]) + +agent.run( + "How can I build Neovim from source for MacOS according to this https://github.com/neovim/neovim/blob/master/BUILD.md" +) diff --git a/docs/griptape-tools/official-tools/src/query_client_1.py b/docs/griptape-tools/official-tools/src/query_client_1.py new file mode 100644 index 000000000..c917c9e04 --- /dev/null +++ b/docs/griptape-tools/official-tools/src/query_client_1.py @@ -0,0 +1,6 @@ +from griptape.structures import Agent +from griptape.tools import QueryClient, WebScraper + +agent = Agent(tools=[WebScraper(off_prompt=True), QueryClient()]) + +agent.run("Tell me about the architecture as described here: https://neovim.io/doc/user/vim_diff.html") diff --git a/mkdocs.yml b/mkdocs.yml index bb65d99f3..c769d11ec 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -156,6 +156,9 @@ nav: - AudioTranscriptionClient: "griptape-tools/official-tools/audio-transcription-client.md" - GriptapeCloudKnowledgeBaseClient: "griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md" - RagClient: "griptape-tools/official-tools/rag-client.md" + - ExtractionClient: "griptape-tools/official-tools/extraction-client.md" + - QueryClient: "griptape-tools/official-tools/query-client.md" + - PromptSummaryClient: "griptape-tools/official-tools/prompt-summary-client.md" - Custom Tools: - Building Custom Tools: "griptape-tools/custom-tools/index.md" - Recipes: diff --git a/poetry.lock b/poetry.lock index d3ac878c7..c167ed21b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3124,23 +3124,25 @@ mkdocs = ">=1.2" [[package]] name = "mkdocstrings" -version = "0.23.0" +version = "0.25.2" description = "Automatic documentation from sources, for MkDocs." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocstrings-0.23.0-py3-none-any.whl", hash = "sha256:051fa4014dfcd9ed90254ae91de2dbb4f24e166347dae7be9a997fe16316c65e"}, - {file = "mkdocstrings-0.23.0.tar.gz", hash = "sha256:d9c6a37ffbe7c14a7a54ef1258c70b8d394e6a33a1c80832bce40b9567138d1c"}, + {file = "mkdocstrings-0.25.2-py3-none-any.whl", hash = "sha256:9e2cda5e2e12db8bb98d21e3410f3f27f8faab685a24b03b06ba7daa5b92abfc"}, + {file = "mkdocstrings-0.25.2.tar.gz", hash = "sha256:5cf57ad7f61e8be3111a2458b4e49c2029c9cb35525393b179f9c916ca8042dc"}, ] [package.dependencies] +click = ">=7.0" importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} Jinja2 = ">=2.11.1" Markdown = ">=3.3" MarkupSafe = ">=1.1" -mkdocs = ">=1.2" +mkdocs = ">=1.4" mkdocs-autorefs = ">=0.3.1" mkdocstrings-python = {version = ">=0.5.2", optional = true, markers = "extra == \"python\""} +platformdirs = ">=2.2.0" pymdown-extensions = ">=6.3" typing-extensions = {version = ">=4.1", markers = "python_version < \"3.10\""} @@ -6955,4 +6957,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "107e2bcaa620a1134fb649254bc6157304754eb7f617334c544ac5db70286fdd" +content-hash = "06a69b74f09aa3ff57b7bd87e9b59d6627c08c2aedbf6862fafc180f0a90fa57" diff --git a/pyproject.toml b/pyproject.toml index 8ca015c41..6db25e3d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,7 +239,7 @@ optional = true mkdocs = "^1.5.2" mkdocs-material = "^9.2.8" mkdocs-glightbox = "^0.3.4" -mkdocstrings = {extras = ["python"], version = "^0.23.0"} +mkdocstrings = {extras = ["python"], version = "^0.25.2"} mkdocs-gen-files = "^0.5.0" mkdocs-literate-nav = "^0.6.0" mkdocs-section-index = "^0.3.6" From 63d93ffe549fa87ed15018e706f3f45f77b1c49b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 12 Aug 2024 16:13:00 -0700 Subject: [PATCH 47/63] Fix test --- CHANGELOG.md | 2 +- tests/unit/structures/test_agent.py | 18 +++--------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05bb8b4c4..b51bb4e68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Unique name generation for all `RagEngine` modules. - `ExtractionClient` Tool for having the LLM extract structured data from text. - `PromptSummaryClient` Tool for having the LLM summarize text. -- `QueryClient` Tool for having hte LLM queyr text. +- `QueryClient` Tool for having the LLM queyr text. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 33bfdc5ee..ef5faeff1 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,13 +1,11 @@ import pytest -from griptape.engines import PromptSummaryEngine from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Agent from griptape.tasks import BaseTask, PromptTask, ToolkitTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -218,23 +216,13 @@ def test_context(self): assert context["structure"] == agent - def test_task_memory_defaults(self): - prompt_driver = MockPromptDriver() - embedding_driver = MockEmbeddingDriver() - agent = Agent(prompt_driver=prompt_driver) + def test_task_memory_defaults(self, mock_config): + agent = Agent() storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) - assert storage.rag_engine.response_stage.response_modules[0].prompt_driver == prompt_driver - assert ( - storage.rag_engine.retrieval_stage.retrieval_modules[0].vector_store_driver.embedding_driver - == embedding_driver - ) - assert isinstance(storage.summary_engine, PromptSummaryEngine) - assert storage.summary_engine.prompt_driver == prompt_driver - assert storage.csv_extraction_engine.prompt_driver == prompt_driver - assert storage.json_extraction_engine.prompt_driver == prompt_driver + assert storage.vector_store_driver.embedding_driver == mock_config.drivers.embedding def finished_tasks(self): task = PromptTask("test prompt") From 1ff7e88f293b1d7999bc612afb32db289f5d3b18 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 10:36:27 -0700 Subject: [PATCH 48/63] Clean up example --- docs/griptape-framework/misc/src/events_3.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index ab995e018..bae8b8224 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,6 +1,5 @@ from typing import cast -from griptape.config import OpenAiDriverConfig, config from griptape.drivers import OpenAiChatPromptDriver from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline @@ -16,13 +15,11 @@ ] ) -config.drivers = OpenAiDriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True)) - pipeline = Pipeline() - pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True), tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], ) ) From 6944ba4d289147a1803682e230f53c3c8e2d9bd9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 10:37:50 -0700 Subject: [PATCH 49/63] Default stream to config value --- griptape/structures/agent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 59a865897..1e79fab14 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -23,7 +23,7 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) - stream: bool = field(default=False, kw_only=True) + stream: bool = field(default=Factory(lambda: config.drivers.prompt.stream), kw_only=True) prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) @@ -37,7 +37,6 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() - self.prompt_driver.stream = self.stream if len(self.tasks) == 0: if self.tools: task = ToolkitTask( From 8a0845bdec50ba1a31e6d664b308c1975c2aa865 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 12:04:14 -0700 Subject: [PATCH 50/63] Rename new tools --- CHANGELOG.md | 8 ++++---- README.md | 4 ++-- docs/examples/multiple-agent-shared-memory.md | 4 ++-- docs/examples/src/multi_agent_workflow_1.py | 4 ++-- .../src/multiple_agent_shared_memory_1.py | 4 ++-- docs/examples/src/query_webpage_astra_db_1.py | 4 ++-- .../drivers/src/embedding_drivers_10.py | 4 ++-- .../drivers/src/web_search_drivers_2.py | 4 ++-- docs/griptape-framework/misc/src/events_3.py | 4 ++-- docs/griptape-framework/misc/src/events_4.py | 4 ++-- docs/griptape-framework/src/index_4.py | 4 ++-- .../structures/src/task_memory_3.py | 4 ++-- .../structures/src/task_memory_5.py | 4 ++-- .../structures/src/task_memory_6.py | 4 ++-- .../structures/src/task_memory_8.py | 4 ++-- .../structures/src/tasks_16.py | 4 ++-- .../griptape-framework/structures/src/tasks_4.py | 4 ++-- .../griptape-framework/structures/task-memory.md | 12 ++++++------ docs/griptape-framework/tools/src/index_1.py | 4 ++-- .../official-tools/extraction-client.md | 2 +- .../official-tools/image-query-client.md | 2 +- .../official-tools/query-client.md | 2 +- .../official-tools/src/extraction_client_1.py | 4 ++-- .../official-tools/src/image_query_client_1.py | 4 ++-- .../src/prompt_summary_client_1.py | 4 ++-- .../official-tools/src/query_client_1.py | 4 ++-- .../official-tools/src/vector_store_client_1.py | 4 ++-- .../official-tools/src/web_scraper_1.py | 4 ++-- griptape/tools/__init__.py | 16 ++++++++-------- .../__init__.py | 0 .../manifest.yml | 0 .../requirements.txt | 0 .../{extraction_client => extraction}/tool.py | 2 +- griptape/tools/image_query_client/tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 0 .../requirements.txt | 0 .../tool.py | 2 +- .../tools/{query_client => query}/__init__.py | 0 .../tools/{query_client => query}/manifest.yml | 0 .../{query_client => query}/requirements.txt | 0 griptape/tools/{query_client => query}/tool.py | 2 +- mkdocs.yml | 6 +++--- tests/integration/tasks/test_toolkit_task.py | 4 ++-- tests/unit/tools/test_extraction_client.py | 8 ++++---- tests/unit/tools/test_prompt_summary_client.py | 4 ++-- 46 files changed, 82 insertions(+), 82 deletions(-) rename griptape/tools/{extraction_client => extraction}/__init__.py (100%) rename griptape/tools/{extraction_client => extraction}/manifest.yml (100%) rename griptape/tools/{extraction_client => extraction}/requirements.txt (100%) rename griptape/tools/{extraction_client => extraction}/tool.py (98%) rename griptape/tools/{prompt_summary_client => prompt_summary}/__init__.py (100%) rename griptape/tools/{prompt_summary_client => prompt_summary}/manifest.yml (100%) rename griptape/tools/{prompt_summary_client => prompt_summary}/requirements.txt (100%) rename griptape/tools/{prompt_summary_client => prompt_summary}/tool.py (97%) rename griptape/tools/{query_client => query}/__init__.py (100%) rename griptape/tools/{query_client => query}/manifest.yml (100%) rename griptape/tools/{query_client => query}/requirements.txt (100%) rename griptape/tools/{query_client => query}/tool.py (98%) diff --git a/CHANGELOG.md b/CHANGELOG.md index b51bb4e68..76a2cdf74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,9 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. - Global config, `griptape.config.config`, for setting global configuration defaults. - Unique name generation for all `RagEngine` modules. -- `ExtractionClient` Tool for having the LLM extract structured data from text. +- `ExtractionTool` Tool for having the LLM extract structured data from text. - `PromptSummaryClient` Tool for having the LLM summarize text. -- `QueryClient` Tool for having the LLM queyr text. +- `QueryTool` Tool for having the LLM query text. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. @@ -41,7 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Changed `JsonExtractionEngine.template_schema` from a `run` argument to a class attribute. - **BREAKING**: Changed `CsvExtractionEngine.column_names` from a `run` argument to a class attribute. - **BREAKING**: Removed `JsonExtractionTask`, and `CsvExtractionTask` use `ExtractionTask` instead. -- **BREAKING**: Removed `TaskMemoryClient`, use `RagClient`, `ExtractionClient`, or `PromptSummaryClient` instead. +- **BREAKING**: Removed `TaskMemoryClient`, use `RagClient`, `ExtractionTool`, or `PromptSummaryClient` instead. - `RagClient` now can be used to search through Artifacts stored in Task Memory. - Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. @@ -454,7 +454,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `JsonExtractionTask` for convenience over using `ExtractionTask` with a `JsonExtractionEngine`. - `CsvExtractionTask` for convenience over using `ExtractionTask` with a `CsvExtractionEngine`. - `OpenAiVisionImageQueryDriver` to support queries on images using OpenAI's Vision model. -- `ImageQueryClient` allowing an Agent to make queries on images on disk or in Task Memory. +- `ImageQueryTool` allowing an Agent to make queries on images on disk or in Task Memory. - `ImageQueryTask` and `ImageQueryEngine`. ### Fixed diff --git a/README.md b/README.md index 555efd984..e10c58f98 100644 --- a/README.md +++ b/README.md @@ -89,13 +89,13 @@ With Griptape, you can create Structures, such as Agents, Pipelines, and Workflo ```python from griptape.structures import Agent -from griptape.tools import WebScraper, FileManager, PromptSummaryClient +from griptape.tools import WebScraper, FileManager, PromptSummaryTool agent = Agent( input="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.", tools=[ WebScraper(off_prompt=True), - PromptSummaryClient(off_prompt=True), + PromptSummaryTool(off_prompt=True), FileManager() ] ) diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index 41bf8677a..bd91f977d 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -1,6 +1,6 @@ -This example shows how to use one `Agent` to load content into `TaskMemory` and get that content from another `Agent` using `QueryClient`. +This example shows how to use one `Agent` to load content into `TaskMemory` and get that content from another `Agent` using `QueryTool`. -The first `Agent` uses a remote vector store (`MongoDbAtlasVectorStoreDriver` in this example) to handle memory operations. The second `Agent` uses the same instance of `TaskMemory` and the `QueryClient` with the same `MongoDbAtlasVectorStoreDriver` to get the data. +The first `Agent` uses a remote vector store (`MongoDbAtlasVectorStoreDriver` in this example) to handle memory operations. The second `Agent` uses the same instance of `TaskMemory` and the `QueryTool` with the same `MongoDbAtlasVectorStoreDriver` to get the data. The `MongoDbAtlasVectorStoreDriver` assumes that you have a vector index configured where the path to the content is called `vector`, and the number of dimensions set on the index is `1536` (this is a commonly used number of dimensions for embedding models). diff --git a/docs/examples/src/multi_agent_workflow_1.py b/docs/examples/src/multi_agent_workflow_1.py index 6a3b31c19..de4c256e0 100644 --- a/docs/examples/src/multi_agent_workflow_1.py +++ b/docs/examples/src/multi_agent_workflow_1.py @@ -5,7 +5,7 @@ from griptape.structures import Agent, Workflow from griptape.tasks import PromptTask, StructureRunTask from griptape.tools import ( - PromptSummaryClient, + PromptSummaryTool, WebScraper, WebSearch, ) @@ -38,7 +38,7 @@ def build_researcher() -> Agent: WebScraper( off_prompt=True, ), - PromptSummaryClient(off_prompt=False), + PromptSummaryTool(off_prompt=False), ], rulesets=[ Ruleset( diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index dafa50559..2864dd04e 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -3,7 +3,7 @@ from griptape.config import AzureOpenAiDriverConfig, config from griptape.drivers import AzureMongoDbVectorStoreDriver, AzureOpenAiEmbeddingDriver from griptape.structures import Agent -from griptape.tools import QueryClient, WebScraper +from griptape.tools import QueryTool, WebScraper AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -46,7 +46,7 @@ ) asker = Agent( tools=[ - QueryClient(off_prompt=False), + QueryTool(off_prompt=False), ], meta_memory=loader.meta_memory, task_memory=loader.task_memory, diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index d31d3b53f..38c458567 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -14,7 +14,7 @@ from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import QueryClient, RagClient +from griptape.tools import QueryTool, RagClient namespace = "datastax_blog" input_blogpost = "www.datastax.com/blog/indexing-all-of-wikipedia-on-a-laptop" @@ -53,5 +53,5 @@ description="A DataStax blog post", rag_engine=engine, ) -agent = Agent(tools=[vector_store_tool, QueryClient(off_prompt=False)]) +agent = Agent(tools=[vector_store_tool, QueryTool(off_prompt=False)]) agent.run("What engine made possible to index such an amount of data, " "and what kind of tuning was required?") diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index a22af9ab9..b11871c67 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -4,7 +4,7 @@ VoyageAiEmbeddingDriver, ) from griptape.structures import Agent -from griptape.tools import PromptSummaryClient, WebScraper +from griptape.tools import PromptSummaryTool, WebScraper config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), @@ -17,7 +17,7 @@ ) agent = Agent( - tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), PromptSummaryTool(off_prompt=False)], ) agent.run("based on https://www.griptape.ai/, tell me what Griptape is") diff --git a/docs/griptape-framework/drivers/src/web_search_drivers_2.py b/docs/griptape-framework/drivers/src/web_search_drivers_2.py index 4c2c469a3..2b92f0017 100644 --- a/docs/griptape-framework/drivers/src/web_search_drivers_2.py +++ b/docs/griptape-framework/drivers/src/web_search_drivers_2.py @@ -2,7 +2,7 @@ from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent -from griptape.tools import PromptSummaryClient, WebSearch +from griptape.tools import PromptSummaryTool, WebSearch agent = Agent( tools=[ @@ -12,7 +12,7 @@ search_id=os.environ["GOOGLE_API_SEARCH_ID"], ), ), - PromptSummaryClient(off_prompt=False), + PromptSummaryTool(off_prompt=False), ], ) agent.run("Give me some websites with information about AI frameworks.") diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index 3e3d0ae62..291e6dd7c 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -4,7 +4,7 @@ from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import PromptSummaryClient, WebScraper +from griptape.tools import PromptSummaryTool, WebScraper event_bus.add_event_listeners( [ @@ -20,7 +20,7 @@ ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True), - tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), PromptSummaryTool(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/misc/src/events_4.py b/docs/griptape-framework/misc/src/events_4.py index a3fe44b78..b20e86f60 100644 --- a/docs/griptape-framework/misc/src/events_4.py +++ b/docs/griptape-framework/misc/src/events_4.py @@ -1,13 +1,13 @@ from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import PromptSummaryClient, WebScraper +from griptape.tools import PromptSummaryTool, WebScraper from griptape.utils import Stream pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", - tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), PromptSummaryTool(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/src/index_4.py b/docs/griptape-framework/src/index_4.py index b1108d9e9..48d602b17 100644 --- a/docs/griptape-framework/src/index_4.py +++ b/docs/griptape-framework/src/index_4.py @@ -1,7 +1,7 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask -from griptape.tools import FileManager, PromptSummaryClient, WebScraper +from griptape.tools import FileManager, PromptSummaryTool, WebScraper # Pipelines represent sequences of tasks. pipeline = Pipeline(conversation_memory=ConversationMemory()) @@ -11,7 +11,7 @@ ToolkitTask( "{{ args[0] }}", # Add tools for web scraping, and file management - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), PromptSummaryClient(off_prompt=False)], + tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), PromptSummaryTool(off_prompt=False)], ), # Augment `input` from the previous task. PromptTask("Say the following in spanish: {{ parent_output }}"), diff --git a/docs/griptape-framework/structures/src/task_memory_3.py b/docs/griptape-framework/structures/src/task_memory_3.py index 14649a222..c7d0617fd 100644 --- a/docs/griptape-framework/structures/src/task_memory_3.py +++ b/docs/griptape-framework/structures/src/task_memory_3.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator, PromptSummaryClient +from griptape.tools import Calculator, PromptSummaryTool # Create an agent with the Calculator tool -agent = Agent(tools=[Calculator(off_prompt=True), PromptSummaryClient(off_prompt=False)]) +agent = Agent(tools=[Calculator(off_prompt=True), PromptSummaryTool(off_prompt=False)]) agent.run("What is the square root of 12345?") diff --git a/docs/griptape-framework/structures/src/task_memory_5.py b/docs/griptape-framework/structures/src/task_memory_5.py index 255e72397..a3052ab5a 100644 --- a/docs/griptape-framework/structures/src/task_memory_5.py +++ b/docs/griptape-framework/structures/src/task_memory_5.py @@ -1,10 +1,10 @@ from griptape.structures import Agent -from griptape.tools import QueryClient, WebScraper +from griptape.tools import QueryTool, WebScraper agent = Agent( tools=[ WebScraper(off_prompt=True), - QueryClient(off_prompt=False), + QueryTool(off_prompt=False), ] ) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 0e915ba87..c92f30261 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -11,7 +11,7 @@ from griptape.memory import TaskMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent -from griptape.tools import FileManager, QueryClient, WebScraper +from griptape.tools import FileManager, QueryTool, WebScraper config.drivers = OpenAiDriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4"), @@ -33,7 +33,7 @@ ), tools=[ WebScraper(off_prompt=True), - QueryClient(off_prompt=True), + QueryTool(off_prompt=True), FileManager(off_prompt=True), ], ) diff --git a/docs/griptape-framework/structures/src/task_memory_8.py b/docs/griptape-framework/structures/src/task_memory_8.py index 4f19a235f..a6e66308c 100644 --- a/docs/griptape-framework/structures/src/task_memory_8.py +++ b/docs/griptape-framework/structures/src/task_memory_8.py @@ -1,10 +1,10 @@ from griptape.structures import Agent -from griptape.tools import PromptSummaryClient, WebScraper +from griptape.tools import PromptSummaryTool, WebScraper agent = Agent( tools=[ WebScraper(off_prompt=True), # This tool will store the data in Task Memory - PromptSummaryClient( + PromptSummaryTool( off_prompt=True ), # This tool will store the data back in Task Memory with no way to get it out ] diff --git a/docs/griptape-framework/structures/src/tasks_16.py b/docs/griptape-framework/structures/src/tasks_16.py index ac8d8d5b2..6c6ab87f5 100644 --- a/docs/griptape-framework/structures/src/tasks_16.py +++ b/docs/griptape-framework/structures/src/tasks_16.py @@ -5,7 +5,7 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask from griptape.tools import ( - PromptSummaryClient, + PromptSummaryTool, WebScraper, WebSearch, ) @@ -23,7 +23,7 @@ def build_researcher() -> Agent: WebScraper( off_prompt=True, ), - PromptSummaryClient(off_prompt=False), + PromptSummaryTool(off_prompt=False), ], rulesets=[ Ruleset( diff --git a/docs/griptape-framework/structures/src/tasks_4.py b/docs/griptape-framework/structures/src/tasks_4.py index 747a82c26..02d21016a 100644 --- a/docs/griptape-framework/structures/src/tasks_4.py +++ b/docs/griptape-framework/structures/src/tasks_4.py @@ -1,12 +1,12 @@ from griptape.structures import Agent from griptape.tasks import ToolkitTask -from griptape.tools import FileManager, PromptSummaryClient, WebScraper +from griptape.tools import FileManager, PromptSummaryTool, WebScraper agent = Agent() agent.add_task( ToolkitTask( "Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt", - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), PromptSummaryClient(off_prompt=True)], + tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), PromptSummaryTool(off_prompt=True)], ), ) diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index b58b99e94..094f20948 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -143,8 +143,8 @@ When running this example, we get the following error: Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}} ``` -This is because the content of the webpage is too large to fit in the LLM's input token limit. We can fix this by storing the content in Task Memory, and then querying it with the `QueryClient`. -Note that we're setting `off_prompt` to `False` on the `QueryClient` so that the _queried_ content can be returned directly to the LLM. +This is because the content of the webpage is too large to fit in the LLM's input token limit. We can fix this by storing the content in Task Memory, and then querying it with the `QueryTool`. +Note that we're setting `off_prompt` to `False` on the `QueryTool` so that the _queried_ content can be returned directly to the LLM. ```python --8<-- "docs/griptape-framework/structures/src/task_memory_5.py" @@ -174,7 +174,7 @@ And now we get the expected output: Actions: [ { "tag": "call_DGsOHC4AVxhV7RPVA7q3rATX", - "name": "QueryClient", + "name": "QueryTool", "path": "search", "input": { "values": { @@ -229,7 +229,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s Actions: [ { "tag": "call_Oiqq6oI20yqmdNrH9Mawb2fS", - "name": "QueryClient", + "name": "QueryTool", "path": "search", "input": { "values": { @@ -243,7 +243,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s } ] [08/12/24 14:55:34] INFO Subtask d8b4cf297a0d4d9db04e4f8e63b746c8 - Response: Output of "QueryClient.search" was stored in memory with memory_name "TaskMemory" and artifact_namespace + Response: Output of "QueryTool.search" was stored in memory with memory_name "TaskMemory" and artifact_namespace "fd828ddd629e4974a7837f9dfde65954" [08/12/24 14:55:38] INFO Subtask 7aafcb3fb0d845858e2fcf9b8dc8a7ec Actions: [ @@ -274,7 +274,7 @@ As seen in the previous example, certain Tools are designed to read directly fro Today, these include: - [PromptSummaryClient](../../griptape-tools/official-tools/prompt-summary-client.md) -- [ExtractionClient](../../griptape-tools/official-tools/extraction-client.md) +- [ExtractionTool](../../griptape-tools/official-tools/extraction-client.md) - [RagClient](../../griptape-tools/official-tools/rag-client.md) - [FileManager](../../griptape-tools/official-tools/file-manager.md) diff --git a/docs/griptape-framework/tools/src/index_1.py b/docs/griptape-framework/tools/src/index_1.py index 0366a314e..1aa2ce0a4 100644 --- a/docs/griptape-framework/tools/src/index_1.py +++ b/docs/griptape-framework/tools/src/index_1.py @@ -1,6 +1,6 @@ from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import FileManager, PromptSummaryClient, WebScraper +from griptape.tools import FileManager, PromptSummaryTool, WebScraper pipeline = Pipeline() @@ -10,7 +10,7 @@ tools=[ WebScraper(off_prompt=True), FileManager(off_prompt=True), - PromptSummaryClient(off_prompt=False), + PromptSummaryTool(off_prompt=False), ], ), ) diff --git a/docs/griptape-tools/official-tools/extraction-client.md b/docs/griptape-tools/official-tools/extraction-client.md index ad2617219..2812d2ad0 100644 --- a/docs/griptape-tools/official-tools/extraction-client.md +++ b/docs/griptape-tools/official-tools/extraction-client.md @@ -28,7 +28,7 @@ Here is an example of how it can be used with a local vector store driver: Actions: [ { "tag": "call_o3MrpM01OnhCfpxsMe85tpDF", - "name": "ExtractionClient", + "name": "ExtractionTool", "path": "extract_json", "input": { "values": { diff --git a/docs/griptape-tools/official-tools/image-query-client.md b/docs/griptape-tools/official-tools/image-query-client.md index ae02d127c..a633f975f 100644 --- a/docs/griptape-tools/official-tools/image-query-client.md +++ b/docs/griptape-tools/official-tools/image-query-client.md @@ -1,4 +1,4 @@ -# ImageQueryClient +# ImageQueryTool This tool allows Agents to execute natural language queries on the contents of images using multimodal models. diff --git a/docs/griptape-tools/official-tools/query-client.md b/docs/griptape-tools/official-tools/query-client.md index 53b5b2f5a..d85174ec4 100644 --- a/docs/griptape-tools/official-tools/query-client.md +++ b/docs/griptape-tools/official-tools/query-client.md @@ -28,7 +28,7 @@ Here is an example of how it can be used with a local vector store driver: Actions: [ { "tag": "call_GtBICZi6oIeL85Aj7q5szul9", - "name": "QueryClient", + "name": "QueryTool", "path": "query", "input": { "values": { diff --git a/docs/griptape-tools/official-tools/src/extraction_client_1.py b/docs/griptape-tools/official-tools/src/extraction_client_1.py index e39cdbd79..d2eccb213 100644 --- a/docs/griptape-tools/official-tools/src/extraction_client_1.py +++ b/docs/griptape-tools/official-tools/src/extraction_client_1.py @@ -2,13 +2,13 @@ from griptape.engines import JsonExtractionEngine from griptape.structures import Agent -from griptape.tools import ExtractionClient, WebScraper +from griptape.tools import ExtractionTool, WebScraper agent = Agent( input="Load {{ args[0] }} and extract key info", tools=[ WebScraper(off_prompt=True), - ExtractionClient( + ExtractionTool( off_prompt=False, extraction_engine=JsonExtractionEngine( template_schema=schema.Schema( diff --git a/docs/griptape-tools/official-tools/src/image_query_client_1.py b/docs/griptape-tools/official-tools/src/image_query_client_1.py index 177154d2d..032a6a4c2 100644 --- a/docs/griptape-tools/official-tools/src/image_query_client_1.py +++ b/docs/griptape-tools/official-tools/src/image_query_client_1.py @@ -1,7 +1,7 @@ from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.structures import Agent -from griptape.tools import ImageQueryClient +from griptape.tools import ImageQueryTool # Create an Image Query Driver. driver = OpenAiImageQueryDriver(model="gpt-4o") @@ -12,7 +12,7 @@ ) # Create an Image Query Client configured to use the engine. -tool = ImageQueryClient( +tool = ImageQueryTool( image_query_engine=engine, ) diff --git a/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py b/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py index 6e57b4b00..e4740c944 100644 --- a/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py +++ b/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import PromptSummaryClient, WebScraper +from griptape.tools import PromptSummaryTool, WebScraper -agent = Agent(tools=[WebScraper(off_prompt=True), PromptSummaryClient()]) +agent = Agent(tools=[WebScraper(off_prompt=True), PromptSummaryTool()]) agent.run( "How can I build Neovim from source for MacOS according to this https://github.com/neovim/neovim/blob/master/BUILD.md" diff --git a/docs/griptape-tools/official-tools/src/query_client_1.py b/docs/griptape-tools/official-tools/src/query_client_1.py index c917c9e04..a6df7940d 100644 --- a/docs/griptape-tools/official-tools/src/query_client_1.py +++ b/docs/griptape-tools/official-tools/src/query_client_1.py @@ -1,6 +1,6 @@ from griptape.structures import Agent -from griptape.tools import QueryClient, WebScraper +from griptape.tools import QueryTool, WebScraper -agent = Agent(tools=[WebScraper(off_prompt=True), QueryClient()]) +agent = Agent(tools=[WebScraper(off_prompt=True), QueryTool()]) agent.run("Tell me about the architecture as described here: https://neovim.io/doc/user/vim_diff.html") diff --git a/docs/griptape-tools/official-tools/src/vector_store_client_1.py b/docs/griptape-tools/official-tools/src/vector_store_client_1.py index c4e0c5bd9..adc79ad12 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_client_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_client_1.py @@ -2,7 +2,7 @@ from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import PromptSummaryClient, VectorStoreClient +from griptape.tools import PromptSummaryTool, VectorStoreClient vector_store_driver = LocalVectorStoreDriver( embedding_driver=OpenAiEmbeddingDriver(), @@ -20,6 +20,6 @@ off_prompt=True, ) -agent = Agent(tools=[vector_db, PromptSummaryClient()]) +agent = Agent(tools=[vector_db, PromptSummaryTool()]) agent.run("what is Griptape?") diff --git a/docs/griptape-tools/official-tools/src/web_scraper_1.py b/docs/griptape-tools/official-tools/src/web_scraper_1.py index f858d558d..14995d519 100644 --- a/docs/griptape-tools/official-tools/src/web_scraper_1.py +++ b/docs/griptape-tools/official-tools/src/web_scraper_1.py @@ -1,6 +1,6 @@ from griptape.structures import Agent -from griptape.tools import PromptSummaryClient, WebScraper +from griptape.tools import PromptSummaryTool, WebScraper -agent = Agent(tools=[WebScraper(off_prompt=True), PromptSummaryClient(off_prompt=False)]) +agent = Agent(tools=[WebScraper(off_prompt=True), PromptSummaryTool(off_prompt=False)]) agent.run("Based on https://www.griptape.ai/, tell me what griptape is") diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index c87557225..014f957fe 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -25,13 +25,13 @@ from .outpainting_image_generation_client.tool import OutpaintingImageGenerationClient from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient from .structure_run_client.tool import StructureRunClient -from .image_query_client.tool import ImageQueryClient +from .image_query_client.tool import ImageQueryTool from .rag_client.tool import RagClient from .text_to_speech_client.tool import TextToSpeechClient from .audio_transcription_client.tool import AudioTranscriptionClient -from .extraction_client.tool import ExtractionClient -from .prompt_summary_client.tool import PromptSummaryClient -from .query_client.tool import QueryClient +from .extraction.tool import ExtractionTool +from .prompt_summary.tool import PromptSummaryTool +from .query.tool import QueryTool __all__ = [ "BaseTool", @@ -61,11 +61,11 @@ "OutpaintingImageGenerationClient", "GriptapeCloudKnowledgeBaseClient", "StructureRunClient", - "ImageQueryClient", + "ImageQueryTool", "RagClient", "TextToSpeechClient", "AudioTranscriptionClient", - "ExtractionClient", - "PromptSummaryClient", - "QueryClient", + "ExtractionTool", + "PromptSummaryTool", + "QueryTool", ] diff --git a/griptape/tools/extraction_client/__init__.py b/griptape/tools/extraction/__init__.py similarity index 100% rename from griptape/tools/extraction_client/__init__.py rename to griptape/tools/extraction/__init__.py diff --git a/griptape/tools/extraction_client/manifest.yml b/griptape/tools/extraction/manifest.yml similarity index 100% rename from griptape/tools/extraction_client/manifest.yml rename to griptape/tools/extraction/manifest.yml diff --git a/griptape/tools/extraction_client/requirements.txt b/griptape/tools/extraction/requirements.txt similarity index 100% rename from griptape/tools/extraction_client/requirements.txt rename to griptape/tools/extraction/requirements.txt diff --git a/griptape/tools/extraction_client/tool.py b/griptape/tools/extraction/tool.py similarity index 98% rename from griptape/tools/extraction_client/tool.py rename to griptape/tools/extraction/tool.py index ce8c5b034..fca38cc5b 100644 --- a/griptape/tools/extraction_client/tool.py +++ b/griptape/tools/extraction/tool.py @@ -17,7 +17,7 @@ @define(kw_only=True) -class ExtractionClient(BaseTool, RuleMixin): +class ExtractionTool(BaseTool, RuleMixin): """Tool for using an Extraction Engine. Attributes: diff --git a/griptape/tools/image_query_client/tool.py b/griptape/tools/image_query_client/tool.py index a10929b13..97772d546 100644 --- a/griptape/tools/image_query_client/tool.py +++ b/griptape/tools/image_query_client/tool.py @@ -17,7 +17,7 @@ @define -class ImageQueryClient(BaseTool): +class ImageQueryTool(BaseTool): image_query_engine: ImageQueryEngine = field(kw_only=True) image_loader: ImageLoader = field(default=Factory(lambda: ImageLoader()), kw_only=True) diff --git a/griptape/tools/prompt_summary_client/__init__.py b/griptape/tools/prompt_summary/__init__.py similarity index 100% rename from griptape/tools/prompt_summary_client/__init__.py rename to griptape/tools/prompt_summary/__init__.py diff --git a/griptape/tools/prompt_summary_client/manifest.yml b/griptape/tools/prompt_summary/manifest.yml similarity index 100% rename from griptape/tools/prompt_summary_client/manifest.yml rename to griptape/tools/prompt_summary/manifest.yml diff --git a/griptape/tools/prompt_summary_client/requirements.txt b/griptape/tools/prompt_summary/requirements.txt similarity index 100% rename from griptape/tools/prompt_summary_client/requirements.txt rename to griptape/tools/prompt_summary/requirements.txt diff --git a/griptape/tools/prompt_summary_client/tool.py b/griptape/tools/prompt_summary/tool.py similarity index 97% rename from griptape/tools/prompt_summary_client/tool.py rename to griptape/tools/prompt_summary/tool.py index f71b8f325..88ec9d855 100644 --- a/griptape/tools/prompt_summary_client/tool.py +++ b/griptape/tools/prompt_summary/tool.py @@ -11,7 +11,7 @@ @define(kw_only=True) -class PromptSummaryClient(BaseTool, RuleMixin): +class PromptSummaryTool(BaseTool, RuleMixin): """Tool for using a Prompt Summary Engine. Attributes: diff --git a/griptape/tools/query_client/__init__.py b/griptape/tools/query/__init__.py similarity index 100% rename from griptape/tools/query_client/__init__.py rename to griptape/tools/query/__init__.py diff --git a/griptape/tools/query_client/manifest.yml b/griptape/tools/query/manifest.yml similarity index 100% rename from griptape/tools/query_client/manifest.yml rename to griptape/tools/query/manifest.yml diff --git a/griptape/tools/query_client/requirements.txt b/griptape/tools/query/requirements.txt similarity index 100% rename from griptape/tools/query_client/requirements.txt rename to griptape/tools/query/requirements.txt diff --git a/griptape/tools/query_client/tool.py b/griptape/tools/query/tool.py similarity index 98% rename from griptape/tools/query_client/tool.py rename to griptape/tools/query/tool.py index 44ff65387..00d5d1039 100644 --- a/griptape/tools/query_client/tool.py +++ b/griptape/tools/query/tool.py @@ -17,7 +17,7 @@ @define(kw_only=True) -class QueryClient(RagClient, RuleMixin): +class QueryTool(RagClient, RuleMixin): """Tool for performing a query against data.""" description: str = field(init=False) diff --git a/mkdocs.yml b/mkdocs.yml index c769d11ec..dd41eadc4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -151,13 +151,13 @@ nav: - VariationImageGenerationClient: "griptape-tools/official-tools/variation-image-generation-client.md" - InpaintingImageGenerationClient: "griptape-tools/official-tools/inpainting-image-generation-client.md" - OutpaintingImageGenerationClient: "griptape-tools/official-tools/outpainting-image-generation-client.md" - - ImageQueryClient: "griptape-tools/official-tools/image-query-client.md" + - ImageQueryTool: "griptape-tools/official-tools/image-query-client.md" - TextToSpeechClient: "griptape-tools/official-tools/text-to-speech-client.md" - AudioTranscriptionClient: "griptape-tools/official-tools/audio-transcription-client.md" - GriptapeCloudKnowledgeBaseClient: "griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md" - RagClient: "griptape-tools/official-tools/rag-client.md" - - ExtractionClient: "griptape-tools/official-tools/extraction-client.md" - - QueryClient: "griptape-tools/official-tools/query-client.md" + - ExtractionTool: "griptape-tools/official-tools/extraction-client.md" + - QueryTool: "griptape-tools/official-tools/query-client.md" - PromptSummaryClient: "griptape-tools/official-tools/prompt-summary-client.md" - Custom Tools: - Building Custom Tools: "griptape-tools/custom-tools/index.md" diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index 4b61cb9e5..20db7b27a 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -14,7 +14,7 @@ def structure_tester(self, request): from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent - from griptape.tools import PromptSummaryClient, WebScraper, WebSearch + from griptape.tools import PromptSummaryTool, WebScraper, WebSearch return StructureTester( Agent( @@ -25,7 +25,7 @@ def structure_tester(self, request): ) ), WebScraper(off_prompt=True), - PromptSummaryClient(off_prompt=False), + PromptSummaryTool(off_prompt=False), ], conversation_memory=None, prompt_driver=request.param, diff --git a/tests/unit/tools/test_extraction_client.py b/tests/unit/tools/test_extraction_client.py index 6285d5b09..33598971c 100644 --- a/tests/unit/tools/test_extraction_client.py +++ b/tests/unit/tools/test_extraction_client.py @@ -4,15 +4,15 @@ from griptape.artifacts import TextArtifact from griptape.engines import CsvExtractionEngine, JsonExtractionEngine -from griptape.tools import ExtractionClient +from griptape.tools import ExtractionTool from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils import defaults -class TestExtractionClient: +class TestExtractionTool: @pytest.fixture() def json_tool(self): - return ExtractionClient( + return ExtractionTool( input_memory=[defaults.text_task_memory("TestMemory")], extraction_engine=JsonExtractionEngine( prompt_driver=MockPromptDriver( @@ -24,7 +24,7 @@ def json_tool(self): @pytest.fixture() def csv_tool(self): - return ExtractionClient( + return ExtractionTool( input_memory=[defaults.text_task_memory("TestMemory")], extraction_engine=CsvExtractionEngine( prompt_driver=MockPromptDriver(), diff --git a/tests/unit/tools/test_prompt_summary_client.py b/tests/unit/tools/test_prompt_summary_client.py index a31f217dd..0053e0645 100644 --- a/tests/unit/tools/test_prompt_summary_client.py +++ b/tests/unit/tools/test_prompt_summary_client.py @@ -2,7 +2,7 @@ from griptape.artifacts import TextArtifact from griptape.engines import PromptSummaryEngine -from griptape.tools import PromptSummaryClient +from griptape.tools import PromptSummaryTool from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils import defaults @@ -10,7 +10,7 @@ class TestPromptSummaryClient: @pytest.fixture() def tool(self): - return PromptSummaryClient( + return PromptSummaryTool( input_memory=[defaults.text_task_memory("TestMemory")], prompt_summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver()), ) From e64e1c669fb9c39cc93ce254f2fb2454f67d758e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 14:07:08 -0700 Subject: [PATCH 51/63] Rename/structure QueryTool --- griptape/tools/query/tool.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 00d5d1039..5b4e29c29 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -12,16 +12,15 @@ from griptape.engines.rag.rag_context import RagContext from griptape.engines.rag.stages import ResponseRagStage from griptape.mixins.rule_mixin import RuleMixin -from griptape.tools.rag_client.tool import RagClient +from griptape.tools.base_tool import BaseTool from griptape.utils.decorators import activity @define(kw_only=True) -class QueryTool(RagClient, RuleMixin): +class QueryTool(BaseTool, RuleMixin): """Tool for performing a query against data.""" - description: str = field(init=False) - rag_engine: RagEngine = field( + _rag_engine: RagEngine = field( default=Factory( lambda self: RagEngine( response_stage=ResponseRagStage( @@ -32,6 +31,7 @@ class QueryTool(RagClient, RuleMixin): ), takes_self=True, ), + alias="_rag_engine", ) @activity( @@ -70,7 +70,7 @@ def query(self, params: dict) -> BaseArtifact: text_artifacts = [artifact for artifact in artifacts if isinstance(artifact, TextArtifact)] - outputs = self.rag_engine.process(RagContext(query=query, text_chunks=text_artifacts)).outputs + outputs = self._rag_engine.process(RagContext(query=query, text_chunks=text_artifacts)).outputs if len(outputs) > 0: return ListArtifact(outputs) From 0574f0d95c8801ea08dac65684e76090091d968c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 10:46:40 -0700 Subject: [PATCH 52/63] Rename all Tools for consistency --- CHANGELOG.md | 3 + README.md | 8 +- .../src/load_query_and_chat_marqo_1.py | 4 +- docs/examples/src/multi_agent_workflow_1.py | 12 +- .../src/multiple_agent_shared_memory_1.py | 6 +- docs/examples/src/query_webpage_astra_db_1.py | 6 +- docs/examples/src/talk_to_a_pdf_1.py | 6 +- docs/examples/src/talk_to_a_webpage_1.py | 6 +- docs/examples/src/talk_to_redshift_1.py | 6 +- docs/examples/talk-to-a-pdf.md | 2 +- docs/examples/talk-to-a-webpage.md | 2 +- .../src/audio_transcription_drivers_1.py | 4 +- .../drivers/src/embedding_drivers_10.py | 4 +- .../drivers/src/image_generation_drivers_1.py | 4 +- .../drivers/src/image_generation_drivers_2.py | 4 +- .../drivers/src/image_generation_drivers_3.py | 4 +- .../drivers/src/image_generation_drivers_4.py | 4 +- .../drivers/src/image_generation_drivers_5.py | 4 +- .../drivers/src/image_generation_drivers_6.py | 4 +- .../drivers/src/prompt_drivers_10.py | 4 +- .../drivers/src/text_to_speech_drivers_1.py | 4 +- .../drivers/src/text_to_speech_drivers_2.py | 4 +- .../drivers/src/web_scraper_drivers_3.py | 6 +- .../drivers/src/web_search_drivers_2.py | 6 +- .../drivers/structure-run-drivers.md | 2 +- docs/griptape-framework/index.md | 4 +- docs/griptape-framework/misc/src/events_3.py | 4 +- docs/griptape-framework/misc/src/events_4.py | 4 +- docs/griptape-framework/src/index_3.py | 4 +- docs/griptape-framework/src/index_4.py | 4 +- docs/griptape-framework/structures/agents.md | 2 +- .../structures/src/agents_1.py | 4 +- .../structures/src/task_memory_1.py | 6 +- .../structures/src/task_memory_2.py | 6 +- .../structures/src/task_memory_3.py | 6 +- .../structures/src/task_memory_4.py | 6 +- .../structures/src/task_memory_5.py | 6 +- .../structures/src/task_memory_6.py | 8 +- .../structures/src/task_memory_7.py | 4 +- .../structures/src/task_memory_8.py | 8 +- .../structures/src/task_memory_9.py | 4 +- .../structures/src/tasks_16.py | 12 +- .../structures/src/tasks_4.py | 4 +- .../structures/src/tasks_5.py | 4 +- .../structures/task-memory.md | 60 ++++----- docs/griptape-framework/structures/tasks.md | 4 +- docs/griptape-framework/tools/index.md | 2 +- docs/griptape-framework/tools/src/index_1.py | 4 +- ...-client.md => audio-transcription-tool.md} | 4 +- .../{aws-iam-client.md => aws-iam-tool.md} | 12 +- .../{aws-s3-client.md => aws-s3-tool.md} | 12 +- .../{calculator.md => calculator-tool.md} | 8 +- .../{computer.md => computer-tool.md} | 6 +- .../{date-time.md => date-time-tool.md} | 4 +- .../official-tools/email-client.md | 13 -- .../official-tools/email-tool.md | 13 ++ .../{file-manager.md => file-manager-tool.md} | 4 +- .../official-tools/google-cal-client.md | 8 -- .../official-tools/google-calendar-tool.md | 8 ++ ...gle-docs-client.md => google-docs-tool.md} | 10 +- ...e-drive-client.md => google-drive-tool.md} | 10 +- ...e-gmail-client.md => google-gmail-tool.md} | 10 +- .../griptape-cloud-knowledge-base-client.md | 9 -- .../griptape-cloud-knowledge-base-tool.md | 9 ++ ...ge-query-client.md => image-query-tool.md} | 4 +- ...md => inpainting-image-generation-tool.md} | 4 +- .../official-tools/openweather-client.md | 7 - .../official-tools/openweather-tool.md | 7 + ...d => outpainting-image-generation-tool.md} | 4 +- ...ent.md => prompt-image-generation-tool.md} | 4 +- .../{rag-client.md => rag-tool.md} | 8 +- .../{rest-api-client.md => rest-api-tool.md} | 6 +- .../{sql-client.md => sql-tool.md} | 12 +- ...ent_1.py => audio_transcription_tool_1.py} | 4 +- ...{aws_iam_client_1.py => aws_iam_tool_1.py} | 4 +- .../{aws_s3_client_1.py => aws_s3_tool_1.py} | 6 +- .../{calculator_1.py => calculator_tool_1.py} | 6 +- .../src/{computer_1.py => computer_tool_1.py} | 8 +- .../official-tools/src/date_time_1.py | 8 -- .../official-tools/src/date_time_tool_1.py | 8 ++ .../{email_client_1.py => email_tool_1.py} | 4 +- ...le_manager_1.py => file_manager_tool_1.py} | 6 +- ..._client_1.py => google_calendar_tool_1.py} | 10 +- ...docs_client_1.py => google_docs_tool_1.py} | 8 +- ...ive_client_1.py => google_drive_tool_1.py} | 8 +- ...ail_client_1.py => google_gmail_tool_1.py} | 8 +- ...> griptape_cloud_knowledge_base_tool_1.py} | 4 +- ...uery_client_1.py => image_query_tool_1.py} | 6 +- ... => inpainting_image_generation_tool_1.py} | 4 +- ...ther_client_1.py => openweather_tool_1.py} | 4 +- ...=> outpainting_image_generation_tool_1.py} | 4 +- ...1.py => prompt_image_generation_tool_1.py} | 4 +- .../src/{rag_client_1.py => rag_tool_1.py} | 6 +- ...est_api_client_1.py => rest_api_tool_1.py} | 4 +- .../src/{sql_client_1.py => sql_tool_1.py} | 4 +- ...un_client_1.py => structure_run_tool_1.py} | 6 +- .../src/task_memory_client_1.py | 4 - .../official-tools/src/task_memory_tool_1.py | 4 + ...h_client_1.py => text_to_speech_tool_1.py} | 4 +- ...y => variation_image_generation_tool_1.py} | 4 +- ...y => variation_image_generation_tool_2.py} | 6 +- ...ore_client_1.py => vector_store_tool_1.py} | 6 +- .../official-tools/src/web_scraper_1.py | 6 - .../official-tools/src/web_scraper_tool_1.py | 6 + .../{web_search_1.py => web_search_tool_1.py} | 8 +- .../{web_search_2.py => web_search_tool_2.py} | 4 +- ...re-run-client.md => structure-run-tool.md} | 10 +- ...k-memory-client.md => task-memory-tool.md} | 4 +- ...peech-client.md => text-to-speech-tool.md} | 4 +- ....md => variation-image-generation-tool.md} | 6 +- .../official-tools/vector-store-client.md | 7 - .../official-tools/vector-store-tool.md | 9 ++ .../{web-scraper.md => web-scraper-tool.md} | 6 +- .../{web-search.md => web-search-tool.md} | 6 +- griptape/tools/__init__.py | 124 +++++++++--------- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 2 +- .../{aws_iam_client => aws_iam}/__init__.py | 0 .../{aws_iam_client => aws_iam}/manifest.yml | 2 +- .../tools/{aws_iam_client => aws_iam}/tool.py | 4 +- .../{aws_s3_client => aws_s3}/__init__.py | 0 .../{aws_s3_client => aws_s3}/manifest.yml | 2 +- .../tools/{aws_s3_client => aws_s3}/tool.py | 4 +- .../{base_aws_client.py => base_aws_tool.py} | 2 +- ...e_google_client.py => base_google_tool.py} | 2 +- ..._client.py => base_griptape_cloud_tool.py} | 2 +- ...lient.py => base_image_generation_tool.py} | 2 +- griptape/tools/calculator/manifest.yml | 4 +- griptape/tools/calculator/tool.py | 2 +- griptape/tools/computer/manifest.yml | 4 +- griptape/tools/computer/tool.py | 2 +- griptape/tools/date_time/manifest.yml | 4 +- griptape/tools/date_time/tool.py | 2 +- .../tools/{email_client => email}/__init__.py | 0 .../{email_client => email}/manifest.yml | 2 +- .../tools/{email_client => email}/tool.py | 2 +- griptape/tools/file_manager/manifest.yml | 4 +- griptape/tools/file_manager/tool.py | 4 +- .../__init__.py | 0 .../manifest.yml | 0 .../requirements.txt | 0 .../{google_cal => google_calendar}/tool.py | 4 +- griptape/tools/google_docs/tool.py | 4 +- griptape/tools/google_drive/tool.py | 4 +- griptape/tools/google_gmail/manifest.yml | 2 +- griptape/tools/google_gmail/tool.py | 4 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 6 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../requirements.txt | 0 .../tool.py | 12 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../requirements.txt | 0 .../tool.py | 12 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../requirements.txt | 0 .../tool.py | 8 +- .../tools/{rag_client => rag}/__init__.py | 0 .../tools/{rag_client => rag}/manifest.yml | 2 +- .../{rag_client => rag}/requirements.txt | 0 griptape/tools/{rag_client => rag}/tool.py | 2 +- .../{rest_api_client => rest_api}/__init__.py | 0 .../manifest.yml | 2 +- .../{rest_api_client => rest_api}/tool.py | 2 +- .../tools/{sql_client => sql}/__init__.py | 0 .../tools/{sql_client => sql}/manifest.yml | 2 +- griptape/tools/{sql_client => sql}/tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../tool.py | 2 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../requirements.txt | 0 .../tool.py | 12 +- .../__init__.py | 0 .../manifest.yml | 2 +- .../requirements.txt | 0 .../tool.py | 2 +- griptape/tools/web_scraper/manifest.yml | 4 +- griptape/tools/web_scraper/tool.py | 2 +- griptape/tools/web_search/manifest.yml | 6 +- griptape/tools/web_search/tool.py | 2 +- mkdocs.yml | 56 ++++---- tests/integration/tasks/test_tool_task.py | 4 +- tests/integration/tasks/test_toolkit_task.py | 8 +- tests/integration/test_code_blocks.py | 2 +- ..._calculator.py => test_calculator_tool.py} | 4 +- ...e_manager.py => test_file_manager_tool.py} | 4 +- ...ocs_client.py => test_google_docs_tool.py} | 6 +- ...ve_client.py => test_google_drive_tool.py} | 6 +- .../{test_aws_iam.py => test_aws_iam_tool.py} | 12 +- .../{test_aws_s3.py => test_aws_s3_tool.py} | 25 ++-- tests/unit/tools/test_calculator.py | 4 +- tests/unit/tools/test_computer.py | 4 +- tests/unit/tools/test_date_time.py | 12 +- ...est_email_client.py => test_email_tool.py} | 12 +- tests/unit/tools/test_file_manager.py | 22 ++-- ...ocs_client.py => test_google_docs_tool.py} | 6 +- ...ve_client.py => test_google_drive_tool.py} | 18 +-- ...il_client.py => test_google_gmail_tool.py} | 6 +- ...est_griptape_cloud_knowledge_base_tool.py} | 20 +-- ... test_inpainting_image_generation_tool.py} | 14 +- ...her_client.py => test_openweather_tool.py} | 8 +- ... test_outpainting_image_variation_tool.py} | 14 +- ...y => test_prompt_image_generation_tool.py} | 12 +- .../{test_rag_client.py => test_rag_tool.py} | 6 +- ...st_api_client.py => test_rest_api_tool.py} | 4 +- .../{test_sql_client.py => test_sql_tool.py} | 8 +- ...n_client.py => test_structure_run_tool.py} | 6 +- ...ory_client.py => test_task_memory_tool.py} | 6 +- ..._client.py => test_text_to_speech_tool.py} | 12 +- ...n_client.py => test_transcription_tool.py} | 10 +- ...> test_variation_image_generation_tool.py} | 14 +- ...re_client.py => test_vector_store_tool.py} | 12 +- tests/unit/tools/test_web_scraper.py | 4 +- tests/unit/tools/test_web_search.py | 6 +- 233 files changed, 700 insertions(+), 696 deletions(-) rename docs/griptape-tools/official-tools/{audio-transcription-client.md => audio-transcription-tool.md} (90%) rename docs/griptape-tools/official-tools/{aws-iam-client.md => aws-iam-tool.md} (89%) rename docs/griptape-tools/official-tools/{aws-s3-client.md => aws-s3-tool.md} (90%) rename docs/griptape-tools/official-tools/{calculator.md => calculator-tool.md} (82%) rename docs/griptape-tools/official-tools/{computer.md => computer-tool.md} (97%) rename docs/griptape-tools/official-tools/{date-time.md => date-time-tool.md} (93%) delete mode 100644 docs/griptape-tools/official-tools/email-client.md create mode 100644 docs/griptape-tools/official-tools/email-tool.md rename docs/griptape-tools/official-tools/{file-manager.md => file-manager-tool.md} (94%) delete mode 100644 docs/griptape-tools/official-tools/google-cal-client.md create mode 100644 docs/griptape-tools/official-tools/google-calendar-tool.md rename docs/griptape-tools/official-tools/{google-docs-client.md => google-docs-tool.md} (79%) rename docs/griptape-tools/official-tools/{google-drive-client.md => google-drive-tool.md} (77%) rename docs/griptape-tools/official-tools/{google-gmail-client.md => google-gmail-tool.md} (78%) delete mode 100644 docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md create mode 100644 docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-tool.md rename docs/griptape-tools/official-tools/{image-query-client.md => image-query-tool.md} (58%) rename docs/griptape-tools/official-tools/{inpainting-image-generation-client.md => inpainting-image-generation-tool.md} (87%) delete mode 100644 docs/griptape-tools/official-tools/openweather-client.md create mode 100644 docs/griptape-tools/official-tools/openweather-tool.md rename docs/griptape-tools/official-tools/{outpainting-image-generation-client.md => outpainting-image-generation-tool.md} (86%) rename docs/griptape-tools/official-tools/{prompt-image-generation-client.md => prompt-image-generation-tool.md} (73%) rename docs/griptape-tools/official-tools/{rag-client.md => rag-tool.md} (86%) rename docs/griptape-tools/official-tools/{rest-api-client.md => rest-api-tool.md} (50%) rename docs/griptape-tools/official-tools/{sql-client.md => sql-tool.md} (90%) rename docs/griptape-tools/official-tools/src/{audio_transcription_client_1.py => audio_transcription_tool_1.py} (79%) rename docs/griptape-tools/official-tools/src/{aws_iam_client_1.py => aws_iam_tool_1.py} (72%) rename docs/griptape-tools/official-tools/src/{aws_s3_client_1.py => aws_s3_tool_1.py} (50%) rename docs/griptape-tools/official-tools/src/{calculator_1.py => calculator_tool_1.py} (56%) rename docs/griptape-tools/official-tools/src/{computer_1.py => computer_tool_1.py} (77%) delete mode 100644 docs/griptape-tools/official-tools/src/date_time_1.py create mode 100644 docs/griptape-tools/official-tools/src/date_time_tool_1.py rename docs/griptape-tools/official-tools/src/{email_client_1.py => email_tool_1.py} (79%) rename docs/griptape-tools/official-tools/src/{file_manager_1.py => file_manager_tool_1.py} (73%) rename docs/griptape-tools/official-tools/src/{google_cal_client_1.py => google_calendar_tool_1.py} (79%) rename docs/griptape-tools/official-tools/src/{google_docs_client_1.py => google_docs_tool_1.py} (85%) rename docs/griptape-tools/official-tools/src/{google_drive_client_1.py => google_drive_tool_1.py} (84%) rename docs/griptape-tools/official-tools/src/{google_gmail_client_1.py => google_gmail_tool_1.py} (85%) rename docs/griptape-tools/official-tools/src/{griptape_cloud_knowledge_base_client_1.py => griptape_cloud_knowledge_base_tool_1.py} (75%) rename docs/griptape-tools/official-tools/src/{image_query_client_1.py => image_query_tool_1.py} (80%) rename docs/griptape-tools/official-tools/src/{inpainting_image_generation_client_1.py => inpainting_image_generation_tool_1.py} (90%) rename docs/griptape-tools/official-tools/src/{openweather_client_1.py => openweather_tool_1.py} (75%) rename docs/griptape-tools/official-tools/src/{outpainting_image_generation_client_1.py => outpainting_image_generation_tool_1.py} (89%) rename docs/griptape-tools/official-tools/src/{prompt_image_generation_client_1.py => prompt_image_generation_tool_1.py} (89%) rename docs/griptape-tools/official-tools/src/{rag_client_1.py => rag_tool_1.py} (93%) rename docs/griptape-tools/official-tools/src/{rest_api_client_1.py => rest_api_tool_1.py} (98%) rename docs/griptape-tools/official-tools/src/{sql_client_1.py => sql_tool_1.py} (91%) rename docs/griptape-tools/official-tools/src/{structure_run_client_1.py => structure_run_tool_1.py} (83%) delete mode 100644 docs/griptape-tools/official-tools/src/task_memory_client_1.py create mode 100644 docs/griptape-tools/official-tools/src/task_memory_tool_1.py rename docs/griptape-tools/official-tools/src/{text_to_speech_client_1.py => text_to_speech_tool_1.py} (81%) rename docs/griptape-tools/official-tools/src/{variation_image_generation_client_1.py => variation_image_generation_tool_1.py} (90%) rename docs/griptape-tools/official-tools/src/{variation_image_generation_client_2.py => variation_image_generation_tool_2.py} (89%) rename docs/griptape-tools/official-tools/src/{vector_store_client_1.py => vector_store_tool_1.py} (81%) delete mode 100644 docs/griptape-tools/official-tools/src/web_scraper_1.py create mode 100644 docs/griptape-tools/official-tools/src/web_scraper_tool_1.py rename docs/griptape-tools/official-tools/src/{web_search_1.py => web_search_tool_1.py} (71%) rename docs/griptape-tools/official-tools/src/{web_search_2.py => web_search_tool_2.py} (91%) rename docs/griptape-tools/official-tools/{structure-run-client.md => structure-run-tool.md} (90%) rename docs/griptape-tools/official-tools/{task-memory-client.md => task-memory-tool.md} (74%) rename docs/griptape-tools/official-tools/{text-to-speech-client.md => text-to-speech-tool.md} (72%) rename docs/griptape-tools/official-tools/{variation-image-generation-client.md => variation-image-generation-tool.md} (83%) delete mode 100644 docs/griptape-tools/official-tools/vector-store-client.md create mode 100644 docs/griptape-tools/official-tools/vector-store-tool.md rename docs/griptape-tools/official-tools/{web-scraper.md => web-scraper-tool.md} (96%) rename docs/griptape-tools/official-tools/{web-search.md => web-search-tool.md} (97%) rename griptape/tools/{audio_transcription_client => audio_transcription}/__init__.py (100%) rename griptape/tools/{audio_transcription_client => audio_transcription}/manifest.yml (84%) rename griptape/tools/{audio_transcription_client => audio_transcription}/tool.py (98%) rename griptape/tools/{aws_iam_client => aws_iam}/__init__.py (100%) rename griptape/tools/{aws_iam_client => aws_iam}/manifest.yml (86%) rename griptape/tools/{aws_iam_client => aws_iam}/tool.py (97%) rename griptape/tools/{aws_s3_client => aws_s3}/__init__.py (100%) rename griptape/tools/{aws_s3_client => aws_s3}/manifest.yml (86%) rename griptape/tools/{aws_s3_client => aws_s3}/tool.py (99%) rename griptape/tools/{base_aws_client.py => base_aws_tool.py} (95%) rename griptape/tools/{base_google_client.py => base_google_tool.py} (98%) rename griptape/tools/{base_griptape_cloud_client.py => base_griptape_cloud_tool.py} (93%) rename griptape/tools/{base_image_generation_client.py => base_image_generation_tool.py} (88%) rename griptape/tools/{email_client => email}/__init__.py (100%) rename griptape/tools/{email_client => email}/manifest.yml (87%) rename griptape/tools/{email_client => email}/tool.py (99%) rename griptape/tools/{google_cal => google_calendar}/__init__.py (100%) rename griptape/tools/{google_cal => google_calendar}/manifest.yml (100%) rename griptape/tools/{google_cal => google_calendar}/requirements.txt (100%) rename griptape/tools/{google_cal => google_calendar}/tool.py (98%) rename griptape/tools/{griptape_cloud_knowledge_base_client => griptape_cloud_knowledge_base}/__init__.py (100%) rename griptape/tools/{griptape_cloud_knowledge_base_client => griptape_cloud_knowledge_base}/manifest.yml (78%) rename griptape/tools/{griptape_cloud_knowledge_base_client => griptape_cloud_knowledge_base}/tool.py (91%) rename griptape/tools/{image_query_client => image_query}/__init__.py (100%) rename griptape/tools/{image_query_client => image_query}/manifest.yml (86%) rename griptape/tools/{image_query_client => image_query}/tool.py (99%) rename griptape/tools/{inpainting_image_generation_client => inpainting_image_generation}/__init__.py (100%) rename griptape/tools/{inpainting_image_generation_client => inpainting_image_generation}/manifest.yml (79%) rename griptape/tools/{inpainting_image_generation_client => inpainting_image_generation}/requirements.txt (100%) rename griptape/tools/{inpainting_image_generation_client => inpainting_image_generation}/tool.py (93%) rename griptape/tools/{openweather_client => openweather}/__init__.py (100%) rename griptape/tools/{openweather_client => openweather}/manifest.yml (86%) rename griptape/tools/{openweather_client => openweather}/tool.py (99%) rename griptape/tools/{outpainting_image_generation_client => outpainting_image_generation}/__init__.py (100%) rename griptape/tools/{outpainting_image_generation_client => outpainting_image_generation}/manifest.yml (79%) rename griptape/tools/{outpainting_image_generation_client => outpainting_image_generation}/requirements.txt (100%) rename griptape/tools/{outpainting_image_generation_client => outpainting_image_generation}/tool.py (93%) rename griptape/tools/{prompt_image_generation_client => prompt_image_generation}/__init__.py (100%) rename griptape/tools/{prompt_image_generation_client => prompt_image_generation}/manifest.yml (80%) rename griptape/tools/{prompt_image_generation_client => prompt_image_generation}/requirements.txt (100%) rename griptape/tools/{prompt_image_generation_client => prompt_image_generation}/tool.py (87%) rename griptape/tools/{rag_client => rag}/__init__.py (100%) rename griptape/tools/{rag_client => rag}/manifest.yml (88%) rename griptape/tools/{rag_client => rag}/requirements.txt (100%) rename griptape/tools/{rag_client => rag}/tool.py (97%) rename griptape/tools/{rest_api_client => rest_api}/__init__.py (100%) rename griptape/tools/{rest_api_client => rest_api}/manifest.yml (87%) rename griptape/tools/{rest_api_client => rest_api}/tool.py (99%) rename griptape/tools/{sql_client => sql}/__init__.py (100%) rename griptape/tools/{sql_client => sql}/manifest.yml (88%) rename griptape/tools/{sql_client => sql}/tool.py (98%) rename griptape/tools/{structure_run_client => structure_run}/__init__.py (100%) rename griptape/tools/{structure_run_client => structure_run}/manifest.yml (83%) rename griptape/tools/{structure_run_client => structure_run}/tool.py (97%) rename griptape/tools/{task_memory_client => task_memory}/__init__.py (100%) rename griptape/tools/{task_memory_client => task_memory}/manifest.yml (85%) rename griptape/tools/{task_memory_client => task_memory}/tool.py (98%) rename griptape/tools/{text_to_speech_client => text_to_speech}/__init__.py (100%) rename griptape/tools/{text_to_speech_client => text_to_speech}/manifest.yml (83%) rename griptape/tools/{text_to_speech_client => text_to_speech}/tool.py (95%) rename griptape/tools/{variation_image_generation_client => variation_image_generation}/__init__.py (100%) rename griptape/tools/{variation_image_generation_client => variation_image_generation}/manifest.yml (79%) rename griptape/tools/{variation_image_generation_client => variation_image_generation}/requirements.txt (100%) rename griptape/tools/{variation_image_generation_client => variation_image_generation}/tool.py (91%) rename griptape/tools/{vector_store_client => vector_store}/__init__.py (100%) rename griptape/tools/{vector_store_client => vector_store}/manifest.yml (85%) rename griptape/tools/{vector_store_client => vector_store}/requirements.txt (100%) rename griptape/tools/{vector_store_client => vector_store}/tool.py (98%) rename tests/integration/tools/{test_calculator.py => test_calculator_tool.py} (73%) rename tests/integration/tools/{test_file_manager.py => test_file_manager_tool.py} (79%) rename tests/integration/tools/{test_google_docs_client.py => test_google_docs_tool.py} (95%) rename tests/integration/tools/{test_google_drive_client.py => test_google_drive_tool.py} (94%) rename tests/unit/tools/{test_aws_iam.py => test_aws_iam_tool.py} (56%) rename tests/unit/tools/{test_aws_s3.py => test_aws_s3_tool.py} (58%) rename tests/unit/tools/{test_email_client.py => test_email_tool.py} (93%) rename tests/unit/tools/{test_google_docs_client.py => test_google_docs_tool.py} (84%) rename tests/unit/tools/{test_google_drive_client.py => test_google_drive_tool.py} (68%) rename tests/unit/tools/{test_google_gmail_client.py => test_google_gmail_tool.py} (61%) rename tests/unit/tools/{test_griptape_cloud_knowledge_base_client.py => test_griptape_cloud_knowledge_base_tool.py} (83%) rename tests/unit/tools/{test_inpainting_image_generation_client.py => test_inpainting_image_generation_tool.py} (87%) rename tests/unit/tools/{test_openweather_client.py => test_openweather_tool.py} (92%) rename tests/unit/tools/{test_outpainting_image_variation_client.py => test_outpainting_image_variation_tool.py} (87%) rename tests/unit/tools/{test_prompt_image_generation_client.py => test_prompt_image_generation_tool.py} (77%) rename tests/unit/tools/{test_rag_client.py => test_rag_tool.py} (72%) rename tests/unit/tools/{test_rest_api_client.py => test_rest_api_tool.py} (92%) rename tests/unit/tools/{test_sql_client.py => test_sql_tool.py} (87%) rename tests/unit/tools/{test_structure_run_client.py => test_structure_run_tool.py} (82%) rename tests/unit/tools/{test_task_memory_client.py => test_task_memory_tool.py} (80%) rename tests/unit/tools/{test_text_to_speech_client.py => test_text_to_speech_tool.py} (74%) rename tests/unit/tools/{test_transcription_client.py => test_transcription_tool.py} (83%) rename tests/unit/tools/{test_variation_image_generation_client.py => test_variation_image_generation_tool.py} (88%) rename tests/unit/tools/{test_vector_store_client.py => test_vector_store_tool.py} (78%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b7b08f36..ef91010ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. - **BREAKING**: Removed before and after response modules from `ResponseRagStage`. - **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. +- **BREAKING**: Dropped `Client` from all Tool names for better naming consistency. +- **BREAKING**: Dropped `_client` suffix from all Tool packages. +- **BREAKING**: Added `Tool` suffix to all Tool names for better naming consistency. - Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/README.md b/README.md index 42df8fbd6..1f6d6d883 100644 --- a/README.md +++ b/README.md @@ -89,14 +89,14 @@ With Griptape, you can create Structures, such as Agents, Pipelines, and Workflo ```python from griptape.structures import Agent -from griptape.tools import WebScraper, FileManager, TaskMemoryClient +from griptape.tools import WebScraperTool, FileManagerTool, TaskMemoryTool agent = Agent( input="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.", tools=[ - WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=True), - FileManager() + WebScraperTool(off_prompt=True), + TaskMemoryTool(off_prompt=True), + FileManagerTool() ] ) agent.run("https://griptape.ai", "griptape.txt") diff --git a/docs/examples/src/load_query_and_chat_marqo_1.py b/docs/examples/src/load_query_and_chat_marqo_1.py index ac3ffca0d..013a0264f 100644 --- a/docs/examples/src/load_query_and_chat_marqo_1.py +++ b/docs/examples/src/load_query_and_chat_marqo_1.py @@ -5,7 +5,7 @@ from griptape.drivers import MarqoVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import VectorStoreClient +from griptape.tools import VectorStoreTool # Define the namespace namespace = "griptape-ai" @@ -19,7 +19,7 @@ ) # Initialize the knowledge base tool -vector_store_tool = VectorStoreClient( +vector_store_tool = VectorStoreTool( description="Contains information about the Griptape Framework from www.griptape.ai", vector_store_driver=vector_store, ) diff --git a/docs/examples/src/multi_agent_workflow_1.py b/docs/examples/src/multi_agent_workflow_1.py index 311fa22a2..137a5075f 100644 --- a/docs/examples/src/multi_agent_workflow_1.py +++ b/docs/examples/src/multi_agent_workflow_1.py @@ -5,9 +5,9 @@ from griptape.structures import Agent, Workflow from griptape.tasks import PromptTask, StructureRunTask from griptape.tools import ( - TaskMemoryClient, - WebScraper, - WebSearch, + TaskMemoryTool, + WebScraperTool, + WebSearchTool, ) WRITERS = [ @@ -29,16 +29,16 @@ def build_researcher() -> Agent: researcher = Agent( id="researcher", tools=[ - WebSearch( + WebSearchTool( web_search_driver=GoogleWebSearchDriver( api_key=os.environ["GOOGLE_API_KEY"], search_id=os.environ["GOOGLE_API_SEARCH_ID"], ), ), - WebScraper( + WebScraperTool( off_prompt=True, ), - TaskMemoryClient(off_prompt=False), + TaskMemoryTool(off_prompt=False), ], rulesets=[ Ruleset( diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index 118684d37..e156e531a 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -3,7 +3,7 @@ from griptape.config import AzureOpenAiDriverConfig, config from griptape.drivers import AzureMongoDbVectorStoreDriver, AzureOpenAiEmbeddingDriver from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -41,12 +41,12 @@ loader = Agent( tools=[ - WebScraper(off_prompt=True), + WebScraperTool(off_prompt=True), ], ) asker = Agent( tools=[ - TaskMemoryClient(off_prompt=False), + TaskMemoryTool(off_prompt=False), ], meta_memory=loader.meta_memory, task_memory=loader.task_memory, diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index b5d2b0a01..6a96a813e 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -14,7 +14,7 @@ from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import RagClient, TaskMemoryClient +from griptape.tools import RagTool, TaskMemoryTool namespace = "datastax_blog" input_blogpost = "www.datastax.com/blog/indexing-all-of-wikipedia-on-a-laptop" @@ -49,9 +49,9 @@ raise Exception(artifacts.value) vector_store_driver.upsert_text_artifacts({namespace: artifacts}) -vector_store_tool = RagClient( +rag_tool = RagTool( description="A DataStax blog post", rag_engine=engine, ) -agent = Agent(tools=[vector_store_tool, TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[rag_tool, TaskMemoryTool(off_prompt=False)]) agent.run("What engine made possible to index such an amount of data, " "and what kind of tuning was required?") diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index 2ac184a22..b4ab72029 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -7,7 +7,7 @@ from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.loaders import PdfLoader from griptape.structures import Agent -from griptape.tools import RagClient +from griptape.tools import RagTool from griptape.utils import Chat namespace = "attention" @@ -25,7 +25,7 @@ response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) -vector_store_tool = RagClient( +rag_tool = RagTool( description="Contains information about the Attention Is All You Need paper. " "Use it to answer any related questions.", rag_engine=engine, @@ -37,6 +37,6 @@ vector_store.upsert_text_artifacts({namespace: artifacts}) -agent = Agent(tools=[vector_store_tool]) +agent = Agent(tools=[rag_tool]) Chat(agent).start() diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index d24eb9427..0412ed977 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -6,7 +6,7 @@ from griptape.loaders import WebLoader from griptape.rules import Rule, Ruleset from griptape.structures import Agent -from griptape.tools import RagClient +from griptape.tools import RagTool from griptape.utils import Chat namespace = "physics-wiki" @@ -33,7 +33,7 @@ vector_store_driver.upsert_text_artifacts({namespace: artifacts}) -vector_store_tool = RagClient( +rag_tool = RagTool( description="Contains information about physics. " "Use it to answer any physics-related questions.", rag_engine=engine, ) @@ -45,7 +45,7 @@ rules=[Rule("Always introduce yourself as a physics tutor"), Rule("Be truthful. Only discuss physics.")], ) ], - tools=[vector_store_tool], + tools=[rag_tool], ) Chat(agent).start() diff --git a/docs/examples/src/talk_to_redshift_1.py b/docs/examples/src/talk_to_redshift_1.py index 7354a77cc..bd4b57f4f 100644 --- a/docs/examples/src/talk_to_redshift_1.py +++ b/docs/examples/src/talk_to_redshift_1.py @@ -6,7 +6,7 @@ from griptape.loaders import SqlLoader from griptape.rules import Rule, Ruleset from griptape.structures import Agent -from griptape.tools import FileManager, SqlClient +from griptape.tools import FileManagerTool, SqlTool from griptape.utils import Chat session = boto3.Session() @@ -19,7 +19,7 @@ ) ) -sql_tool = SqlClient( +sql_tool = SqlTool( sql_loader=sql_loader, table_name="people", table_description="contains information about tech industry professionals", @@ -27,7 +27,7 @@ ) agent = Agent( - tools=[sql_tool, FileManager()], + tools=[sql_tool, FileManagerTool()], rulesets=[ Ruleset( name="HumansOrg Agent", diff --git a/docs/examples/talk-to-a-pdf.md b/docs/examples/talk-to-a-pdf.md index 2e0743ae4..0524359d5 100644 --- a/docs/examples/talk-to-a-pdf.md +++ b/docs/examples/talk-to-a-pdf.md @@ -1,4 +1,4 @@ -This example demonstrates how to vectorize a PDF of the [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf) paper and setup a Griptape agent with rules and the [VectorStoreClient](../reference/griptape/tools/vector_store_client/tool.md) tool to use it during conversations. +This example demonstrates how to vectorize a PDF of the [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf) paper and setup a Griptape agent with rules and the [VectorStoreTool](../reference/griptape/tools/vector_store/tool.md) tool to use it during conversations. ```python --8<-- "docs/examples/src/talk_to_a_pdf_1.py" diff --git a/docs/examples/talk-to-a-webpage.md b/docs/examples/talk-to-a-webpage.md index 2bee2c9ba..e4632d401 100644 --- a/docs/examples/talk-to-a-webpage.md +++ b/docs/examples/talk-to-a-webpage.md @@ -1,4 +1,4 @@ -This example demonstrates how to vectorize a webpage and setup a Griptape agent with rules and the [RagClient](../reference/griptape/tools/rag_client/tool.md) tool to use it during conversations. +This example demonstrates how to vectorize a webpage and setup a Griptape agent with rules and the [RagClient](../reference/griptape/tools/rag/tool.md) tool to use it during conversations. ```python --8<-- "docs/examples/src/talk_to_a_webpage_1.py" diff --git a/docs/griptape-framework/drivers/src/audio_transcription_drivers_1.py b/docs/griptape-framework/drivers/src/audio_transcription_drivers_1.py index e4415e1d4..16013638e 100644 --- a/docs/griptape-framework/drivers/src/audio_transcription_drivers_1.py +++ b/docs/griptape-framework/drivers/src/audio_transcription_drivers_1.py @@ -1,11 +1,11 @@ from griptape.drivers import OpenAiAudioTranscriptionDriver from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent -from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient +from griptape.tools.audio_transcription.tool import AudioTranscriptionTool driver = OpenAiAudioTranscriptionDriver(model="whisper-1") -tool = AudioTranscriptionClient( +tool = AudioTranscriptionTool( off_prompt=False, engine=AudioTranscriptionEngine( audio_transcription_driver=driver, diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index 3ef816b29..4f7560c99 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -4,7 +4,7 @@ VoyageAiEmbeddingDriver, ) from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), @@ -12,7 +12,7 @@ ) agent = Agent( - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraperTool(off_prompt=True), TaskMemoryTool(off_prompt=False)], ) agent.run("based on https://www.griptape.ai/, tell me what Griptape is") diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_1.py b/docs/griptape-framework/drivers/src/image_generation_drivers_1.py index 637b6ec7a..b20a42265 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_1.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_1.py @@ -1,7 +1,7 @@ from griptape.drivers import OpenAiImageGenerationDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool driver = OpenAiImageGenerationDriver( model="dall-e-2", @@ -11,7 +11,7 @@ agent = Agent( tools=[ - PromptImageGenerationClient(engine=engine), + PromptImageGenerationTool(engine=engine), ] ) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_2.py b/docs/griptape-framework/drivers/src/image_generation_drivers_2.py index 63663ae79..ab07fcb27 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_2.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_2.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool model_driver = BedrockStableDiffusionImageGenerationModelDriver( style_preset="pixel-art", @@ -16,7 +16,7 @@ agent = Agent( tools=[ - PromptImageGenerationClient(engine=engine), + PromptImageGenerationTool(engine=engine), ] ) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_3.py b/docs/griptape-framework/drivers/src/image_generation_drivers_3.py index 630d80f20..b8c63589d 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_3.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_3.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockTitanImageGenerationModelDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool model_driver = BedrockTitanImageGenerationModelDriver() @@ -14,7 +14,7 @@ agent = Agent( tools=[ - PromptImageGenerationClient(engine=engine), + PromptImageGenerationTool(engine=engine), ] ) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_4.py b/docs/griptape-framework/drivers/src/image_generation_drivers_4.py index 428f470cf..f1bc06200 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_4.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_4.py @@ -3,7 +3,7 @@ from griptape.drivers import AzureOpenAiImageGenerationDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool driver = AzureOpenAiImageGenerationDriver( model="dall-e-3", @@ -16,7 +16,7 @@ agent = Agent( tools=[ - PromptImageGenerationClient(engine=engine), + PromptImageGenerationTool(engine=engine), ] ) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_5.py b/docs/griptape-framework/drivers/src/image_generation_drivers_5.py index 06157107f..46173a232 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_5.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_5.py @@ -3,7 +3,7 @@ from griptape.drivers import LeonardoImageGenerationDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool driver = LeonardoImageGenerationDriver( model=os.environ["LEONARDO_MODEL_ID"], @@ -16,7 +16,7 @@ agent = Agent( tools=[ - PromptImageGenerationClient(engine=engine), + PromptImageGenerationTool(engine=engine), ] ) diff --git a/docs/griptape-framework/drivers/src/image_generation_drivers_6.py b/docs/griptape-framework/drivers/src/image_generation_drivers_6.py index feb8a54d7..d295da4ff 100644 --- a/docs/griptape-framework/drivers/src/image_generation_drivers_6.py +++ b/docs/griptape-framework/drivers/src/image_generation_drivers_6.py @@ -1,7 +1,7 @@ from griptape.drivers import OpenAiImageGenerationDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool driver = OpenAiImageGenerationDriver( model="dall-e-2", @@ -12,7 +12,7 @@ agent = Agent( tools=[ - PromptImageGenerationClient(engine=engine), + PromptImageGenerationTool(engine=engine), ] ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_10.py b/docs/griptape-framework/drivers/src/prompt_drivers_10.py index 04e2acb35..1d757668c 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_10.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_10.py @@ -1,11 +1,11 @@ from griptape.drivers import OllamaPromptDriver from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool agent = Agent( prompt_driver=OllamaPromptDriver( model="llama3.1", ), - tools=[Calculator()], + tools=[CalculatorTool()], ) agent.run("What is (192 + 12) ^ 4") diff --git a/docs/griptape-framework/drivers/src/text_to_speech_drivers_1.py b/docs/griptape-framework/drivers/src/text_to_speech_drivers_1.py index c6a03b80d..376113d63 100644 --- a/docs/griptape-framework/drivers/src/text_to_speech_drivers_1.py +++ b/docs/griptape-framework/drivers/src/text_to_speech_drivers_1.py @@ -3,7 +3,7 @@ from griptape.drivers import ElevenLabsTextToSpeechDriver from griptape.engines import TextToSpeechEngine from griptape.structures import Agent -from griptape.tools.text_to_speech_client.tool import TextToSpeechClient +from griptape.tools.text_to_speech.tool import TextToSpeechTool driver = ElevenLabsTextToSpeechDriver( api_key=os.environ["ELEVEN_LABS_API_KEY"], @@ -11,7 +11,7 @@ voice="Matilda", ) -tool = TextToSpeechClient( +tool = TextToSpeechTool( engine=TextToSpeechEngine( text_to_speech_driver=driver, ), diff --git a/docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py b/docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py index 99927a390..4a6323b1b 100644 --- a/docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py +++ b/docs/griptape-framework/drivers/src/text_to_speech_drivers_2.py @@ -1,11 +1,11 @@ from griptape.drivers import OpenAiTextToSpeechDriver from griptape.engines import TextToSpeechEngine from griptape.structures import Agent -from griptape.tools.text_to_speech_client.tool import TextToSpeechClient +from griptape.tools.text_to_speech.tool import TextToSpeechTool driver = OpenAiTextToSpeechDriver() -tool = TextToSpeechClient( +tool = TextToSpeechTool( engine=TextToSpeechEngine( text_to_speech_driver=driver, ), diff --git a/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py b/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py index d9fe11e85..d4c77c2ae 100644 --- a/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py +++ b/docs/griptape-framework/drivers/src/web_scraper_drivers_3.py @@ -1,15 +1,15 @@ from griptape.drivers import MarkdownifyWebScraperDriver from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool agent = Agent( tools=[ - WebScraper( + WebScraperTool( web_loader=WebLoader(web_scraper_driver=MarkdownifyWebScraperDriver(timeout=1000)), off_prompt=True, ), - TaskMemoryClient(off_prompt=False), + TaskMemoryTool(off_prompt=False), ], ) agent.run("List all email addresses on griptape.ai in a flat numbered markdown list.") diff --git a/docs/griptape-framework/drivers/src/web_search_drivers_2.py b/docs/griptape-framework/drivers/src/web_search_drivers_2.py index 5cde1a9a8..a33b8c00c 100644 --- a/docs/griptape-framework/drivers/src/web_search_drivers_2.py +++ b/docs/griptape-framework/drivers/src/web_search_drivers_2.py @@ -2,17 +2,17 @@ from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebSearch +from griptape.tools import TaskMemoryTool, WebSearchTool agent = Agent( tools=[ - WebSearch( + WebSearchTool( web_search_driver=GoogleWebSearchDriver( api_key=os.environ["GOOGLE_API_KEY"], search_id=os.environ["GOOGLE_API_SEARCH_ID"], ), ), - TaskMemoryClient(off_prompt=False), + TaskMemoryTool(off_prompt=False), ], ) agent.run("Give me some websites with information about AI frameworks.") diff --git a/docs/griptape-framework/drivers/structure-run-drivers.md b/docs/griptape-framework/drivers/structure-run-drivers.md index 00890cb4a..1f57ff57e 100644 --- a/docs/griptape-framework/drivers/structure-run-drivers.md +++ b/docs/griptape-framework/drivers/structure-run-drivers.md @@ -5,7 +5,7 @@ search: ## Overview Structure Run Drivers can be used to run Griptape Structures in a variety of runtime environments. -When combined with the [Structure Run Task](../../griptape-framework/structures/tasks.md#structure-run-task) or [Structure Run Client](../../griptape-tools/official-tools/structure-run-client.md) you can create complex, multi-agent pipelines that span multiple runtime environments. +When combined with the [Structure Run Task](../../griptape-framework/structures/tasks.md#structure-run-task) or [Structure Run Tool](../../griptape-tools/official-tools/structure-run-tool.md) you can create complex, multi-agent pipelines that span multiple runtime environments. ## Structure Run Drivers diff --git a/docs/griptape-framework/index.md b/docs/griptape-framework/index.md index 52805897d..ff3670e6e 100644 --- a/docs/griptape-framework/index.md +++ b/docs/griptape-framework/index.md @@ -103,7 +103,7 @@ Here is the chain of thought from the Agent. Notice where it realizes it can use Actions: [ { "tag": "call_RTRm7JLFV0F73dCVPmoWVJqO", - "name": "Calculator", + "name": "CalculatorTool", "path": "calculate", "input": { "values": { @@ -144,7 +144,7 @@ Agents are great for getting started, but they are intentionally limited to a si [09/08/23 10:02:53] INFO Subtask 8023e3d257274df29065b22e736faca8 Thought: Now that the webpage content is stored in memory, I can use the TaskMemory tool's summarize activity to summarize the content. - Action: {"name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "39ca67bbe26b4e1584193b87ed82170d"}}} + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "39ca67bbe26b4e1584193b87ed82170d"}}} [09/08/23 10:02:57] INFO Subtask 8023e3d257274df29065b22e736faca8 Response: Griptape is an open source framework that allows developers to build and deploy AI applications using large language models (LLMs). It provides the ability to create conversational and event-driven apps that diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index bae8b8224..2a567debe 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -4,7 +4,7 @@ from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool event_bus.add_event_listeners( [ @@ -20,7 +20,7 @@ ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True), - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraperTool(off_prompt=True), TaskMemoryTool(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/misc/src/events_4.py b/docs/griptape-framework/misc/src/events_4.py index f5523cb11..a66e77b1d 100644 --- a/docs/griptape-framework/misc/src/events_4.py +++ b/docs/griptape-framework/misc/src/events_4.py @@ -1,13 +1,13 @@ from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool from griptape.utils import Stream pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraperTool(off_prompt=True), TaskMemoryTool(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/src/index_3.py b/docs/griptape-framework/src/index_3.py index 043fc75f7..ac153b15f 100644 --- a/docs/griptape-framework/src/index_3.py +++ b/docs/griptape-framework/src/index_3.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool -calculator = Calculator() +calculator = CalculatorTool() agent = Agent(tools=[calculator]) diff --git a/docs/griptape-framework/src/index_4.py b/docs/griptape-framework/src/index_4.py index 0bb345438..dd07280ce 100644 --- a/docs/griptape-framework/src/index_4.py +++ b/docs/griptape-framework/src/index_4.py @@ -1,7 +1,7 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManagerTool, TaskMemoryTool, WebScraperTool # Pipelines represent sequences of tasks. pipeline = Pipeline(conversation_memory=ConversationMemory()) @@ -11,7 +11,7 @@ ToolkitTask( "{{ args[0] }}", # Add tools for web scraping, and file management - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraperTool(off_prompt=True), FileManagerTool(off_prompt=True), TaskMemoryTool(off_prompt=False)], ), # Augment `input` from the previous task. PromptTask("Say the following in spanish: {{ parent_output }}"), diff --git a/docs/griptape-framework/structures/agents.md b/docs/griptape-framework/structures/agents.md index 2f28b336c..376e7288a 100644 --- a/docs/griptape-framework/structures/agents.md +++ b/docs/griptape-framework/structures/agents.md @@ -27,7 +27,7 @@ You can access the final output of the Agent by using the [output](../../referen Actions: [ { "tag": "call_ZSCH6vNoycOgtPJH2DL2U9ji", - "name": "Calculator", + "name": "CalculatorTool", "path": "calculate", "input": { "values": { diff --git a/docs/griptape-framework/structures/src/agents_1.py b/docs/griptape-framework/structures/src/agents_1.py index 20d04004d..9ce4aec22 100644 --- a/docs/griptape-framework/structures/src/agents_1.py +++ b/docs/griptape-framework/structures/src/agents_1.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool -agent = Agent(input="Calculate the following: {{ args[0] }}", tools=[Calculator()]) +agent = Agent(input="Calculate the following: {{ args[0] }}", tools=[CalculatorTool()]) agent.run("what's 13^7?") print("Answer:", agent.output) diff --git a/docs/griptape-framework/structures/src/task_memory_1.py b/docs/griptape-framework/structures/src/task_memory_1.py index 90940a611..e8cfbd8ac 100644 --- a/docs/griptape-framework/structures/src/task_memory_1.py +++ b/docs/griptape-framework/structures/src/task_memory_1.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool -# Create an agent with the Calculator tool -agent = Agent(tools=[Calculator(off_prompt=False)]) +# Create an agent with the CalculatorTool tool +agent = Agent(tools=[CalculatorTool(off_prompt=False)]) agent.run("What is 10 raised to the power of 5?") diff --git a/docs/griptape-framework/structures/src/task_memory_2.py b/docs/griptape-framework/structures/src/task_memory_2.py index dcc5b4d5c..9ff24e1ff 100644 --- a/docs/griptape-framework/structures/src/task_memory_2.py +++ b/docs/griptape-framework/structures/src/task_memory_2.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool -# Create an agent with the Calculator tool -agent = Agent(tools=[Calculator(off_prompt=True)]) +# Create an agent with the CalculatorTool tool +agent = Agent(tools=[CalculatorTool(off_prompt=True)]) agent.run("What is 10 raised to the power of 5?") diff --git a/docs/griptape-framework/structures/src/task_memory_3.py b/docs/griptape-framework/structures/src/task_memory_3.py index cab4f4e3e..926a21cd3 100644 --- a/docs/griptape-framework/structures/src/task_memory_3.py +++ b/docs/griptape-framework/structures/src/task_memory_3.py @@ -1,7 +1,7 @@ from griptape.structures import Agent -from griptape.tools import Calculator, TaskMemoryClient +from griptape.tools import CalculatorTool, TaskMemoryTool -# Create an agent with the Calculator tool -agent = Agent(tools=[Calculator(off_prompt=True), TaskMemoryClient(off_prompt=False)]) +# Create an agent with the CalculatorTool tool +agent = Agent(tools=[CalculatorTool(off_prompt=True), TaskMemoryTool(off_prompt=False)]) agent.run("What is the square root of 12345?") diff --git a/docs/griptape-framework/structures/src/task_memory_4.py b/docs/griptape-framework/structures/src/task_memory_4.py index d192bebb7..cfd6d5711 100644 --- a/docs/griptape-framework/structures/src/task_memory_4.py +++ b/docs/griptape-framework/structures/src/task_memory_4.py @@ -1,8 +1,8 @@ from griptape.structures import Agent -from griptape.tools import WebScraper +from griptape.tools import WebScraperTool -# Create an agent with the WebScraper tool -agent = Agent(tools=[WebScraper()]) +# Create an agent with the WebScraperTool tool +agent = Agent(tools=[WebScraperTool()]) agent.run( "According to this page https://en.wikipedia.org/wiki/Elden_Ring, how many copies of Elden Ring have been sold?" diff --git a/docs/griptape-framework/structures/src/task_memory_5.py b/docs/griptape-framework/structures/src/task_memory_5.py index a5d3995a9..a53d106b4 100644 --- a/docs/griptape-framework/structures/src/task_memory_5.py +++ b/docs/griptape-framework/structures/src/task_memory_5.py @@ -1,10 +1,10 @@ from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool agent = Agent( tools=[ - WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=False), + WebScraperTool(off_prompt=True), + TaskMemoryTool(off_prompt=False), ] ) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index fb5c3eabb..3f4d14b0a 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -14,7 +14,7 @@ from griptape.memory import TaskMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManagerTool, TaskMemoryTool, WebScraperTool config.drivers = OpenAiDriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4"), @@ -45,9 +45,9 @@ } ), tools=[ - WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=True, allowlist=["query"]), - FileManager(off_prompt=True), + WebScraperTool(off_prompt=True), + TaskMemoryTool(off_prompt=True, allowlist=["query"]), + FileManagerTool(off_prompt=True), ], ) diff --git a/docs/griptape-framework/structures/src/task_memory_7.py b/docs/griptape-framework/structures/src/task_memory_7.py index 1ffec43e8..d2f07466f 100644 --- a/docs/griptape-framework/structures/src/task_memory_7.py +++ b/docs/griptape-framework/structures/src/task_memory_7.py @@ -1,9 +1,9 @@ from griptape.structures import Agent -from griptape.tools import WebScraper +from griptape.tools import WebScraperTool agent = Agent( tools=[ - WebScraper(off_prompt=True) # `off_prompt=True` will store the data in Task Memory + WebScraperTool(off_prompt=True) # `off_prompt=True` will store the data in Task Memory # Missing a Tool that can read from Task Memory ] ) diff --git a/docs/griptape-framework/structures/src/task_memory_8.py b/docs/griptape-framework/structures/src/task_memory_8.py index 9aba9516f..846119228 100644 --- a/docs/griptape-framework/structures/src/task_memory_8.py +++ b/docs/griptape-framework/structures/src/task_memory_8.py @@ -1,12 +1,10 @@ from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper +from griptape.tools import TaskMemoryTool, WebScraperTool agent = Agent( tools=[ - WebScraper(off_prompt=True), # This tool will store the data in Task Memory - TaskMemoryClient( - off_prompt=True - ), # This tool will store the data back in Task Memory with no way to get it out + WebScraperTool(off_prompt=True), # This tool will store the data in Task Memory + TaskMemoryTool(off_prompt=True), # This tool will store the data back in Task Memory with no way to get it out ] ) agent.run( diff --git a/docs/griptape-framework/structures/src/task_memory_9.py b/docs/griptape-framework/structures/src/task_memory_9.py index c6d93ba5e..66bb562f0 100644 --- a/docs/griptape-framework/structures/src/task_memory_9.py +++ b/docs/griptape-framework/structures/src/task_memory_9.py @@ -1,9 +1,9 @@ from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool agent = Agent( tools=[ - Calculator() # Default value of `off_prompt=False` will return the data directly to the LLM + CalculatorTool() # Default value of `off_prompt=False` will return the data directly to the LLM ] ) agent.run("What is 10 ^ 3, 55 / 23, and 12345 * 0.5?") diff --git a/docs/griptape-framework/structures/src/tasks_16.py b/docs/griptape-framework/structures/src/tasks_16.py index 796b836da..5f2e6b718 100644 --- a/docs/griptape-framework/structures/src/tasks_16.py +++ b/docs/griptape-framework/structures/src/tasks_16.py @@ -5,25 +5,25 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import StructureRunTask from griptape.tools import ( - TaskMemoryClient, - WebScraper, - WebSearch, + TaskMemoryTool, + WebScraperTool, + WebSearchTool, ) def build_researcher() -> Agent: researcher = Agent( tools=[ - WebSearch( + WebSearchTool( web_search_driver=GoogleWebSearchDriver( api_key=os.environ["GOOGLE_API_KEY"], search_id=os.environ["GOOGLE_API_SEARCH_ID"], ), ), - WebScraper( + WebScraperTool( off_prompt=True, ), - TaskMemoryClient(off_prompt=False), + TaskMemoryTool(off_prompt=False), ], rulesets=[ Ruleset( diff --git a/docs/griptape-framework/structures/src/tasks_4.py b/docs/griptape-framework/structures/src/tasks_4.py index 43737980b..936b59f99 100644 --- a/docs/griptape-framework/structures/src/tasks_4.py +++ b/docs/griptape-framework/structures/src/tasks_4.py @@ -1,12 +1,12 @@ from griptape.structures import Agent from griptape.tasks import ToolkitTask -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManagerTool, TaskMemoryTool, WebScraperTool agent = Agent() agent.add_task( ToolkitTask( "Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt", - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), TaskMemoryClient(off_prompt=True)], + tools=[WebScraperTool(off_prompt=True), FileManagerTool(off_prompt=True), TaskMemoryTool(off_prompt=True)], ), ) diff --git a/docs/griptape-framework/structures/src/tasks_5.py b/docs/griptape-framework/structures/src/tasks_5.py index 543543303..a0d537aa7 100644 --- a/docs/griptape-framework/structures/src/tasks_5.py +++ b/docs/griptape-framework/structures/src/tasks_5.py @@ -1,10 +1,10 @@ from griptape.structures import Agent from griptape.tasks import ToolTask -from griptape.tools import Calculator +from griptape.tools import CalculatorTool # Initialize the agent and add a task agent = Agent() -agent.add_task(ToolTask(tool=Calculator())) +agent.add_task(ToolTask(tool=CalculatorTool())) # Run the agent with a prompt agent.run("Give me the answer for 5*4.") diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 6b858ff5b..ccc4dd1e7 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -29,15 +29,15 @@ Lets look at a simple example where `off_prompt` is set to `False`: [04/26/24 13:06:42] INFO ToolkitTask 36b9dea13b9d479fb752014f41dca54c Input: What is the square root of 12345? [04/26/24 13:06:48] INFO Subtask a88c0feeaef6493796a9148ed68c9caf - Thought: To find the square root of 12345, I can use the Calculator action with the expression "12345 ** 0.5". - Actions: [{"name": "Calculator", "path": "calculate", "input": {"values": {"expression": "12345 ** 0.5"}}, "tag": "sqrt_12345"}] + Thought: To find the square root of 12345, I can use the CalculatorTool action with the expression "12345 ** 0.5". + Actions: [{"name": "CalculatorTool", "path": "calculate", "input": {"values": {"expression": "12345 ** 0.5"}}, "tag": "sqrt_12345"}] INFO Subtask a88c0feeaef6493796a9148ed68c9caf Response: 111.1080555135405 [04/26/24 13:06:49] INFO ToolkitTask 36b9dea13b9d479fb752014f41dca54c Output: The square root of 12345 is approximately 111.108. ``` -Since the result of the Calculator Tool is neither sensitive nor too large, we can set `off_prompt` to `False` and not use Task Memory. +Since the result of the CalculatorTool Tool is neither sensitive nor too large, we can set `off_prompt` to `False` and not use Task Memory. Let's explore what happens when `off_prompt` is set to `True`: @@ -49,38 +49,38 @@ Let's explore what happens when `off_prompt` is set to `True`: [04/26/24 13:07:02] INFO ToolkitTask ecbb788d9830491ab72a8a2bbef5fb0a Input: What is the square root of 12345? [04/26/24 13:07:10] INFO Subtask 4700dc0c2e934d1a9af60a28bd770bc6 - Thought: To find the square root of a number, we can use the Calculator action with the expression "sqrt(12345)". However, the Calculator + Thought: To find the square root of a number, we can use the CalculatorTool action with the expression "sqrt(12345)". However, the CalculatorTool action only supports basic arithmetic operations and does not support the sqrt function. Therefore, we need to use the equivalent expression for square root which is raising the number to the power of 0.5. - Actions: [{"name": "Calculator", "path": "calculate", "input": {"values": {"expression": "12345**0.5"}}, "tag": "sqrt_calculation"}] + Actions: [{"name": "CalculatorTool", "path": "calculate", "input": {"values": {"expression": "12345**0.5"}}, "tag": "sqrt_calculation"}] INFO Subtask 4700dc0c2e934d1a9af60a28bd770bc6 - Response: Output of "Calculator.calculate" was stored in memory with memory_name "TaskMemory" and artifact_namespace + Response: Output of "CalculatorTool.calculate" was stored in memory with memory_name "TaskMemory" and artifact_namespace "6be74c5128024c0588eb9bee1fdb9aa5" [04/26/24 13:07:16] ERROR Subtask ecbb788d9830491ab72a8a2bbef5fb0a - Invalid action JSON: Or({Literal("name", description=""): 'Calculator', Literal("path", description="Can be used for computing simple + Invalid action JSON: Or({Literal("name", description=""): 'CalculatorTool', Literal("path", description="Can be used for computing simple numerical or algebraic calculations in Python"): 'calculate', Literal("input", description=""): {'values': Schema({Literal("expression", description="Arithmetic expression parsable in pure Python. Single line only. Don't use variables. Don't use any imports or external libraries"): })}, Literal("tag", description="Unique tag name for action execution."): }) did not validate {'name': 'Memory', 'path': 'get', 'input': {'memory_name': 'TaskMemory', 'artifact_namespace': '6be74c5128024c0588eb9bee1fdb9aa5'}, 'tag': 'get_sqrt_result'} Key 'name' error: - 'Calculator' does not match 'Memory' + 'CalculatorTool' does not match 'Memory' ...Output truncated for brevity... ``` -When we set `off_prompt` to `True`, the Agent does not function as expected, even generating an error. This is because the Calculator output is being stored in Task Memory but the Agent has no way to access it. -To fix this, we need a [Tool that can read from Task Memory](#tools-that-can-read-from-task-memory) such as the `TaskMemoryClient`. +When we set `off_prompt` to `True`, the Agent does not function as expected, even generating an error. This is because the CalculatorTool output is being stored in Task Memory but the Agent has no way to access it. +To fix this, we need a [Tool that can read from Task Memory](#tools-that-can-read-from-task-memory) such as the `TaskMemoryTool`. This is an example of [not providing a Task Memory compatible Tool](#not-providing-a-task-memory-compatible-tool). -## Task Memory Client +## Task Memory Tool -The [TaskMemoryClient](../../griptape-tools/official-tools/task-memory-client.md) is a Tool that allows an Agent to interact with Task Memory. It has the following methods: +The [TaskMemoryTool](../../griptape-tools/official-tools/task-memory-tool.md) is a Tool that allows an Agent to interact with Task Memory. It has the following methods: - `query`: Retrieve the content of an Artifact stored in Task Memory. - `summarize`: Summarize the content of an Artifact stored in Task Memory. -Let's add `TaskMemoryClient` to the Agent and run the same task. -Note that on the `TaskMemoryClient` we've set `off_prompt` to `False` so that the results of the query can be returned directly to the LLM. +Let's add `TaskMemoryTool` to the Agent and run the same task. +Note that on the `TaskMemoryTool` we've set `off_prompt` to `False` so that the results of the query can be returned directly to the LLM. If we had kept it as `True`, the results would have been stored back Task Memory which would've put us back to square one. See [Task Memory Looping](#task-memory-looping) for more information on this scenario. ```python @@ -91,15 +91,15 @@ If we had kept it as `True`, the results would have been stored back Task Memory [04/26/24 13:13:01] INFO ToolkitTask 5b46f9ef677c4b31906b48aba3f45e2c Input: What is the square root of 12345? [04/26/24 13:13:07] INFO Subtask 611d98ea5576430fbc63259420577ab2 - Thought: To find the square root of 12345, I can use the Calculator action with the expression "12345 ** 0.5". - Actions: [{"name": "Calculator", "path": "calculate", "input": {"values": {"expression": "12345 ** 0.5"}}, "tag": "sqrt_12345"}] + Thought: To find the square root of 12345, I can use the CalculatorTool action with the expression "12345 ** 0.5". + Actions: [{"name": "CalculatorTool", "path": "calculate", "input": {"values": {"expression": "12345 ** 0.5"}}, "tag": "sqrt_12345"}] [04/26/24 13:13:08] INFO Subtask 611d98ea5576430fbc63259420577ab2 - Response: Output of "Calculator.calculate" was stored in memory with memory_name "TaskMemory" and artifact_namespace + Response: Output of "CalculatorTool.calculate" was stored in memory with memory_name "TaskMemory" and artifact_namespace "7554b69e1d414a469b8882e2266dcea1" [04/26/24 13:13:15] INFO Subtask 32b9163a15644212be60b8fba07bd23b - Thought: The square root of 12345 has been calculated and stored in memory. I can retrieve this value using the TaskMemoryClient action with + Thought: The square root of 12345 has been calculated and stored in memory. I can retrieve this value using the TaskMemoryTool action with the query path, providing the memory_name and artifact_namespace as input. - Actions: [{"tag": "retrieve_sqrt", "name": "TaskMemoryClient", "path": "query", "input": {"values": {"memory_name": "TaskMemory", + Actions: [{"tag": "retrieve_sqrt", "name": "TaskMemoryTool", "path": "query", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "7554b69e1d414a469b8882e2266dcea1", "query": "What is the result of the calculation?"}}}] [04/26/24 13:13:16] INFO Subtask 32b9163a15644212be60b8fba07bd23b Response: The result of the calculation is 111.1080555135405. @@ -107,7 +107,7 @@ If we had kept it as `True`, the results would have been stored back Task Memory Output: The square root of 12345 is approximately 111.108. ``` -While this fixed the problem, it took a handful more steps than when we just had `Calculator()`. Something like a basic calculation is an instance of where [Task Memory may not be necessary](#task-memory-may-not-be-necessary). +While this fixed the problem, it took a handful more steps than when we just had `CalculatorTool()`. Something like a basic calculation is an instance of where [Task Memory may not be necessary](#task-memory-may-not-be-necessary). Let's look at a more complex example where Task Memory shines. ## Large Data @@ -125,8 +125,8 @@ When running this example, we get the following error: Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}} ``` -This is because the content of the webpage is too large to fit in the LLM's input token limit. We can fix this by storing the content in Task Memory, and then querying it with the `TaskMemoryClient`. -Note that we're setting `off_prompt` to `False` on the `TaskMemoryClient` so that the _queried_ content can be returned directly to the LLM. +This is because the content of the webpage is too large to fit in the LLM's input token limit. We can fix this by storing the content in Task Memory, and then querying it with the `TaskMemoryTool`. +Note that we're setting `off_prompt` to `False` on the `TaskMemoryTool` so that the _queried_ content can be returned directly to the LLM. ```python --8<-- "docs/griptape-framework/structures/src/task_memory_5.py" @@ -146,7 +146,7 @@ And now we get the expected output: [04/26/24 13:52:11] INFO Subtask f12eb3d3b4924e4085808236b460b43d Thought: Now that the webpage content is stored in memory, I need to query this memory to find the information about how many copies of Elden Ring have been sold. - Actions: [{"tag": "query_sales", "name": "TaskMemoryClient", "path": "query", "input": {"values": {"memory_name": "TaskMemory", + Actions: [{"tag": "query_sales", "name": "TaskMemoryTool", "path": "query", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "2d4ebc7211074bb7be26613eb25d8fc1", "query": "How many copies of Elden Ring have been sold?"}}}] [04/26/24 13:52:14] INFO Subtask f12eb3d3b4924e4085808236b460b43d Response: Elden Ring sold 23 million copies by February 2024. @@ -192,13 +192,13 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s information about how many copies of Elden Ring have been sold. Actions: [{"tag": "query_sales", "name": - "TaskMemoryClient", "path": "query", "input": + "TaskMemoryTool", "path": "query", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "7e48bcff0da94ad3b06aa4e173f8f37b", "query": "How many copies of Elden Ring have been sold?"}}}] [06/21/24 16:00:19] INFO Subtask 56102d42475d413299ce52a0230506b7 - Response: Output of "TaskMemoryClient.query" was + Response: Output of "TaskMemoryTool.query" was stored in memory with memory_name "TaskMemory" and artifact_namespace "9ecf4d7b7d0c46149dfc46ba236f178e" @@ -229,11 +229,11 @@ As seen in the previous example, certain Tools are designed to read directly fro Today, these include: -- [TaskMemoryClient](../../griptape-tools/official-tools/task-memory-client.md) -- [FileManager](../../griptape-tools/official-tools/file-manager.md) -- [AwsS3Client](../../griptape-tools/official-tools/aws-s3-client.md) -- [GoogleDriveClient](../../griptape-tools/official-tools/google-drive-client.md) -- [GoogleDocsClient](../../griptape-tools/official-tools/google-docs-client.md) +- [TaskMemoryTool](../../griptape-tools/official-tools/task-memory-tool.md) +- [FileManager](../../griptape-tools/official-tools/file-manager-tool.md) +- [AwsS3Tool](../../griptape-tools/official-tools/aws-s3-tool.md) +- [GoogleDriveTool](../../griptape-tools/official-tools/google-drive-tool.md) +- [GoogleDocsTool](../../griptape-tools/official-tools/google-docs-tool.md) ## Task Memory Considerations diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index a57c1e604..477ce7b64 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -112,7 +112,7 @@ This Task takes in one or more Tools which the LLM will decide to use through Ch [09/08/23 11:15:11] INFO Subtask a22a7e4ebf594b4b895fcbe8a95c1dd3 Thought: Now that the webpage content is stored in memory, I can use the TaskMemory tool's summarize activity to summarize it. - Action: {"name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "2b50373849d140f698ba8071066437ee"}}} + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "2b50373849d140f698ba8071066437ee"}}} [09/08/23 11:15:15] INFO Subtask a22a7e4ebf594b4b895fcbe8a95c1dd3 Response: Griptape is an open source framework that allows developers to build and deploy AI applications using large language models (LLMs). It provides the ability to create conversational and event-driven apps that @@ -149,7 +149,7 @@ This Task takes in a single Tool which the LLM will use without Chain of Thought [10/20/23 14:20:29] INFO Subtask a9a9ad7be2bf465fa82bd350116fabe4 Action: { - "name": "Calculator", + "name": "CalculatorTool", "path": "calculate", "input": { "values": { diff --git a/docs/griptape-framework/tools/index.md b/docs/griptape-framework/tools/index.md index c974b9658..808849dae 100644 --- a/docs/griptape-framework/tools/index.md +++ b/docs/griptape-framework/tools/index.md @@ -33,7 +33,7 @@ Here is an example of a Pipeline using Tools: [09/08/23 10:54:09] INFO Subtask 7ee08458ce154e3d970711b7d3ed79ba Thought: Now that the webpage content is stored in memory, I can use the TaskMemory tool with the summarize activity to summarize the content. - Action: {"name": "TaskMemoryClient", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "9eb6f5828cf64356bf323f11d28be27e"}}} + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "9eb6f5828cf64356bf323f11d28be27e"}}} [09/08/23 10:54:12] INFO Subtask 7ee08458ce154e3d970711b7d3ed79ba Response: Griptape is an open source framework that allows developers to build and deploy AI applications using large language models (LLMs). It provides the ability to create conversational and event-driven apps that diff --git a/docs/griptape-framework/tools/src/index_1.py b/docs/griptape-framework/tools/src/index_1.py index 52d8ac8e7..488dcfbc8 100644 --- a/docs/griptape-framework/tools/src/index_1.py +++ b/docs/griptape-framework/tools/src/index_1.py @@ -1,13 +1,13 @@ from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import FileManager, TaskMemoryClient, WebScraper +from griptape.tools import FileManagerTool, TaskMemoryTool, WebScraperTool pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( "Load https://www.griptape.ai, summarize it, and store it in a file called griptape.txt", - tools=[WebScraper(off_prompt=True), FileManager(off_prompt=True), TaskMemoryClient(off_prompt=False)], + tools=[WebScraperTool(off_prompt=True), FileManagerTool(off_prompt=True), TaskMemoryTool(off_prompt=False)], ), ) diff --git a/docs/griptape-tools/official-tools/audio-transcription-client.md b/docs/griptape-tools/official-tools/audio-transcription-tool.md similarity index 90% rename from docs/griptape-tools/official-tools/audio-transcription-client.md rename to docs/griptape-tools/official-tools/audio-transcription-tool.md index 271144d5b..ad8eeaa9b 100644 --- a/docs/griptape-tools/official-tools/audio-transcription-client.md +++ b/docs/griptape-tools/official-tools/audio-transcription-tool.md @@ -1,7 +1,7 @@ -# AudioTranscriptionClient +# Audio Transcription Tool This Tool enables [Agents](../../griptape-framework/structures/agents.md) to transcribe speech from text using [Audio Transcription Engines](../../reference/griptape/engines/audio/audio_transcription_engine.md) and [Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md). ```python ---8<-- "docs/griptape-tools/official-tools/src/audio_transcription_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/audio_transcription_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/aws-iam-client.md b/docs/griptape-tools/official-tools/aws-iam-tool.md similarity index 89% rename from docs/griptape-tools/official-tools/aws-iam-client.md rename to docs/griptape-tools/official-tools/aws-iam-tool.md index 22b4115a3..1f594cff0 100644 --- a/docs/griptape-tools/official-tools/aws-iam-client.md +++ b/docs/griptape-tools/official-tools/aws-iam-tool.md @@ -1,22 +1,22 @@ -# AwsIamClient +# Aws Iam Tool This tool enables LLMs to make AWS IAM API requests. ```python ---8<-- "docs/griptape-tools/official-tools/src/aws_iam_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/aws_iam_tool_1.py" ``` ``` [09/11/23 16:45:45] INFO Task 890fcf77fb074c9490d5c91563e0c995 Input: List all my IAM users [09/11/23 16:45:51] INFO Subtask f2f0809ee10d4538972ed01fdd6a2fb8 Thought: To list all IAM users, I can use the - AwsIamClient tool with the list_users activity. + AwsIamTool tool with the list_users activity. This activity does not require any input. - Action: {"name": "AwsIamClient", + Action: {"name": "AwsIamTool", "path": "list_users"} [09/11/23 16:45:52] INFO Subtask f2f0809ee10d4538972ed01fdd6a2fb8 - Response: Output of "AwsIamClient.list_users" + Response: Output of "AwsIamTool.list_users" was stored in memory with memory_name "TaskMemory" and artifact_namespace "51d22a018a434904a5da3bb8d4f763f7" @@ -25,7 +25,7 @@ This tool enables LLMs to make AWS IAM API requests. stored in memory. I can retrieve this information using the TaskMemory tool with the summarize activity. - Action: {"name": "TaskMemoryClient", "path": + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "51d22a018a434904a5da3bb8d4f763f7"}}} diff --git a/docs/griptape-tools/official-tools/aws-s3-client.md b/docs/griptape-tools/official-tools/aws-s3-tool.md similarity index 90% rename from docs/griptape-tools/official-tools/aws-s3-client.md rename to docs/griptape-tools/official-tools/aws-s3-tool.md index 70ca79a20..9594861a9 100644 --- a/docs/griptape-tools/official-tools/aws-s3-client.md +++ b/docs/griptape-tools/official-tools/aws-s3-tool.md @@ -1,23 +1,23 @@ -# AwsS3Client +# Aws S3 Tool This tool enables LLMs to make AWS S3 API requests. ```python ---8<-- "docs/griptape-tools/official-tools/src/aws_s3_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/aws_s3_tool_1.py" ``` ``` [09/11/23 16:49:35] INFO Task 8bf7538e217a4b5a8472829f5eee75b9 Input: List all my S3 buckets. [09/11/23 16:49:41] INFO Subtask 9fc44f5c8e73447ba737283cb2ef7f5d Thought: To list all S3 buckets, I can use the - "list_s3_buckets" activity of the "AwsS3Client" + "list_s3_buckets" activity of the "AwsS3Tool" tool. This activity doesn't require any input. - Action: {"name": "AwsS3Client", + Action: {"name": "AwsS3Tool", "path": "list_s3_buckets"} [09/11/23 16:49:42] INFO Subtask 9fc44f5c8e73447ba737283cb2ef7f5d Response: Output of - "AwsS3Client.list_s3_buckets" was stored in memory + "AwsS3Tool.list_s3_buckets" was stored in memory with memory_name "TaskMemory" and artifact_namespace "f2592085fd4a430286a46770ea508cc9" @@ -26,7 +26,7 @@ This tool enables LLMs to make AWS S3 API requests. activity is stored in memory. I can retrieve this information using the "summarize" activity of the "TaskMemory" tool. - Action: {"name": "TaskMemoryClient", "path": + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "f2592085fd4a430286a46770ea508cc9"}}} diff --git a/docs/griptape-tools/official-tools/calculator.md b/docs/griptape-tools/official-tools/calculator-tool.md similarity index 82% rename from docs/griptape-tools/official-tools/calculator.md rename to docs/griptape-tools/official-tools/calculator-tool.md index 55ac039db..afe17a364 100644 --- a/docs/griptape-tools/official-tools/calculator.md +++ b/docs/griptape-tools/official-tools/calculator-tool.md @@ -1,9 +1,9 @@ -# Calculator +# Calculator Tool This tool enables LLMs to make simple calculations. ```python ---8<-- "docs/griptape-tools/official-tools/src/calculator_1.py" +--8<-- "docs/griptape-tools/official-tools/src/calculator_tool_1.py" ``` ``` [09/08/23 14:23:51] INFO Task bbc6002a5e5b4655bb52b6a550a1b2a5 @@ -12,9 +12,9 @@ This tool enables LLMs to make simple calculations. Thought: The question is asking for the result of 10 raised to the power of 5. This is a mathematical operation that can be performed using the - Calculator tool. + CalculatorTool tool. - Action: {"name": "Calculator", + Action: {"name": "CalculatorTool", "path": "calculate", "input": {"values": {"expression": "10**5"}}} INFO Subtask 3e9211a0f44c4277812ae410c43adbc9 diff --git a/docs/griptape-tools/official-tools/computer.md b/docs/griptape-tools/official-tools/computer-tool.md similarity index 97% rename from docs/griptape-tools/official-tools/computer.md rename to docs/griptape-tools/official-tools/computer-tool.md index 121224b20..758dd8714 100644 --- a/docs/griptape-tools/official-tools/computer.md +++ b/docs/griptape-tools/official-tools/computer-tool.md @@ -1,11 +1,11 @@ -# Computer +# Computer Tool This tool enables LLMs to execute Python code and run shell commands inside a Docker container. You have to have the Docker daemon running in order for this tool to work. You can specify a local working directory and environment variables during tool initialization: ```python ---8<-- "docs/griptape-tools/official-tools/src/computer_1.py" +--8<-- "docs/griptape-tools/official-tools/src/computer_tool_1.py" ``` ``` [09/11/23 16:24:15] INFO Task d08009ee983c4286ba10f83bcf3080e6 @@ -46,7 +46,7 @@ You can specify a local working directory and environment variables during tool in memory. I need to retrieve this output to check if "my_new_file.txt" is listed, which would confirm that the file was created successfully. - Action: {"name": "TaskMemoryClient", "path": + Action: {"name": "TaskMemoryTool", "path": "query", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "82bc4937564e4901b7fc51fced45b643", "query": "Is diff --git a/docs/griptape-tools/official-tools/date-time.md b/docs/griptape-tools/official-tools/date-time-tool.md similarity index 93% rename from docs/griptape-tools/official-tools/date-time.md rename to docs/griptape-tools/official-tools/date-time-tool.md index 76e453f39..bdc5ccbf4 100644 --- a/docs/griptape-tools/official-tools/date-time.md +++ b/docs/griptape-tools/official-tools/date-time-tool.md @@ -1,9 +1,9 @@ -# DateTime +# Date Time Tool This tool enables LLMs to get current date and time. ```python ---8<-- "docs/griptape-tools/official-tools/src/date_time_1.py" +--8<-- "docs/griptape-tools/official-tools/src/date_time_tool_1.py" ``` ``` [09/11/23 15:26:02] INFO Task d0bf49dacd8849e695494578a333f6cc diff --git a/docs/griptape-tools/official-tools/email-client.md b/docs/griptape-tools/official-tools/email-client.md deleted file mode 100644 index 66decf820..000000000 --- a/docs/griptape-tools/official-tools/email-client.md +++ /dev/null @@ -1,13 +0,0 @@ -# EmailClient - -The [EmailClient](../../reference/griptape/tools/email_client/tool.md) enables LLMs to send emails. - -```python ---8<-- "docs/griptape-tools/official-tools/src/email_client_1.py" -``` - -For debugging purposes, you can run a local SMTP server that the LLM can send emails to: - -```shell -python -m smtpd -c DebuggingServer -n localhost:1025 -``` diff --git a/docs/griptape-tools/official-tools/email-tool.md b/docs/griptape-tools/official-tools/email-tool.md new file mode 100644 index 000000000..91ba6f19b --- /dev/null +++ b/docs/griptape-tools/official-tools/email-tool.md @@ -0,0 +1,13 @@ +# Email Tool + +The [EmailTool](../../reference/griptape/tools/email/tool.md) enables LLMs to send emails. + +```python +--8<-- "docs/griptape-tools/official-tools/src/email_tool_1.py" +``` + +For debugging purposes, you can run a local SMTP server that the LLM can send emails to: + +```shell +python -m smtpd -c DebuggingServer -n localhost:1025 +``` diff --git a/docs/griptape-tools/official-tools/file-manager.md b/docs/griptape-tools/official-tools/file-manager-tool.md similarity index 94% rename from docs/griptape-tools/official-tools/file-manager.md rename to docs/griptape-tools/official-tools/file-manager-tool.md index 539c4fff4..5f27b8da5 100644 --- a/docs/griptape-tools/official-tools/file-manager.md +++ b/docs/griptape-tools/official-tools/file-manager-tool.md @@ -1,9 +1,9 @@ -# FileManager +# File Manager Tool This tool enables LLMs to save and load files. ```python ---8<-- "docs/griptape-tools/official-tools/src/file_manager_1.py" +--8<-- "docs/griptape-tools/official-tools/src/file_manager_tool_1.py" ``` ``` [09/12/23 12:07:56] INFO Task 16a1ce1847284ae3805485bad7d99116 diff --git a/docs/griptape-tools/official-tools/google-cal-client.md b/docs/griptape-tools/official-tools/google-cal-client.md deleted file mode 100644 index 9757bdcbf..000000000 --- a/docs/griptape-tools/official-tools/google-cal-client.md +++ /dev/null @@ -1,8 +0,0 @@ -# GoogleCalendarClient - -The GoogleCalendarClient tool allows you to interact with Google Calendar. - - -```python ---8<-- "docs/griptape-tools/official-tools/src/google_cal_client_1.py" -``` diff --git a/docs/griptape-tools/official-tools/google-calendar-tool.md b/docs/griptape-tools/official-tools/google-calendar-tool.md new file mode 100644 index 000000000..e0b5d9cdc --- /dev/null +++ b/docs/griptape-tools/official-tools/google-calendar-tool.md @@ -0,0 +1,8 @@ +# Google Calendar Tool + +The [GoogleCalendarTool](../../reference/griptape/tools/google_calendar/tool.md) tool allows you to interact with Google Calendar. + + +```python +--8<-- "docs/griptape-tools/official-tools/src/google_calendar_tool_1.py" +``` diff --git a/docs/griptape-tools/official-tools/google-docs-client.md b/docs/griptape-tools/official-tools/google-docs-tool.md similarity index 79% rename from docs/griptape-tools/official-tools/google-docs-client.md rename to docs/griptape-tools/official-tools/google-docs-tool.md index ff23d8e89..1f02196b9 100644 --- a/docs/griptape-tools/official-tools/google-docs-client.md +++ b/docs/griptape-tools/official-tools/google-docs-tool.md @@ -1,9 +1,9 @@ -# GoogleDocsClient +# Google Docs Tool -The GoogleDocsClient tool provides a way to interact with the Google Docs API. It can be used to create new documents, save content to existing documents, and more. +The [GoogleDocsTool](../../reference/griptape/tools/google_docs/tool.md) tool provides a way to interact with the Google Docs API. It can be used to create new documents, save content to existing documents, and more. ```python ---8<-- "docs/griptape-tools/official-tools/src/google_docs_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/google_docs_tool_1.py" ``` ``` [10/05/23 12:56:19] INFO ToolkitTask 90721b7478a74618a63d852d35be3b18 @@ -14,10 +14,10 @@ The GoogleDocsClient tool provides a way to interact with the Google Docs API. I named 'test_creation' in a folder named 'test' with the content 'Hey, Tony.'. I can use the 'save_content_to_google_doc' activity of the - GoogleDocsClient tool to achieve this. + GoogleDocsTool tool to achieve this. Action: {"name": - "GoogleDocsClient", "path": + "GoogleDocsTool", "path": "save_content_to_google_doc", "input": {"values": {"file_path": "test_creation", "content": "Hey, Tony.", "folder_path": "test"}}} diff --git a/docs/griptape-tools/official-tools/google-drive-client.md b/docs/griptape-tools/official-tools/google-drive-tool.md similarity index 77% rename from docs/griptape-tools/official-tools/google-drive-client.md rename to docs/griptape-tools/official-tools/google-drive-tool.md index 9ad3daccc..18e10ec08 100644 --- a/docs/griptape-tools/official-tools/google-drive-client.md +++ b/docs/griptape-tools/official-tools/google-drive-tool.md @@ -1,9 +1,9 @@ -# GoogleDriveClient +# Google Drive Tool -The GoogleDriveClient tool provides a way to interact with the Google Drive API. It can be used to save content on Drive, list files, and more. +The [GoogleDriveTool](../../reference/griptape/tools/google_drive/tool.md) tool provides a way to interact with the Google Drive API. It can be used to save content on Drive, list files, and more. ```python ---8<-- "docs/griptape-tools/official-tools/src/google_drive_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/google_drive_tool_1.py" ``` ``` [10/05/23 10:49:14] INFO ToolkitTask 2ae3bb7e828744f3a2631c29c6fce001 @@ -13,11 +13,11 @@ The GoogleDriveClient tool provides a way to interact with the Google Drive API. Thought: The user wants to save the content 'Hi this is Tony' in a file named 'hello.txt' to Google Drive. I can use the 'save_content_to_drive' - activity of the GoogleDriveClient tool to + activity of the GoogleDriveTool tool to accomplish this. Action: {"name": - "GoogleDriveClient", "path": + "GoogleDriveTool", "path": "save_content_to_drive", "input": {"values": {"path": "hello.txt", "content": "Hi this is Tony"}}} diff --git a/docs/griptape-tools/official-tools/google-gmail-client.md b/docs/griptape-tools/official-tools/google-gmail-tool.md similarity index 78% rename from docs/griptape-tools/official-tools/google-gmail-client.md rename to docs/griptape-tools/official-tools/google-gmail-tool.md index fe6952dc3..1a9e6ea47 100644 --- a/docs/griptape-tools/official-tools/google-gmail-client.md +++ b/docs/griptape-tools/official-tools/google-gmail-tool.md @@ -1,9 +1,9 @@ -# GoogleGmailClient +# Google Gmail Tool -The GoogleGmailClient tool provides a way to interact with the Gmail API. It can be used to create draft emails, send emails, and more. +The [GoogleGmailTool](../../reference/griptape/tools/google_gmail/tool.md) tool provides a way to interact with the Gmail API. It can be used to create draft emails, send emails, and more. ```python ---8<-- "docs/griptape-tools/official-tools/src/google_gmail_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/google_gmail_tool_1.py" ``` ``` [10/05/23 13:24:05] INFO ToolkitTask 1f190f823d584053bfe9942f41b6cb2d @@ -12,13 +12,13 @@ The GoogleGmailClient tool provides a way to interact with the Gmail API. It can the body 'This is a test draft email.' [10/05/23 13:24:15] INFO Subtask 7f2cce7e5b0e425ba696531561697b96 Thought: The user wants to create a draft email in - Gmail. I can use the GoogleGmailClient tool with + Gmail. I can use the GoogleGmailTool tool with the create_draft_email activity to accomplish this. I will need to provide the 'to', 'subject', and 'body' values as input. Action: {"name": - "GoogleGmailClient", "path": + "GoogleGmailTool", "path": "create_draft_email", "input": {"values": {"to": "example@email.com", "subject": "Test Draft", "body": "This is a test draft email."}}} diff --git a/docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md b/docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md deleted file mode 100644 index 28e89b4bb..000000000 --- a/docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md +++ /dev/null @@ -1,9 +0,0 @@ -## Overview - -The `GriptapeCloudKnowledgeBaseClient` is a lightweight Tool to retrieve data from a RAG pipeline and vector store hosted in [Griptape Cloud](https://cloud.griptape.ai). It enables searching across a centralized [Knowledge Base](https://cloud.griptape.ai/knowledge-bases) that can consist of various data sources such as Confluence, Google Docs, and web pages. - -**Note:** This tool requires a [Knowledge Base](https://cloud.griptape.ai/knowledge-bases) hosted in Griptape Cloud and an [API Key](https://cloud.griptape.ai/account/api-keys) for access. - -```python ---8<-- "docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_client_1.py" -``` diff --git a/docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-tool.md b/docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-tool.md new file mode 100644 index 000000000..96af51782 --- /dev/null +++ b/docs/griptape-tools/official-tools/griptape-cloud-knowledge-base-tool.md @@ -0,0 +1,9 @@ +# Griptape Cloud Knowledge Base Tool + +The [GriptapeCloudKnowledgeBaseTool](../../reference/griptape/tools/griptape_cloud_knowledge_base/tool.md) is a lightweight Tool to retrieve data from a RAG pipeline and vector store hosted in [Griptape Cloud](https://cloud.griptape.ai). It enables searching across a centralized [Knowledge Base](https://cloud.griptape.ai/knowledge-bases) that can consist of various data sources such as Confluence, Google Docs, and web pages. + +**Note:** This tool requires a [Knowledge Base](https://cloud.griptape.ai/knowledge-bases) hosted in Griptape Cloud and an [API Key](https://cloud.griptape.ai/account/api-keys) for access. + +```python +--8<-- "docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_tool_1.py" +``` diff --git a/docs/griptape-tools/official-tools/image-query-client.md b/docs/griptape-tools/official-tools/image-query-tool.md similarity index 58% rename from docs/griptape-tools/official-tools/image-query-client.md rename to docs/griptape-tools/official-tools/image-query-tool.md index ae02d127c..781a279a8 100644 --- a/docs/griptape-tools/official-tools/image-query-client.md +++ b/docs/griptape-tools/official-tools/image-query-tool.md @@ -1,7 +1,7 @@ -# ImageQueryClient +# Image Query Tool This tool allows Agents to execute natural language queries on the contents of images using multimodal models. ```python ---8<-- "docs/griptape-tools/official-tools/src/image_query_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/image_query_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/inpainting-image-generation-client.md b/docs/griptape-tools/official-tools/inpainting-image-generation-tool.md similarity index 87% rename from docs/griptape-tools/official-tools/inpainting-image-generation-client.md rename to docs/griptape-tools/official-tools/inpainting-image-generation-tool.md index 65de566ad..7abdc6238 100644 --- a/docs/griptape-tools/official-tools/inpainting-image-generation-client.md +++ b/docs/griptape-tools/official-tools/inpainting-image-generation-tool.md @@ -1,7 +1,7 @@ -# InpaintingImageGenerationClient +# Inpainting Image Generation Tool This tool allows LLMs to generate images using inpainting, where an input image is altered within the area specified by a mask image according to a prompt. The input and mask images can be provided either by their file path or by their [Task Memory](../../griptape-framework/structures/task-memory.md) references. ```python ---8<-- "docs/griptape-tools/official-tools/src/inpainting_image_generation_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/inpainting_image_generation_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/openweather-client.md b/docs/griptape-tools/official-tools/openweather-client.md deleted file mode 100644 index 3521733c5..000000000 --- a/docs/griptape-tools/official-tools/openweather-client.md +++ /dev/null @@ -1,7 +0,0 @@ -# OpenWeatherClient - -The [OpenWeatherClient](../../reference/griptape/tools/openweather_client/tool.md) enables LLMs to use [OpenWeatherMap](https://openweathermap.org/). - -```python ---8<-- "docs/griptape-tools/official-tools/src/openweather_client_1.py" -``` diff --git a/docs/griptape-tools/official-tools/openweather-tool.md b/docs/griptape-tools/official-tools/openweather-tool.md new file mode 100644 index 000000000..be1ed3972 --- /dev/null +++ b/docs/griptape-tools/official-tools/openweather-tool.md @@ -0,0 +1,7 @@ +# Open Weather Tool + +The [OpenWeatherTool](../../reference/griptape/tools/openweather/tool.md) enables LLMs to use [OpenWeatherMap](https://openweathermap.org/). + +```python +--8<-- "docs/griptape-tools/official-tools/src/openweather_tool_1.py" +``` diff --git a/docs/griptape-tools/official-tools/outpainting-image-generation-client.md b/docs/griptape-tools/official-tools/outpainting-image-generation-tool.md similarity index 86% rename from docs/griptape-tools/official-tools/outpainting-image-generation-client.md rename to docs/griptape-tools/official-tools/outpainting-image-generation-tool.md index e62a40a73..ce97798bc 100644 --- a/docs/griptape-tools/official-tools/outpainting-image-generation-client.md +++ b/docs/griptape-tools/official-tools/outpainting-image-generation-tool.md @@ -1,7 +1,7 @@ -# OutpaintingImageGenerationClient +# Outpainting Image Generation Tool This tool allows LLMs to generate images using outpainting, where an input image is altered outside of the area specified by a mask image according to a prompt. The input and mask images can be provided either by their file path or by their [Task Memory](../../griptape-framework/structures/task-memory.md) references. ```python ---8<-- "docs/griptape-tools/official-tools/src/outpainting_image_generation_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/outpainting_image_generation_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/prompt-image-generation-client.md b/docs/griptape-tools/official-tools/prompt-image-generation-tool.md similarity index 73% rename from docs/griptape-tools/official-tools/prompt-image-generation-client.md rename to docs/griptape-tools/official-tools/prompt-image-generation-tool.md index d91b154be..c764791dc 100644 --- a/docs/griptape-tools/official-tools/prompt-image-generation-client.md +++ b/docs/griptape-tools/official-tools/prompt-image-generation-tool.md @@ -1,7 +1,7 @@ -# PromptImageGenerationClient +# Prompt Image Generation Tool This tool allows LLMs to generate images from a text prompt. ```python ---8<-- "docs/griptape-tools/official-tools/src/prompt_image_generation_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/prompt_image_generation_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/rag-client.md b/docs/griptape-tools/official-tools/rag-tool.md similarity index 86% rename from docs/griptape-tools/official-tools/rag-client.md rename to docs/griptape-tools/official-tools/rag-tool.md index c90762946..71613beab 100644 --- a/docs/griptape-tools/official-tools/rag-client.md +++ b/docs/griptape-tools/official-tools/rag-tool.md @@ -1,9 +1,11 @@ -The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. +# Rag Tool + +The [RagTool](../../reference/griptape/tools/rag/tool.md) enables LLMs to query modular RAG engines. Here is an example of how it can be used with a local vector store driver: ```python ---8<-- "docs/griptape-tools/official-tools/src/rag_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/rag_tool_1.py" ``` ``` [07/11/24 13:30:43] INFO ToolkitTask a6d057d5c71d4e9cb6863a2adb64b76c @@ -12,7 +14,7 @@ Here is an example of how it can be used with a local vector store driver: Actions: [ { "tag": "call_4MaDzOuKnWAs2gmhK3KJhtjI", - "name": "RagClient", + "name": "RagTool", "path": "search", "input": { "values": { diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-tool.md similarity index 50% rename from docs/griptape-tools/official-tools/rest-api-client.md rename to docs/griptape-tools/official-tools/rest-api-tool.md index 8ec4804c2..345f0589b 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-tool.md @@ -1,12 +1,12 @@ -# RestApiClient +# Rest Api Tool This tool enables LLMs to call REST APIs. -The [RestApiClient](../../reference/griptape/tools/rest_api_client/tool.md) tool uses the following parameters: +The [RestApiTool](../../reference/griptape/tools/rest_api/tool.md) tool uses the following parameters: ### Example The following example is built using [https://jsonplaceholder.typicode.com/guide/](https://jsonplaceholder.typicode.com/guide/). ```python ---8<-- "docs/griptape-tools/official-tools/src/rest_api_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/rest_api_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/sql-client.md b/docs/griptape-tools/official-tools/sql-tool.md similarity index 90% rename from docs/griptape-tools/official-tools/sql-client.md rename to docs/griptape-tools/official-tools/sql-tool.md index 1d0d7abb0..b20327013 100644 --- a/docs/griptape-tools/official-tools/sql-client.md +++ b/docs/griptape-tools/official-tools/sql-tool.md @@ -1,23 +1,23 @@ -# SqlClient +# Sql Tool This tool enables LLMs to execute SQL statements via [SQLAlchemy](https://www.sqlalchemy.org/). Depending on your underlying SQL engine, [configure](https://docs.sqlalchemy.org/en/20/core/engines.html) your `engine_url` and give the LLM a hint about what engine you are using via `engine_name`, so that it can create engine-specific statements. ```python ---8<-- "docs/griptape-tools/official-tools/src/sql_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/sql_tool_1.py" ``` ``` [09/11/23 17:02:55] INFO Task d8331f8705b64b4b9d9a88137ed73f3f Input: SELECT * FROM people; [09/11/23 17:03:02] INFO Subtask 46c2f8926ce9469e9ca6b1b3364e3e41 Thought: The user wants to retrieve all records - from the 'people' table. I can use the SqlClient + from the 'people' table. I can use the SqlTool tool to execute this query. - Action: {"name": "SqlClient", + Action: {"name": "SqlTool", "path": "execute_query", "input": {"values": {"sql_query": "SELECT * FROM people;"}}} [09/11/23 17:03:03] INFO Subtask 46c2f8926ce9469e9ca6b1b3364e3e41 - Response: Output of "SqlClient.execute_query" + Response: Output of "SqlTool.execute_query" was stored in memory with memory_name "TaskMemory" and artifact_namespace "217715ba3e444e4985bee223df5716a8" @@ -25,7 +25,7 @@ This tool enables LLMs to execute SQL statements via [SQLAlchemy](https://www.sq Thought: The output of the SQL query has been stored in memory. I can retrieve this data using the TaskMemory's 'summarize' activity. - Action: {"name": "TaskMemoryClient", "path": + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "217715ba3e444e4985bee223df5716a8"}}} diff --git a/docs/griptape-tools/official-tools/src/audio_transcription_client_1.py b/docs/griptape-tools/official-tools/src/audio_transcription_tool_1.py similarity index 79% rename from docs/griptape-tools/official-tools/src/audio_transcription_client_1.py rename to docs/griptape-tools/official-tools/src/audio_transcription_tool_1.py index d2c54e0c9..bc25fd1fa 100644 --- a/docs/griptape-tools/official-tools/src/audio_transcription_client_1.py +++ b/docs/griptape-tools/official-tools/src/audio_transcription_tool_1.py @@ -1,11 +1,11 @@ from griptape.drivers import OpenAiAudioTranscriptionDriver from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent -from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient +from griptape.tools.audio_transcription.tool import AudioTranscriptionTool driver = OpenAiAudioTranscriptionDriver(model="whisper-1") -tool = AudioTranscriptionClient( +tool = AudioTranscriptionTool( off_prompt=False, engine=AudioTranscriptionEngine( audio_transcription_driver=driver, diff --git a/docs/griptape-tools/official-tools/src/aws_iam_client_1.py b/docs/griptape-tools/official-tools/src/aws_iam_tool_1.py similarity index 72% rename from docs/griptape-tools/official-tools/src/aws_iam_client_1.py rename to docs/griptape-tools/official-tools/src/aws_iam_tool_1.py index 8fda31553..89718010f 100644 --- a/docs/griptape-tools/official-tools/src/aws_iam_client_1.py +++ b/docs/griptape-tools/official-tools/src/aws_iam_tool_1.py @@ -1,10 +1,10 @@ import boto3 from griptape.structures import Agent -from griptape.tools import AwsIamClient +from griptape.tools import AwsIamTool # Initialize the AWS IAM client -aws_iam_client = AwsIamClient(session=boto3.Session()) +aws_iam_client = AwsIamTool(session=boto3.Session()) # Create an agent with the AWS IAM client tool agent = Agent(tools=[aws_iam_client]) diff --git a/docs/griptape-tools/official-tools/src/aws_s3_client_1.py b/docs/griptape-tools/official-tools/src/aws_s3_tool_1.py similarity index 50% rename from docs/griptape-tools/official-tools/src/aws_s3_client_1.py rename to docs/griptape-tools/official-tools/src/aws_s3_tool_1.py index e1ba42525..973a0b881 100644 --- a/docs/griptape-tools/official-tools/src/aws_s3_client_1.py +++ b/docs/griptape-tools/official-tools/src/aws_s3_tool_1.py @@ -1,13 +1,13 @@ import boto3 from griptape.structures import Agent -from griptape.tools import AwsS3Client, TaskMemoryClient +from griptape.tools import AwsS3Tool, TaskMemoryTool # Initialize the AWS S3 client -aws_s3_client = AwsS3Client(session=boto3.Session(), off_prompt=True) +aws_s3_client = AwsS3Tool(session=boto3.Session(), off_prompt=True) # Create an agent with the AWS S3 client tool -agent = Agent(tools=[aws_s3_client, TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[aws_s3_client, TaskMemoryTool(off_prompt=False)]) # Task to list all the AWS S3 buckets agent.run("List all my S3 buckets.") diff --git a/docs/griptape-tools/official-tools/src/calculator_1.py b/docs/griptape-tools/official-tools/src/calculator_tool_1.py similarity index 56% rename from docs/griptape-tools/official-tools/src/calculator_1.py rename to docs/griptape-tools/official-tools/src/calculator_tool_1.py index b28bb92b2..1263cad45 100644 --- a/docs/griptape-tools/official-tools/src/calculator_1.py +++ b/docs/griptape-tools/official-tools/src/calculator_tool_1.py @@ -1,8 +1,8 @@ from griptape.structures import Agent -from griptape.tools import Calculator +from griptape.tools import CalculatorTool -# Create an agent with the Calculator tool -agent = Agent(tools=[Calculator()]) +# Create an agent with the CalculatorTool tool +agent = Agent(tools=[CalculatorTool()]) # Run the agent with a task to perform the arithmetic calculation of \(10^5\) agent.run("What is 10 raised to the power of 5?") diff --git a/docs/griptape-tools/official-tools/src/computer_1.py b/docs/griptape-tools/official-tools/src/computer_tool_1.py similarity index 77% rename from docs/griptape-tools/official-tools/src/computer_1.py rename to docs/griptape-tools/official-tools/src/computer_tool_1.py index 7fa22a46b..b0892d425 100644 --- a/docs/griptape-tools/official-tools/src/computer_1.py +++ b/docs/griptape-tools/official-tools/src/computer_tool_1.py @@ -1,10 +1,10 @@ from griptape.structures import Agent -from griptape.tools import Computer +from griptape.tools import ComputerTool -# Initialize the Computer tool -computer = Computer() +# Initialize the ComputerTool tool +computer = ComputerTool() -# Create an agent with the Computer tool +# Create an agent with the ComputerTool tool agent = Agent(tools=[computer]) # Create a file using the shell command diff --git a/docs/griptape-tools/official-tools/src/date_time_1.py b/docs/griptape-tools/official-tools/src/date_time_1.py deleted file mode 100644 index 735b77307..000000000 --- a/docs/griptape-tools/official-tools/src/date_time_1.py +++ /dev/null @@ -1,8 +0,0 @@ -from griptape.structures import Agent -from griptape.tools import DateTime - -# Create an agent with the DateTime tool -agent = Agent(tools=[DateTime()]) - -# Fetch the current date and time -agent.run("What is the current date and time?") diff --git a/docs/griptape-tools/official-tools/src/date_time_tool_1.py b/docs/griptape-tools/official-tools/src/date_time_tool_1.py new file mode 100644 index 000000000..f806e5091 --- /dev/null +++ b/docs/griptape-tools/official-tools/src/date_time_tool_1.py @@ -0,0 +1,8 @@ +from griptape.structures import Agent +from griptape.tools import DateTimeTool + +# Create an agent with the DateTimeTool tool +agent = Agent(tools=[DateTimeTool()]) + +# Fetch the current date and time +agent.run("What is the current date and time?") diff --git a/docs/griptape-tools/official-tools/src/email_client_1.py b/docs/griptape-tools/official-tools/src/email_tool_1.py similarity index 79% rename from docs/griptape-tools/official-tools/src/email_client_1.py rename to docs/griptape-tools/official-tools/src/email_tool_1.py index e93a74f34..e9d3b3cee 100644 --- a/docs/griptape-tools/official-tools/src/email_client_1.py +++ b/docs/griptape-tools/official-tools/src/email_tool_1.py @@ -1,8 +1,8 @@ import os -from griptape.tools import EmailClient +from griptape.tools import EmailTool -email_client = EmailClient( +email_tool = EmailTool( smtp_host=os.environ.get("SMTP_HOST"), smtp_port=int(os.environ.get("SMTP_PORT", 465)), smtp_password=os.environ.get("SMTP_PASSWORD"), diff --git a/docs/griptape-tools/official-tools/src/file_manager_1.py b/docs/griptape-tools/official-tools/src/file_manager_tool_1.py similarity index 73% rename from docs/griptape-tools/official-tools/src/file_manager_1.py rename to docs/griptape-tools/official-tools/src/file_manager_tool_1.py index 16adf669c..0b5596d2b 100644 --- a/docs/griptape-tools/official-tools/src/file_manager_1.py +++ b/docs/griptape-tools/official-tools/src/file_manager_tool_1.py @@ -1,10 +1,10 @@ from pathlib import Path from griptape.structures import Agent -from griptape.tools import FileManager +from griptape.tools import FileManagerTool -# Initialize the FileManager tool with the current directory as its base -file_manager_tool = FileManager() +# Initialize the FileManagerTool tool with the current directory as its base +file_manager_tool = FileManagerTool() # Add the tool to the Agent agent = Agent(tools=[file_manager_tool]) diff --git a/docs/griptape-tools/official-tools/src/google_cal_client_1.py b/docs/griptape-tools/official-tools/src/google_calendar_tool_1.py similarity index 79% rename from docs/griptape-tools/official-tools/src/google_cal_client_1.py rename to docs/griptape-tools/official-tools/src/google_calendar_tool_1.py index 1b99d9ec4..afbb20c9f 100644 --- a/docs/griptape-tools/official-tools/src/google_cal_client_1.py +++ b/docs/griptape-tools/official-tools/src/google_calendar_tool_1.py @@ -1,10 +1,10 @@ import os from griptape.structures import Agent -from griptape.tools import GoogleCalendarClient +from griptape.tools import GoogleCalendarTool -# Create the GoogleCalendarClient tool -google_calendar_tool = GoogleCalendarClient( +# Create the GoogleCalendarTool tool +google_calendarendar_tool = GoogleCalendarTool( service_account_credentials={ "type": os.environ["GOOGLE_ACCOUNT_TYPE"], "project_id": os.environ["GOOGLE_PROJECT_ID"], @@ -20,8 +20,8 @@ owner_email=os.environ["GOOGLE_OWNER_EMAIL"], ) -# Set up an agent using the GoogleCalendarClient tool -agent = Agent(tools=[google_calendar_tool]) +# Set up an agent using the GoogleCalendarTool tool +agent = Agent(tools=[google_calendarendar_tool]) # Task: Get upcoming events from a Google calendar agent.run( diff --git a/docs/griptape-tools/official-tools/src/google_docs_client_1.py b/docs/griptape-tools/official-tools/src/google_docs_tool_1.py similarity index 85% rename from docs/griptape-tools/official-tools/src/google_docs_client_1.py rename to docs/griptape-tools/official-tools/src/google_docs_tool_1.py index 473bfbfd8..0d8e8a3cb 100644 --- a/docs/griptape-tools/official-tools/src/google_docs_client_1.py +++ b/docs/griptape-tools/official-tools/src/google_docs_tool_1.py @@ -1,10 +1,10 @@ import os from griptape.structures import Agent -from griptape.tools import GoogleDocsClient +from griptape.tools import GoogleDocsTool -# Create the GoogleDocsClient tool -google_docs_tool = GoogleDocsClient( +# Create the GoogleDocsTool tool +google_docs_tool = GoogleDocsTool( service_account_credentials={ "type": os.environ["GOOGLE_ACCOUNT_TYPE"], "project_id": os.environ["GOOGLE_PROJECT_ID"], @@ -20,7 +20,7 @@ owner_email=os.environ["GOOGLE_OWNER_EMAIL"], ) -# Set up an agent using the GoogleDocsClient tool +# Set up an agent using the GoogleDocsTool tool agent = Agent(tools=[google_docs_tool]) # Task: Create a new Google Doc and save content to it diff --git a/docs/griptape-tools/official-tools/src/google_drive_client_1.py b/docs/griptape-tools/official-tools/src/google_drive_tool_1.py similarity index 84% rename from docs/griptape-tools/official-tools/src/google_drive_client_1.py rename to docs/griptape-tools/official-tools/src/google_drive_tool_1.py index a020b1a96..d8e43a6db 100644 --- a/docs/griptape-tools/official-tools/src/google_drive_client_1.py +++ b/docs/griptape-tools/official-tools/src/google_drive_tool_1.py @@ -1,10 +1,10 @@ import os from griptape.structures import Agent -from griptape.tools import GoogleDriveClient +from griptape.tools import GoogleDriveTool -# Create the GoogleDriveClient tool -google_drive_tool = GoogleDriveClient( +# Create the GoogleDriveTool tool +google_drive_tool = GoogleDriveTool( service_account_credentials={ "type": os.environ["GOOGLE_ACCOUNT_TYPE"], "project_id": os.environ["GOOGLE_PROJECT_ID"], @@ -20,7 +20,7 @@ owner_email=os.environ["GOOGLE_OWNER_EMAIL"], ) -# Set up an agent using the GoogleDriveClient tool +# Set up an agent using the GoogleDriveTool tool agent = Agent(tools=[google_drive_tool]) # Task: Save content to my Google Drive (default directory is root) diff --git a/docs/griptape-tools/official-tools/src/google_gmail_client_1.py b/docs/griptape-tools/official-tools/src/google_gmail_tool_1.py similarity index 85% rename from docs/griptape-tools/official-tools/src/google_gmail_client_1.py rename to docs/griptape-tools/official-tools/src/google_gmail_tool_1.py index e9a075fa8..44e0ceb39 100644 --- a/docs/griptape-tools/official-tools/src/google_gmail_client_1.py +++ b/docs/griptape-tools/official-tools/src/google_gmail_tool_1.py @@ -1,10 +1,10 @@ import os from griptape.structures import Agent -from griptape.tools import GoogleGmailClient +from griptape.tools import GoogleGmailTool -# Create the GoogleGmailClient tool -gmail_tool = GoogleGmailClient( +# Create the GoogleGmailTool tool +gmail_tool = GoogleGmailTool( service_account_credentials={ "type": os.environ["GOOGLE_ACCOUNT_TYPE"], "project_id": os.environ["GOOGLE_PROJECT_ID"], @@ -20,7 +20,7 @@ owner_email=os.environ["GOOGLE_OWNER_EMAIL"], ) -# Set up an agent using the GoogleGmailClient tool +# Set up an agent using the GoogleGmailTool tool agent = Agent(tools=[gmail_tool]) # Task: Create a draft email in GMail diff --git a/docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_client_1.py b/docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_tool_1.py similarity index 75% rename from docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_client_1.py rename to docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_tool_1.py index 9cfd09a22..b8c294f6b 100644 --- a/docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_client_1.py +++ b/docs/griptape-tools/official-tools/src/griptape_cloud_knowledge_base_tool_1.py @@ -1,9 +1,9 @@ import os from griptape.structures import Agent -from griptape.tools import GriptapeCloudKnowledgeBaseClient +from griptape.tools import GriptapeCloudKnowledgeBaseTool -knowledge_base_client = GriptapeCloudKnowledgeBaseClient( +knowledge_base_client = GriptapeCloudKnowledgeBaseTool( description="Contains information about the company and its operations", api_key=os.environ["GRIPTAPE_CLOUD_API_KEY"], knowledge_base_id=os.environ["GRIPTAPE_CLOUD_KB_ID"], diff --git a/docs/griptape-tools/official-tools/src/image_query_client_1.py b/docs/griptape-tools/official-tools/src/image_query_tool_1.py similarity index 80% rename from docs/griptape-tools/official-tools/src/image_query_client_1.py rename to docs/griptape-tools/official-tools/src/image_query_tool_1.py index 177154d2d..a4d69eafb 100644 --- a/docs/griptape-tools/official-tools/src/image_query_client_1.py +++ b/docs/griptape-tools/official-tools/src/image_query_tool_1.py @@ -1,7 +1,7 @@ from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.structures import Agent -from griptape.tools import ImageQueryClient +from griptape.tools import ImageQueryTool # Create an Image Query Driver. driver = OpenAiImageQueryDriver(model="gpt-4o") @@ -11,8 +11,8 @@ image_query_driver=driver, ) -# Create an Image Query Client configured to use the engine. -tool = ImageQueryClient( +# Create an Image Query Tool configured to use the engine. +tool = ImageQueryTool( image_query_engine=engine, ) diff --git a/docs/griptape-tools/official-tools/src/inpainting_image_generation_client_1.py b/docs/griptape-tools/official-tools/src/inpainting_image_generation_tool_1.py similarity index 90% rename from docs/griptape-tools/official-tools/src/inpainting_image_generation_client_1.py rename to docs/griptape-tools/official-tools/src/inpainting_image_generation_tool_1.py index 1042a4567..5821e1b40 100644 --- a/docs/griptape-tools/official-tools/src/inpainting_image_generation_client_1.py +++ b/docs/griptape-tools/official-tools/src/inpainting_image_generation_tool_1.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import InpaintingImageGenerationEngine from griptape.structures import Agent -from griptape.tools import InpaintingImageGenerationClient +from griptape.tools import InpaintingImageGenerationTool # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( @@ -15,7 +15,7 @@ ) # Create a tool configured to use the engine. -tool = InpaintingImageGenerationClient( +tool = InpaintingImageGenerationTool( engine=engine, ) diff --git a/docs/griptape-tools/official-tools/src/openweather_client_1.py b/docs/griptape-tools/official-tools/src/openweather_tool_1.py similarity index 75% rename from docs/griptape-tools/official-tools/src/openweather_client_1.py rename to docs/griptape-tools/official-tools/src/openweather_tool_1.py index 2156e24da..b592620fa 100644 --- a/docs/griptape-tools/official-tools/src/openweather_client_1.py +++ b/docs/griptape-tools/official-tools/src/openweather_tool_1.py @@ -1,11 +1,11 @@ import os from griptape.structures import Agent -from griptape.tools import OpenWeatherClient +from griptape.tools import OpenWeatherTool agent = Agent( tools=[ - OpenWeatherClient( + OpenWeatherTool( api_key=os.environ["OPENWEATHER_API_KEY"], ), ] diff --git a/docs/griptape-tools/official-tools/src/outpainting_image_generation_client_1.py b/docs/griptape-tools/official-tools/src/outpainting_image_generation_tool_1.py similarity index 89% rename from docs/griptape-tools/official-tools/src/outpainting_image_generation_client_1.py rename to docs/griptape-tools/official-tools/src/outpainting_image_generation_tool_1.py index bc7eb8585..79606a965 100644 --- a/docs/griptape-tools/official-tools/src/outpainting_image_generation_client_1.py +++ b/docs/griptape-tools/official-tools/src/outpainting_image_generation_tool_1.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import OutpaintingImageGenerationEngine from griptape.structures import Agent -from griptape.tools import OutpaintingImageGenerationClient +from griptape.tools import OutpaintingImageGenerationTool # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( @@ -15,7 +15,7 @@ ) # Create a tool configured to use the engine. -tool = OutpaintingImageGenerationClient( +tool = OutpaintingImageGenerationTool( engine=engine, ) diff --git a/docs/griptape-tools/official-tools/src/prompt_image_generation_client_1.py b/docs/griptape-tools/official-tools/src/prompt_image_generation_tool_1.py similarity index 89% rename from docs/griptape-tools/official-tools/src/prompt_image_generation_client_1.py rename to docs/griptape-tools/official-tools/src/prompt_image_generation_tool_1.py index f75f904b6..0173cc185 100644 --- a/docs/griptape-tools/official-tools/src/prompt_image_generation_client_1.py +++ b/docs/griptape-tools/official-tools/src/prompt_image_generation_tool_1.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( @@ -15,7 +15,7 @@ ) # Create a tool configured to use the engine. -tool = PromptImageGenerationClient( +tool = PromptImageGenerationTool( engine=engine, ) diff --git a/docs/griptape-tools/official-tools/src/rag_client_1.py b/docs/griptape-tools/official-tools/src/rag_tool_1.py similarity index 93% rename from docs/griptape-tools/official-tools/src/rag_client_1.py rename to docs/griptape-tools/official-tools/src/rag_tool_1.py index 01e71e253..7cefd065b 100644 --- a/docs/griptape-tools/official-tools/src/rag_client_1.py +++ b/docs/griptape-tools/official-tools/src/rag_tool_1.py @@ -4,7 +4,7 @@ from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage from griptape.structures import Agent -from griptape.tools import RagClient +from griptape.tools import RagTool vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) @@ -15,7 +15,7 @@ vector_store_driver.upsert_text_artifact(artifact=artifact, namespace="griptape") -rag_client = RagClient( +rag_tool = RagTool( description="Contains information about Griptape", off_prompt=False, rag_engine=RagEngine( @@ -32,6 +32,6 @@ ), ) -agent = Agent(tools=[rag_client]) +agent = Agent(tools=[rag_tool]) agent.run("what is Griptape?") diff --git a/docs/griptape-tools/official-tools/src/rest_api_client_1.py b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py similarity index 98% rename from docs/griptape-tools/official-tools/src/rest_api_client_1.py rename to docs/griptape-tools/official-tools/src/rest_api_tool_1.py index 026874283..2093163b7 100644 --- a/docs/griptape-tools/official-tools/src/rest_api_client_1.py +++ b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py @@ -5,13 +5,13 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask -from griptape.tools import RestApiClient +from griptape.tools import RestApiTool config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), ) -posts_client = RestApiClient( +posts_client = RestApiTool( base_url="https://jsonplaceholder.typicode.com", path="posts", description="Allows for creating, updating, deleting, patching, and getting posts.", diff --git a/docs/griptape-tools/official-tools/src/sql_client_1.py b/docs/griptape-tools/official-tools/src/sql_tool_1.py similarity index 91% rename from docs/griptape-tools/official-tools/src/sql_client_1.py rename to docs/griptape-tools/official-tools/src/sql_tool_1.py index 3e89d6096..f7630891f 100644 --- a/docs/griptape-tools/official-tools/src/sql_client_1.py +++ b/docs/griptape-tools/official-tools/src/sql_tool_1.py @@ -5,7 +5,7 @@ from griptape.drivers import AmazonRedshiftSqlDriver from griptape.loaders import SqlLoader from griptape.structures import Agent -from griptape.tools import SqlClient +from griptape.tools import SqlTool session = boto3.Session() @@ -17,7 +17,7 @@ ) ) -sql_tool = SqlClient( +sql_tool = SqlTool( sql_loader=sql_loader, table_name="people", table_description="contains information about tech industry professionals", diff --git a/docs/griptape-tools/official-tools/src/structure_run_client_1.py b/docs/griptape-tools/official-tools/src/structure_run_tool_1.py similarity index 83% rename from docs/griptape-tools/official-tools/src/structure_run_client_1.py rename to docs/griptape-tools/official-tools/src/structure_run_tool_1.py index 10f48b80d..575092ce6 100644 --- a/docs/griptape-tools/official-tools/src/structure_run_client_1.py +++ b/docs/griptape-tools/official-tools/src/structure_run_tool_1.py @@ -2,13 +2,13 @@ from griptape.drivers import GriptapeCloudStructureRunDriver from griptape.structures import Agent -from griptape.tools import StructureRunClient +from griptape.tools import StructureRunTool base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"] api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"] structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] -structure_run_tool = StructureRunClient( +structure_run_tool = StructureRunTool( description="RAG Expert Agent - Structure to invoke with natural language queries about the topic of Retrieval Augmented Generation", driver=GriptapeCloudStructureRunDriver( base_url=base_url, @@ -17,7 +17,7 @@ ), ) -# Set up an agent using the StructureRunClient tool +# Set up an agent using the StructureRunTool tool agent = Agent(tools=[structure_run_tool]) # Task: Ask the Griptape Cloud Hosted Structure about modular RAG diff --git a/docs/griptape-tools/official-tools/src/task_memory_client_1.py b/docs/griptape-tools/official-tools/src/task_memory_client_1.py deleted file mode 100644 index e9c2562a1..000000000 --- a/docs/griptape-tools/official-tools/src/task_memory_client_1.py +++ /dev/null @@ -1,4 +0,0 @@ -from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper - -Agent(tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)]) diff --git a/docs/griptape-tools/official-tools/src/task_memory_tool_1.py b/docs/griptape-tools/official-tools/src/task_memory_tool_1.py new file mode 100644 index 000000000..f5c2c487f --- /dev/null +++ b/docs/griptape-tools/official-tools/src/task_memory_tool_1.py @@ -0,0 +1,4 @@ +from griptape.structures import Agent +from griptape.tools import TaskMemoryTool, WebScraperTool + +Agent(tools=[WebScraperTool(off_prompt=True), TaskMemoryTool(off_prompt=False)]) diff --git a/docs/griptape-tools/official-tools/src/text_to_speech_client_1.py b/docs/griptape-tools/official-tools/src/text_to_speech_tool_1.py similarity index 81% rename from docs/griptape-tools/official-tools/src/text_to_speech_client_1.py rename to docs/griptape-tools/official-tools/src/text_to_speech_tool_1.py index c6a03b80d..376113d63 100644 --- a/docs/griptape-tools/official-tools/src/text_to_speech_client_1.py +++ b/docs/griptape-tools/official-tools/src/text_to_speech_tool_1.py @@ -3,7 +3,7 @@ from griptape.drivers import ElevenLabsTextToSpeechDriver from griptape.engines import TextToSpeechEngine from griptape.structures import Agent -from griptape.tools.text_to_speech_client.tool import TextToSpeechClient +from griptape.tools.text_to_speech.tool import TextToSpeechTool driver = ElevenLabsTextToSpeechDriver( api_key=os.environ["ELEVEN_LABS_API_KEY"], @@ -11,7 +11,7 @@ voice="Matilda", ) -tool = TextToSpeechClient( +tool = TextToSpeechTool( engine=TextToSpeechEngine( text_to_speech_driver=driver, ), diff --git a/docs/griptape-tools/official-tools/src/variation_image_generation_client_1.py b/docs/griptape-tools/official-tools/src/variation_image_generation_tool_1.py similarity index 90% rename from docs/griptape-tools/official-tools/src/variation_image_generation_client_1.py rename to docs/griptape-tools/official-tools/src/variation_image_generation_tool_1.py index 6c4432d52..209d97a7b 100644 --- a/docs/griptape-tools/official-tools/src/variation_image_generation_client_1.py +++ b/docs/griptape-tools/official-tools/src/variation_image_generation_tool_1.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import VariationImageGenerationEngine from griptape.structures import Agent -from griptape.tools import VariationImageGenerationClient +from griptape.tools import VariationImageGenerationTool # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( @@ -17,7 +17,7 @@ ) # Create a tool configured to use the engine. -tool = VariationImageGenerationClient( +tool = VariationImageGenerationTool( engine=engine, ) diff --git a/docs/griptape-tools/official-tools/src/variation_image_generation_client_2.py b/docs/griptape-tools/official-tools/src/variation_image_generation_tool_2.py similarity index 89% rename from docs/griptape-tools/official-tools/src/variation_image_generation_client_2.py rename to docs/griptape-tools/official-tools/src/variation_image_generation_tool_2.py index d98aec199..036b75d48 100644 --- a/docs/griptape-tools/official-tools/src/variation_image_generation_client_2.py +++ b/docs/griptape-tools/official-tools/src/variation_image_generation_tool_2.py @@ -1,7 +1,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, BedrockStableDiffusionImageGenerationModelDriver from griptape.engines import PromptImageGenerationEngine, VariationImageGenerationEngine from griptape.structures import Agent -from griptape.tools import PromptImageGenerationClient, VariationImageGenerationClient +from griptape.tools import PromptImageGenerationTool, VariationImageGenerationTool # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( @@ -17,7 +17,7 @@ ) # Create a prompt image generation client configured to use the engine. -prompt_tool = PromptImageGenerationClient( +prompt_tool = PromptImageGenerationTool( engine=prompt_engine, ) @@ -27,7 +27,7 @@ ) # Create a variation image generation client configured to use the engine. -variation_tool = VariationImageGenerationClient( +variation_tool = VariationImageGenerationTool( engine=variation_engine, ) diff --git a/docs/griptape-tools/official-tools/src/vector_store_client_1.py b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py similarity index 81% rename from docs/griptape-tools/official-tools/src/vector_store_client_1.py rename to docs/griptape-tools/official-tools/src/vector_store_tool_1.py index df9117960..57ca7b25b 100644 --- a/docs/griptape-tools/official-tools/src/vector_store_client_1.py +++ b/docs/griptape-tools/official-tools/src/vector_store_tool_1.py @@ -2,7 +2,7 @@ from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver from griptape.loaders import WebLoader from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, VectorStoreClient +from griptape.tools import TaskMemoryTool, VectorStoreTool vector_store_driver = LocalVectorStoreDriver( embedding_driver=OpenAiEmbeddingDriver(), @@ -13,13 +13,13 @@ raise Exception(artifacts.value) vector_store_driver.upsert_text_artifacts({"griptape": artifacts}) -vector_db = VectorStoreClient( +vector_db = VectorStoreTool( description="This DB has information about the Griptape Python framework", vector_store_driver=vector_store_driver, query_params={"namespace": "griptape"}, off_prompt=True, ) -agent = Agent(tools=[vector_db, TaskMemoryClient(off_prompt=False)]) +agent = Agent(tools=[vector_db, TaskMemoryTool(off_prompt=False)]) agent.run("what is Griptape?") diff --git a/docs/griptape-tools/official-tools/src/web_scraper_1.py b/docs/griptape-tools/official-tools/src/web_scraper_1.py deleted file mode 100644 index 138e8600f..000000000 --- a/docs/griptape-tools/official-tools/src/web_scraper_1.py +++ /dev/null @@ -1,6 +0,0 @@ -from griptape.structures import Agent -from griptape.tools import TaskMemoryClient, WebScraper - -agent = Agent(tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)]) - -agent.run("Based on https://www.griptape.ai/, tell me what griptape is") diff --git a/docs/griptape-tools/official-tools/src/web_scraper_tool_1.py b/docs/griptape-tools/official-tools/src/web_scraper_tool_1.py new file mode 100644 index 000000000..7716acde4 --- /dev/null +++ b/docs/griptape-tools/official-tools/src/web_scraper_tool_1.py @@ -0,0 +1,6 @@ +from griptape.structures import Agent +from griptape.tools import TaskMemoryTool, WebScraperTool + +agent = Agent(tools=[WebScraperTool(off_prompt=True), TaskMemoryTool(off_prompt=False)]) + +agent.run("Based on https://www.griptape.ai/, tell me what griptape is") diff --git a/docs/griptape-tools/official-tools/src/web_search_1.py b/docs/griptape-tools/official-tools/src/web_search_tool_1.py similarity index 71% rename from docs/griptape-tools/official-tools/src/web_search_1.py rename to docs/griptape-tools/official-tools/src/web_search_tool_1.py index 70603a693..3469ad7f9 100644 --- a/docs/griptape-tools/official-tools/src/web_search_1.py +++ b/docs/griptape-tools/official-tools/src/web_search_tool_1.py @@ -2,10 +2,10 @@ from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent -from griptape.tools import WebSearch +from griptape.tools import WebSearchTool -# Initialize the WebSearch tool with necessary parameters -web_search_tool = WebSearch( +# Initialize the WebSearchTool tool with necessary parameters +web_search_tool = WebSearchTool( web_search_driver=GoogleWebSearchDriver( api_key=os.environ["GOOGLE_API_KEY"], search_id=os.environ["GOOGLE_API_SEARCH_ID"], @@ -15,7 +15,7 @@ ), ) -# Set up an agent using the WebSearch tool +# Set up an agent using the WebSearchTool tool agent = Agent(tools=[web_search_tool]) # Task: Search the web for a specific query diff --git a/docs/griptape-tools/official-tools/src/web_search_2.py b/docs/griptape-tools/official-tools/src/web_search_tool_2.py similarity index 91% rename from docs/griptape-tools/official-tools/src/web_search_2.py rename to docs/griptape-tools/official-tools/src/web_search_tool_2.py index b40eac094..dd9c32655 100644 --- a/docs/griptape-tools/official-tools/src/web_search_2.py +++ b/docs/griptape-tools/official-tools/src/web_search_tool_2.py @@ -4,11 +4,11 @@ from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent -from griptape.tools import WebSearch +from griptape.tools import WebSearchTool agent = Agent( tools=[ - WebSearch( + WebSearchTool( web_search_driver=GoogleWebSearchDriver( api_key=os.environ["GOOGLE_API_KEY"], search_id=os.environ["GOOGLE_API_SEARCH_ID"], diff --git a/docs/griptape-tools/official-tools/structure-run-client.md b/docs/griptape-tools/official-tools/structure-run-tool.md similarity index 90% rename from docs/griptape-tools/official-tools/structure-run-client.md rename to docs/griptape-tools/official-tools/structure-run-tool.md index 139633059..7b73e5b52 100644 --- a/docs/griptape-tools/official-tools/structure-run-client.md +++ b/docs/griptape-tools/official-tools/structure-run-tool.md @@ -1,20 +1,20 @@ -# StructureRunClient +# Structure Run Tool -The StructureRunClient Tool provides a way to run Structures via a Tool. +The [StructureRunTool](../../reference/griptape/tools/structure_run/tool.md) Tool provides a way to run Structures via a Tool. It requires you to provide a [Structure Run Driver](../../griptape-framework/drivers/structure-run-drivers.md) to run the Structure in the desired environment. ```python ---8<-- "docs/griptape-tools/official-tools/src/structure_run_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/structure_run_tool_1.py" ``` ``` [05/02/24 13:50:03] INFO ToolkitTask 4e9458375bda4fbcadb77a94624ed64c Input: what is modular RAG? [05/02/24 13:50:10] INFO Subtask 5ef2d72028fc495aa7faf6f46825b004 - Thought: To answer this question, I need to run a search for the term "modular RAG". I will use the StructureRunClient action to execute a + Thought: To answer this question, I need to run a search for the term "modular RAG". I will use the StructureRunTool action to execute a search structure. Actions: [ { - "name": "StructureRunClient", + "name": "StructureRunTool", "path": "run_structure", "input": { "values": { diff --git a/docs/griptape-tools/official-tools/task-memory-client.md b/docs/griptape-tools/official-tools/task-memory-tool.md similarity index 74% rename from docs/griptape-tools/official-tools/task-memory-client.md rename to docs/griptape-tools/official-tools/task-memory-tool.md index fa88c85b9..a988c3a64 100644 --- a/docs/griptape-tools/official-tools/task-memory-client.md +++ b/docs/griptape-tools/official-tools/task-memory-tool.md @@ -1,7 +1,7 @@ -# TaskMemoryClient +# Task Memory Tool Tool This tool enables LLMs to query and summarize task outputs that are stored in short-term tool memory. This tool uniquely requires the user to set the `off_prompt` property explicitly for usability reasons (Griptape doesn't provide the default `True` value). ```python ---8<-- "docs/griptape-tools/official-tools/src/task_memory_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/task_memory_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/text-to-speech-client.md b/docs/griptape-tools/official-tools/text-to-speech-tool.md similarity index 72% rename from docs/griptape-tools/official-tools/text-to-speech-client.md rename to docs/griptape-tools/official-tools/text-to-speech-tool.md index d7fa043a7..ac3f54f8e 100644 --- a/docs/griptape-tools/official-tools/text-to-speech-client.md +++ b/docs/griptape-tools/official-tools/text-to-speech-tool.md @@ -1,7 +1,7 @@ -# TextToSpeechClient +# Text To Speech Tool This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). ```python ---8<-- "docs/griptape-tools/official-tools/src/text_to_speech_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/text_to_speech_tool_1.py" ``` diff --git a/docs/griptape-tools/official-tools/variation-image-generation-client.md b/docs/griptape-tools/official-tools/variation-image-generation-tool.md similarity index 83% rename from docs/griptape-tools/official-tools/variation-image-generation-client.md rename to docs/griptape-tools/official-tools/variation-image-generation-tool.md index 4b5880ef8..bcc8c3f61 100644 --- a/docs/griptape-tools/official-tools/variation-image-generation-client.md +++ b/docs/griptape-tools/official-tools/variation-image-generation-tool.md @@ -1,15 +1,15 @@ -# VariationImageGenerationEngine +# Variation Image Generation Engine Tool This Tool allows LLMs to generate variations of an input image from a text prompt. The input image can be provided either by its file path or by its [Task Memory](../../griptape-framework/structures/task-memory.md) reference. ## Referencing an Image by File Path ```python ---8<-- "docs/griptape-tools/official-tools/src/variation_image_generation_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/variation_image_generation_tool_1.py" ``` ## Referencing an Image in Task Memory ```python ---8<-- "docs/griptape-tools/official-tools/src/variation_image_generation_client_2.py" +--8<-- "docs/griptape-tools/official-tools/src/variation_image_generation_client_tool_2.py" ``` diff --git a/docs/griptape-tools/official-tools/vector-store-client.md b/docs/griptape-tools/official-tools/vector-store-client.md deleted file mode 100644 index 8fd280ec9..000000000 --- a/docs/griptape-tools/official-tools/vector-store-client.md +++ /dev/null @@ -1,7 +0,0 @@ -The [VectorStoreClient](../../reference/griptape/tools/vector_store_client/tool.md) enables LLMs to query vector stores. - -Here is an example of how it can be used with a local vector store driver: - -```python ---8<-- "docs/griptape-tools/official-tools/src/vector_store_client_1.py" -``` diff --git a/docs/griptape-tools/official-tools/vector-store-tool.md b/docs/griptape-tools/official-tools/vector-store-tool.md new file mode 100644 index 000000000..7317c25db --- /dev/null +++ b/docs/griptape-tools/official-tools/vector-store-tool.md @@ -0,0 +1,9 @@ +# Vector Store Tool + +The [VectorStoreTool](../../reference/griptape/tools/vector_store/tool.md) enables LLMs to query vector stores. + +Here is an example of how it can be used with a local vector store driver: + +```python +--8<-- "docs/griptape-tools/official-tools/src/vector_store_tool_1.py" +``` diff --git a/docs/griptape-tools/official-tools/web-scraper.md b/docs/griptape-tools/official-tools/web-scraper-tool.md similarity index 96% rename from docs/griptape-tools/official-tools/web-scraper.md rename to docs/griptape-tools/official-tools/web-scraper-tool.md index 5d8e1fe27..1261d955d 100644 --- a/docs/griptape-tools/official-tools/web-scraper.md +++ b/docs/griptape-tools/official-tools/web-scraper-tool.md @@ -1,9 +1,9 @@ -# WebScraper +# Web Scraper Tool This tool enables LLMs to scrape web pages for full text, summaries, authors, titles, and keywords. It can also execute search queries to answer specific questions about the page. This tool uses OpenAI APIs for some of its activities, so in order to use it provide a valid API key in `openai_api_key`. ```python ---8<-- "docs/griptape-tools/official-tools/src/web_scraper_1.py" +--8<-- "docs/griptape-tools/official-tools/src/web_scraper_tool_1.py" ``` ``` [09/11/23 15:27:39] INFO Task dd9ad12c5c1e4280a6e20d7c116303ed @@ -29,7 +29,7 @@ This tool enables LLMs to scrape web pages for full text, summaries, authors, ti in memory. I can use the TaskMemory tool with the summarize activity to get a summary of the content. - Action: {"name": "TaskMemoryClient", "path": + Action: {"name": "TaskMemoryTool", "path": "summarize", "input": {"values": {"memory_name": "TaskMemory", "artifact_namespace": "02da5930b8d74f7ca30aecc3760a3318"}}} diff --git a/docs/griptape-tools/official-tools/web-search.md b/docs/griptape-tools/official-tools/web-search-tool.md similarity index 97% rename from docs/griptape-tools/official-tools/web-search.md rename to docs/griptape-tools/official-tools/web-search-tool.md index 3d0495229..3f31fd4fd 100644 --- a/docs/griptape-tools/official-tools/web-search.md +++ b/docs/griptape-tools/official-tools/web-search-tool.md @@ -1,9 +1,9 @@ -# WebSearch +# Web Search Tool This tool enables LLMs to search the web. ```python ---8<-- "docs/griptape-tools/official-tools/src/web_search_1.py" +--8<-- "docs/griptape-tools/official-tools/src/web_search_tool_1.py" ``` ``` [09/08/23 15:37:25] INFO Task 2cf557f7f7cd4a20a7fa2f0c46af2f71 @@ -93,5 +93,5 @@ Extra schema properties can be added to the Tool to allow for more customization In this example, we add a `sort` property to the `search` Activity which will be added as a [Google custom search query parameter](https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list). ```python ---8<-- "docs/griptape-tools/official-tools/src/web_search_2.py" +--8<-- "docs/griptape-tools/official-tools/src/web_search_tool_2.py" ``` diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index d99b63b6c..c33b9f0b0 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -1,67 +1,67 @@ from .base_tool import BaseTool -from .base_image_generation_client import BaseImageGenerationClient -from .calculator.tool import Calculator -from .web_search.tool import WebSearch -from .web_scraper.tool import WebScraper -from .sql_client.tool import SqlClient -from .email_client.tool import EmailClient -from .rest_api_client.tool import RestApiClient -from .file_manager.tool import FileManager -from .vector_store_client.tool import VectorStoreClient -from .date_time.tool import DateTime -from .task_memory_client.tool import TaskMemoryClient -from .base_aws_client import BaseAwsClient -from .aws_iam_client.tool import AwsIamClient -from .aws_s3_client.tool import AwsS3Client -from .computer.tool import Computer -from .base_google_client import BaseGoogleClient -from .google_gmail.tool import GoogleGmailClient -from .google_cal.tool import GoogleCalendarClient -from .google_docs.tool import GoogleDocsClient -from .google_drive.tool import GoogleDriveClient -from .openweather_client.tool import OpenWeatherClient -from .prompt_image_generation_client.tool import PromptImageGenerationClient -from .variation_image_generation_client.tool import VariationImageGenerationClient -from .inpainting_image_generation_client.tool import InpaintingImageGenerationClient -from .outpainting_image_generation_client.tool import OutpaintingImageGenerationClient -from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient -from .structure_run_client.tool import StructureRunClient -from .image_query_client.tool import ImageQueryClient -from .rag_client.tool import RagClient -from .text_to_speech_client.tool import TextToSpeechClient -from .audio_transcription_client.tool import AudioTranscriptionClient +from .base_image_generation_tool import BaseImageGenerationTool +from .calculator.tool import CalculatorTool +from .web_search.tool import WebSearchTool +from .web_scraper.tool import WebScraperTool +from .sql.tool import SqlTool +from .email.tool import EmailTool +from .rest_api.tool import RestApiTool +from .file_manager.tool import FileManagerTool +from .vector_store.tool import VectorStoreTool +from .date_time.tool import DateTimeTool +from .task_memory.tool import TaskMemoryTool +from .base_aws_tool import BaseAwsTool +from .aws_iam.tool import AwsIamTool +from .aws_s3.tool import AwsS3Tool +from .computer.tool import ComputerTool +from .base_google_tool import BaseGoogleTool +from .google_gmail.tool import GoogleGmailTool +from .google_calendar.tool import GoogleCalendarTool +from .google_docs.tool import GoogleDocsTool +from .google_drive.tool import GoogleDriveTool +from .openweather.tool import OpenWeatherTool +from .prompt_image_generation.tool import PromptImageGenerationTool +from .variation_image_generation.tool import VariationImageGenerationTool +from .inpainting_image_generation.tool import InpaintingImageGenerationTool +from .outpainting_image_generation.tool import OutpaintingImageGenerationTool +from .griptape_cloud_knowledge_base.tool import GriptapeCloudKnowledgeBaseTool +from .structure_run.tool import StructureRunTool +from .image_query.tool import ImageQueryTool +from .rag.tool import RagTool +from .text_to_speech.tool import TextToSpeechTool +from .audio_transcription.tool import AudioTranscriptionTool __all__ = [ "BaseTool", - "BaseImageGenerationClient", - "BaseAwsClient", - "AwsIamClient", - "AwsS3Client", - "BaseGoogleClient", - "GoogleGmailClient", - "GoogleDocsClient", - "GoogleCalendarClient", - "GoogleDriveClient", - "Calculator", - "WebSearch", - "WebScraper", - "SqlClient", - "EmailClient", - "RestApiClient", - "FileManager", - "VectorStoreClient", - "DateTime", - "TaskMemoryClient", - "Computer", - "OpenWeatherClient", - "PromptImageGenerationClient", - "VariationImageGenerationClient", - "InpaintingImageGenerationClient", - "OutpaintingImageGenerationClient", - "GriptapeCloudKnowledgeBaseClient", - "StructureRunClient", - "ImageQueryClient", - "RagClient", - "TextToSpeechClient", - "AudioTranscriptionClient", + "BaseImageGenerationTool", + "BaseAwsTool", + "AwsIamTool", + "AwsS3Tool", + "BaseGoogleTool", + "GoogleGmailTool", + "GoogleDocsTool", + "GoogleCalendarTool", + "GoogleDriveTool", + "CalculatorTool", + "WebSearchTool", + "WebScraperTool", + "SqlTool", + "EmailTool", + "RestApiTool", + "FileManagerTool", + "VectorStoreTool", + "DateTimeTool", + "TaskMemoryTool", + "ComputerTool", + "OpenWeatherTool", + "PromptImageGenerationTool", + "VariationImageGenerationTool", + "InpaintingImageGenerationTool", + "OutpaintingImageGenerationTool", + "GriptapeCloudKnowledgeBaseTool", + "StructureRunTool", + "ImageQueryTool", + "RagTool", + "TextToSpeechTool", + "AudioTranscriptionTool", ] diff --git a/griptape/tools/audio_transcription_client/__init__.py b/griptape/tools/audio_transcription/__init__.py similarity index 100% rename from griptape/tools/audio_transcription_client/__init__.py rename to griptape/tools/audio_transcription/__init__.py diff --git a/griptape/tools/audio_transcription_client/manifest.yml b/griptape/tools/audio_transcription/manifest.yml similarity index 84% rename from griptape/tools/audio_transcription_client/manifest.yml rename to griptape/tools/audio_transcription/manifest.yml index 6bbe4a21a..32b017c55 100644 --- a/griptape/tools/audio_transcription_client/manifest.yml +++ b/griptape/tools/audio_transcription/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Transcription Client +name: Transcription Tool description: A tool for generating transcription of audio. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/audio_transcription_client/tool.py b/griptape/tools/audio_transcription/tool.py similarity index 98% rename from griptape/tools/audio_transcription_client/tool.py rename to griptape/tools/audio_transcription/tool.py index 62cd9e7a5..4174db209 100644 --- a/griptape/tools/audio_transcription_client/tool.py +++ b/griptape/tools/audio_transcription/tool.py @@ -17,7 +17,7 @@ @define -class AudioTranscriptionClient(BaseTool): +class AudioTranscriptionTool(BaseTool): """A tool that can be used to generate transcriptions from input audio.""" engine: AudioTranscriptionEngine = field(kw_only=True) diff --git a/griptape/tools/aws_iam_client/__init__.py b/griptape/tools/aws_iam/__init__.py similarity index 100% rename from griptape/tools/aws_iam_client/__init__.py rename to griptape/tools/aws_iam/__init__.py diff --git a/griptape/tools/aws_iam_client/manifest.yml b/griptape/tools/aws_iam/manifest.yml similarity index 86% rename from griptape/tools/aws_iam_client/manifest.yml rename to griptape/tools/aws_iam/manifest.yml index ea825527f..072d4f92e 100644 --- a/griptape/tools/aws_iam_client/manifest.yml +++ b/griptape/tools/aws_iam/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: AWS IAM Client +name: AWS IAM Tool description: Tool for the IAM boto3 API. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/aws_iam_client/tool.py b/griptape/tools/aws_iam/tool.py similarity index 97% rename from griptape/tools/aws_iam_client/tool.py rename to griptape/tools/aws_iam/tool.py index 1be0251f0..8d22dd3c9 100644 --- a/griptape/tools/aws_iam_client/tool.py +++ b/griptape/tools/aws_iam/tool.py @@ -6,7 +6,7 @@ from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact -from griptape.tools import BaseAwsClient +from griptape.tools import BaseAwsTool from griptape.utils.decorators import activity if TYPE_CHECKING: @@ -14,7 +14,7 @@ @define -class AwsIamClient(BaseAwsClient): +class AwsIamTool(BaseAwsTool): iam_client: Client = field(default=Factory(lambda self: self.session.client("iam"), takes_self=True), kw_only=True) @activity( diff --git a/griptape/tools/aws_s3_client/__init__.py b/griptape/tools/aws_s3/__init__.py similarity index 100% rename from griptape/tools/aws_s3_client/__init__.py rename to griptape/tools/aws_s3/__init__.py diff --git a/griptape/tools/aws_s3_client/manifest.yml b/griptape/tools/aws_s3/manifest.yml similarity index 86% rename from griptape/tools/aws_s3_client/manifest.yml rename to griptape/tools/aws_s3/manifest.yml index 642b6c588..a48169f0c 100644 --- a/griptape/tools/aws_s3_client/manifest.yml +++ b/griptape/tools/aws_s3/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: AWS S3 Client +name: AWS S3 Tool description: Tool for the S3 boto3 API. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/aws_s3_client/tool.py b/griptape/tools/aws_s3/tool.py similarity index 99% rename from griptape/tools/aws_s3_client/tool.py rename to griptape/tools/aws_s3/tool.py index 8f67195a1..24d091d71 100644 --- a/griptape/tools/aws_s3_client/tool.py +++ b/griptape/tools/aws_s3/tool.py @@ -7,7 +7,7 @@ from schema import Literal, Schema from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact -from griptape.tools import BaseAwsClient +from griptape.tools import BaseAwsTool from griptape.utils.decorators import activity if TYPE_CHECKING: @@ -15,7 +15,7 @@ @define -class AwsS3Client(BaseAwsClient): +class AwsS3Tool(BaseAwsTool): s3_client: Client = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True) @activity( diff --git a/griptape/tools/base_aws_client.py b/griptape/tools/base_aws_tool.py similarity index 95% rename from griptape/tools/base_aws_client.py rename to griptape/tools/base_aws_tool.py index 8c6d02e2b..72fc54583 100644 --- a/griptape/tools/base_aws_client.py +++ b/griptape/tools/base_aws_tool.py @@ -14,7 +14,7 @@ @define -class BaseAwsClient(BaseTool, ABC): +class BaseAwsTool(BaseTool, ABC): session: boto3.Session = field(kw_only=True) @activity(config={"description": "Can be used to get current AWS account and IAM principal."}) diff --git a/griptape/tools/base_google_client.py b/griptape/tools/base_google_tool.py similarity index 98% rename from griptape/tools/base_google_client.py rename to griptape/tools/base_google_tool.py index 2a38d8ffe..c40a583cf 100644 --- a/griptape/tools/base_google_client.py +++ b/griptape/tools/base_google_tool.py @@ -9,7 +9,7 @@ @define -class BaseGoogleClient(BaseTool, ABC): +class BaseGoogleTool(BaseTool, ABC): DRIVE_FILE_SCOPES = ["https://www.googleapis.com/auth/drive.file"] DRIVE_AUTH_SCOPES = ["https://www.googleapis.com/auth/drive"] diff --git a/griptape/tools/base_griptape_cloud_client.py b/griptape/tools/base_griptape_cloud_tool.py similarity index 93% rename from griptape/tools/base_griptape_cloud_client.py rename to griptape/tools/base_griptape_cloud_tool.py index 4f5692957..7ee8f2dfc 100644 --- a/griptape/tools/base_griptape_cloud_client.py +++ b/griptape/tools/base_griptape_cloud_tool.py @@ -8,7 +8,7 @@ @define -class BaseGriptapeCloudClient(BaseTool, ABC): +class BaseGriptapeCloudTool(BaseTool, ABC): """Base class for Griptape Cloud clients. Attributes: diff --git a/griptape/tools/base_image_generation_client.py b/griptape/tools/base_image_generation_tool.py similarity index 88% rename from griptape/tools/base_image_generation_client.py rename to griptape/tools/base_image_generation_tool.py index e85336d23..487c6d1ba 100644 --- a/griptape/tools/base_image_generation_client.py +++ b/griptape/tools/base_image_generation_tool.py @@ -5,7 +5,7 @@ @define -class BaseImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): +class BaseImageGenerationTool(BlobArtifactFileOutputMixin, BaseTool): """A base class for tools that generate images from text prompts.""" PROMPT_DESCRIPTION = "Features and qualities to include in the generated image, descriptive and succinct." diff --git a/griptape/tools/calculator/manifest.yml b/griptape/tools/calculator/manifest.yml index 717313495..dd902c616 100644 --- a/griptape/tools/calculator/manifest.yml +++ b/griptape/tools/calculator/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Calculator +name: Calculator Tool description: Tool for making simple calculations in Python. contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/calculator/tool.py b/griptape/tools/calculator/tool.py index e8fcb1ed4..ed128987f 100644 --- a/griptape/tools/calculator/tool.py +++ b/griptape/tools/calculator/tool.py @@ -5,7 +5,7 @@ from griptape.utils.decorators import activity -class Calculator(BaseTool): +class CalculatorTool(BaseTool): @activity( config={ "description": "Can be used for computing simple numerical or algebraic calculations in Python", diff --git a/griptape/tools/computer/manifest.yml b/griptape/tools/computer/manifest.yml index 706c32b5b..4c5d30495 100644 --- a/griptape/tools/computer/manifest.yml +++ b/griptape/tools/computer/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Computer +name: Computer Tool description: Tool that allows LLMs to run Python code and access the shell contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/computer/tool.py b/griptape/tools/computer/tool.py index 4e996c63c..e2da2d9f8 100644 --- a/griptape/tools/computer/tool.py +++ b/griptape/tools/computer/tool.py @@ -23,7 +23,7 @@ @define -class Computer(BaseTool): +class ComputerTool(BaseTool): local_workdir: Optional[str] = field(default=None, kw_only=True) container_workdir: str = field(default="/griptape", kw_only=True) env_vars: dict = field(factory=dict, kw_only=True) diff --git a/griptape/tools/date_time/manifest.yml b/griptape/tools/date_time/manifest.yml index c50b46ed8..da8e553a5 100644 --- a/griptape/tools/date_time/manifest.yml +++ b/griptape/tools/date_time/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Date Time +name: Date Time Tool description: Tool that allows LLMs to retrieve the current date & time contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/date_time/tool.py b/griptape/tools/date_time/tool.py index 728a3449a..5181dbe3e 100644 --- a/griptape/tools/date_time/tool.py +++ b/griptape/tools/date_time/tool.py @@ -7,7 +7,7 @@ from griptape.utils.decorators import activity -class DateTime(BaseTool): +class DateTimeTool(BaseTool): @activity(config={"description": "Can be used to return current date and time."}) def get_current_datetime(self, _: dict) -> BaseArtifact: try: diff --git a/griptape/tools/email_client/__init__.py b/griptape/tools/email/__init__.py similarity index 100% rename from griptape/tools/email_client/__init__.py rename to griptape/tools/email/__init__.py diff --git a/griptape/tools/email_client/manifest.yml b/griptape/tools/email/manifest.yml similarity index 87% rename from griptape/tools/email_client/manifest.yml rename to griptape/tools/email/manifest.yml index c1e04b226..08009292d 100644 --- a/griptape/tools/email_client/manifest.yml +++ b/griptape/tools/email/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Email Client +name: Email Tool description: Tool for working with email. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/email_client/tool.py b/griptape/tools/email/tool.py similarity index 99% rename from griptape/tools/email_client/tool.py rename to griptape/tools/email/tool.py index 26c13bb9f..f5e7f0247 100644 --- a/griptape/tools/email_client/tool.py +++ b/griptape/tools/email/tool.py @@ -16,7 +16,7 @@ @define -class EmailClient(BaseTool): +class EmailTool(BaseTool): """Tool for working with email. Attributes: diff --git a/griptape/tools/file_manager/manifest.yml b/griptape/tools/file_manager/manifest.yml index 8778098fb..132a03327 100644 --- a/griptape/tools/file_manager/manifest.yml +++ b/griptape/tools/file_manager/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: File Manager +name: File Manager Tool description: Tool for managing files in the local environment. contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/file_manager/tool.py b/griptape/tools/file_manager/tool.py index c0cb691ea..a3e17a8c4 100644 --- a/griptape/tools/file_manager/tool.py +++ b/griptape/tools/file_manager/tool.py @@ -12,8 +12,8 @@ @define -class FileManager(BaseTool): - """FileManager is a tool that can be used to list, load, and save files. +class FileManagerTool(BaseTool): + """FileManagerTool is a tool that can be used to list, load, and save files. Attributes: file_manager_driver: File Manager Driver to use to list, load, and save files. diff --git a/griptape/tools/google_cal/__init__.py b/griptape/tools/google_calendar/__init__.py similarity index 100% rename from griptape/tools/google_cal/__init__.py rename to griptape/tools/google_calendar/__init__.py diff --git a/griptape/tools/google_cal/manifest.yml b/griptape/tools/google_calendar/manifest.yml similarity index 100% rename from griptape/tools/google_cal/manifest.yml rename to griptape/tools/google_calendar/manifest.yml diff --git a/griptape/tools/google_cal/requirements.txt b/griptape/tools/google_calendar/requirements.txt similarity index 100% rename from griptape/tools/google_cal/requirements.txt rename to griptape/tools/google_calendar/requirements.txt diff --git a/griptape/tools/google_cal/tool.py b/griptape/tools/google_calendar/tool.py similarity index 98% rename from griptape/tools/google_cal/tool.py rename to griptape/tools/google_calendar/tool.py index 70f685605..de9c4e8e1 100644 --- a/griptape/tools/google_cal/tool.py +++ b/griptape/tools/google_calendar/tool.py @@ -7,12 +7,12 @@ from schema import Literal, Optional, Schema from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact -from griptape.tools import BaseGoogleClient +from griptape.tools import BaseGoogleTool from griptape.utils.decorators import activity @define -class GoogleCalendarClient(BaseGoogleClient): +class GoogleCalendarTool(BaseGoogleTool): CREATE_EVENT_SCOPES = ["https://www.googleapis.com/auth/calendar"] GET_UPCOMING_EVENTS_SCOPES = ["https://www.googleapis.com/auth/calendar"] diff --git a/griptape/tools/google_docs/tool.py b/griptape/tools/google_docs/tool.py index b3564b9b2..be40b09da 100644 --- a/griptape/tools/google_docs/tool.py +++ b/griptape/tools/google_docs/tool.py @@ -6,12 +6,12 @@ from schema import Literal, Optional, Schema from griptape.artifacts import ErrorArtifact, InfoArtifact -from griptape.tools import BaseGoogleClient +from griptape.tools import BaseGoogleTool from griptape.utils.decorators import activity @define -class GoogleDocsClient(BaseGoogleClient): +class GoogleDocsTool(BaseGoogleTool): DOCS_SCOPES = ["https://www.googleapis.com/auth/documents"] DEFAULT_FOLDER_PATH = "root" diff --git a/griptape/tools/google_drive/tool.py b/griptape/tools/google_drive/tool.py index 37122a56b..1642ebaf7 100644 --- a/griptape/tools/google_drive/tool.py +++ b/griptape/tools/google_drive/tool.py @@ -9,12 +9,12 @@ from schema import Literal, Or, Schema from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact -from griptape.tools import BaseGoogleClient +from griptape.tools import BaseGoogleTool from griptape.utils.decorators import activity @define -class GoogleDriveClient(BaseGoogleClient): +class GoogleDriveTool(BaseGoogleTool): LIST_FILES_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] GOOGLE_EXPORT_MIME_MAPPING = { diff --git a/griptape/tools/google_gmail/manifest.yml b/griptape/tools/google_gmail/manifest.yml index 262e3a6f8..869575166 100644 --- a/griptape/tools/google_gmail/manifest.yml +++ b/griptape/tools/google_gmail/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Google Gmail Client +name: Google Gmail Tool description: Tool for working with Google Gmail. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/google_gmail/tool.py b/griptape/tools/google_gmail/tool.py index 853b8850f..2cc959168 100644 --- a/griptape/tools/google_gmail/tool.py +++ b/griptape/tools/google_gmail/tool.py @@ -8,12 +8,12 @@ from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, InfoArtifact -from griptape.tools import BaseGoogleClient +from griptape.tools import BaseGoogleTool from griptape.utils.decorators import activity @define -class GoogleGmailClient(BaseGoogleClient): +class GoogleGmailTool(BaseGoogleTool): CREATE_DRAFT_EMAIL_SCOPES = ["https://www.googleapis.com/auth/gmail.compose"] owner_email: str = field(kw_only=True) diff --git a/griptape/tools/griptape_cloud_knowledge_base_client/__init__.py b/griptape/tools/griptape_cloud_knowledge_base/__init__.py similarity index 100% rename from griptape/tools/griptape_cloud_knowledge_base_client/__init__.py rename to griptape/tools/griptape_cloud_knowledge_base/__init__.py diff --git a/griptape/tools/griptape_cloud_knowledge_base_client/manifest.yml b/griptape/tools/griptape_cloud_knowledge_base/manifest.yml similarity index 78% rename from griptape/tools/griptape_cloud_knowledge_base_client/manifest.yml rename to griptape/tools/griptape_cloud_knowledge_base/manifest.yml index 89b7d2fe3..7262964c3 100644 --- a/griptape/tools/griptape_cloud_knowledge_base_client/manifest.yml +++ b/griptape/tools/griptape_cloud_knowledge_base/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Griptape Cloud Knowledge Base Client +name: Griptape Cloud Knowledge Base Tool description: Tool for using the Griptape Cloud Knowledge Base API. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py b/griptape/tools/griptape_cloud_knowledge_base/tool.py similarity index 91% rename from griptape/tools/griptape_cloud_knowledge_base_client/tool.py rename to griptape/tools/griptape_cloud_knowledge_base/tool.py index 0c544524d..13ff76baa 100644 --- a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py +++ b/griptape/tools/griptape_cloud_knowledge_base/tool.py @@ -7,12 +7,12 @@ from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, TextArtifact -from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient +from griptape.tools.base_griptape_cloud_tool import BaseGriptapeCloudTool from griptape.utils.decorators import activity @define -class GriptapeCloudKnowledgeBaseClient(BaseGriptapeCloudClient): +class GriptapeCloudKnowledgeBaseTool(BaseGriptapeCloudTool): """Tool for querying a Griptape Cloud Knowledge Base. Attributes: @@ -64,7 +64,7 @@ def _get_knowledge_base_description(self) -> str: return response_body["description"] else: raise ValueError( - f"No description found for Knowledge Base {self.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseClient.description` attribute.", + f"No description found for Knowledge Base {self.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseTool.description` attribute.", ) else: raise ValueError(f"Error accessing Knowledge Base {self.knowledge_base_id}.") diff --git a/griptape/tools/image_query_client/__init__.py b/griptape/tools/image_query/__init__.py similarity index 100% rename from griptape/tools/image_query_client/__init__.py rename to griptape/tools/image_query/__init__.py diff --git a/griptape/tools/image_query_client/manifest.yml b/griptape/tools/image_query/manifest.yml similarity index 86% rename from griptape/tools/image_query_client/manifest.yml rename to griptape/tools/image_query/manifest.yml index b73027f6a..504543fca 100644 --- a/griptape/tools/image_query_client/manifest.yml +++ b/griptape/tools/image_query/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Image Query Client +name: Image Query Tool description: Tool for executing a natural language query on images. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/image_query_client/tool.py b/griptape/tools/image_query/tool.py similarity index 99% rename from griptape/tools/image_query_client/tool.py rename to griptape/tools/image_query/tool.py index a10929b13..97772d546 100644 --- a/griptape/tools/image_query_client/tool.py +++ b/griptape/tools/image_query/tool.py @@ -17,7 +17,7 @@ @define -class ImageQueryClient(BaseTool): +class ImageQueryTool(BaseTool): image_query_engine: ImageQueryEngine = field(kw_only=True) image_loader: ImageLoader = field(default=Factory(lambda: ImageLoader()), kw_only=True) diff --git a/griptape/tools/inpainting_image_generation_client/__init__.py b/griptape/tools/inpainting_image_generation/__init__.py similarity index 100% rename from griptape/tools/inpainting_image_generation_client/__init__.py rename to griptape/tools/inpainting_image_generation/__init__.py diff --git a/griptape/tools/inpainting_image_generation_client/manifest.yml b/griptape/tools/inpainting_image_generation/manifest.yml similarity index 79% rename from griptape/tools/inpainting_image_generation_client/manifest.yml rename to griptape/tools/inpainting_image_generation/manifest.yml index 575c0630d..d6592b741 100644 --- a/griptape/tools/inpainting_image_generation_client/manifest.yml +++ b/griptape/tools/inpainting_image_generation/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Inpainting Image Generation Client +name: Inpainting Image Generation Tool description: Tool for generating images through image inpainting. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/inpainting_image_generation_client/requirements.txt b/griptape/tools/inpainting_image_generation/requirements.txt similarity index 100% rename from griptape/tools/inpainting_image_generation_client/requirements.txt rename to griptape/tools/inpainting_image_generation/requirements.txt diff --git a/griptape/tools/inpainting_image_generation_client/tool.py b/griptape/tools/inpainting_image_generation/tool.py similarity index 93% rename from griptape/tools/inpainting_image_generation_client/tool.py rename to griptape/tools/inpainting_image_generation/tool.py index e8979efb0..d32f481d9 100644 --- a/griptape/tools/inpainting_image_generation_client/tool.py +++ b/griptape/tools/inpainting_image_generation/tool.py @@ -8,7 +8,7 @@ from griptape.artifacts import ErrorArtifact, ImageArtifact from griptape.loaders import ImageLoader -from griptape.tools.base_image_generation_client import BaseImageGenerationClient +from griptape.tools.base_image_generation_tool import BaseImageGenerationTool from griptape.utils.decorators import activity from griptape.utils.load_artifact_from_memory import load_artifact_from_memory @@ -17,7 +17,7 @@ @define -class InpaintingImageGenerationClient(BaseImageGenerationClient): +class InpaintingImageGenerationTool(BaseImageGenerationTool): """A tool that can be used to generate prompted inpaintings of an image. Attributes: @@ -34,8 +34,8 @@ class InpaintingImageGenerationClient(BaseImageGenerationClient): "description": "Modifies an image within a specified mask area using image and mask files.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, Literal( "image_file", description="The path to an image file to be used as a base to generate variations from.", @@ -63,8 +63,8 @@ def image_inpainting_from_file(self, params: dict[str, dict[str, str]]) -> Image "description": "Modifies an image within a specified mask area using image and mask artifacts in memory.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, "memory_name": str, "image_artifact_namespace": str, "image_artifact_name": str, diff --git a/griptape/tools/openweather_client/__init__.py b/griptape/tools/openweather/__init__.py similarity index 100% rename from griptape/tools/openweather_client/__init__.py rename to griptape/tools/openweather/__init__.py diff --git a/griptape/tools/openweather_client/manifest.yml b/griptape/tools/openweather/manifest.yml similarity index 86% rename from griptape/tools/openweather_client/manifest.yml rename to griptape/tools/openweather/manifest.yml index 66efae262..315143ea2 100644 --- a/griptape/tools/openweather_client/manifest.yml +++ b/griptape/tools/openweather/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: OpenWeather Client +name: OpenWeather Tool description: Tool for using OpenWeather to retrieve weather information contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/openweather_client/tool.py b/griptape/tools/openweather/tool.py similarity index 99% rename from griptape/tools/openweather_client/tool.py rename to griptape/tools/openweather/tool.py index 4a7edb0f6..311db733b 100644 --- a/griptape/tools/openweather_client/tool.py +++ b/griptape/tools/openweather/tool.py @@ -13,7 +13,7 @@ @define -class OpenWeatherClient(BaseTool): +class OpenWeatherTool(BaseTool): BASE_URL = "https://api.openweathermap.org/data/3.0/onecall" GEOCODING_URL = "https://api.openweathermap.org/geo/1.0/direct" US_STATE_CODES = [ diff --git a/griptape/tools/outpainting_image_generation_client/__init__.py b/griptape/tools/outpainting_image_generation/__init__.py similarity index 100% rename from griptape/tools/outpainting_image_generation_client/__init__.py rename to griptape/tools/outpainting_image_generation/__init__.py diff --git a/griptape/tools/outpainting_image_generation_client/manifest.yml b/griptape/tools/outpainting_image_generation/manifest.yml similarity index 79% rename from griptape/tools/outpainting_image_generation_client/manifest.yml rename to griptape/tools/outpainting_image_generation/manifest.yml index 54c84668e..8b7ca14a1 100644 --- a/griptape/tools/outpainting_image_generation_client/manifest.yml +++ b/griptape/tools/outpainting_image_generation/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Outpainting Image Generation Client +name: Outpainting Image Generation Tool description: Tool for generating images through image outpainting. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/outpainting_image_generation_client/requirements.txt b/griptape/tools/outpainting_image_generation/requirements.txt similarity index 100% rename from griptape/tools/outpainting_image_generation_client/requirements.txt rename to griptape/tools/outpainting_image_generation/requirements.txt diff --git a/griptape/tools/outpainting_image_generation_client/tool.py b/griptape/tools/outpainting_image_generation/tool.py similarity index 93% rename from griptape/tools/outpainting_image_generation_client/tool.py rename to griptape/tools/outpainting_image_generation/tool.py index 800d88e70..afa39e178 100644 --- a/griptape/tools/outpainting_image_generation_client/tool.py +++ b/griptape/tools/outpainting_image_generation/tool.py @@ -8,7 +8,7 @@ from griptape.artifacts import ErrorArtifact, ImageArtifact from griptape.loaders import ImageLoader -from griptape.tools import BaseImageGenerationClient +from griptape.tools import BaseImageGenerationTool from griptape.utils.decorators import activity from griptape.utils.load_artifact_from_memory import load_artifact_from_memory @@ -17,7 +17,7 @@ @define -class OutpaintingImageGenerationClient(BaseImageGenerationClient): +class OutpaintingImageGenerationTool(BaseImageGenerationTool): """A tool that can be used to generate prompted outpaintings of an image. Attributes: @@ -34,8 +34,8 @@ class OutpaintingImageGenerationClient(BaseImageGenerationClient): "description": "Modifies an image outside a specified mask area using image and mask files.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, Literal( "image_file", description="The path to an image file to be used as a base to generate variations from.", @@ -61,8 +61,8 @@ def image_outpainting_from_file(self, params: dict[str, dict[str, str]]) -> Imag "description": "Modifies an image outside a specified mask area using image and mask artifacts in memory.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, "memory_name": str, "image_artifact_namespace": str, "mask_artifact_namespace": str, diff --git a/griptape/tools/prompt_image_generation_client/__init__.py b/griptape/tools/prompt_image_generation/__init__.py similarity index 100% rename from griptape/tools/prompt_image_generation_client/__init__.py rename to griptape/tools/prompt_image_generation/__init__.py diff --git a/griptape/tools/prompt_image_generation_client/manifest.yml b/griptape/tools/prompt_image_generation/manifest.yml similarity index 80% rename from griptape/tools/prompt_image_generation_client/manifest.yml rename to griptape/tools/prompt_image_generation/manifest.yml index 665a24444..091cc14d7 100644 --- a/griptape/tools/prompt_image_generation_client/manifest.yml +++ b/griptape/tools/prompt_image_generation/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Prompt Image Generation Client +name: Prompt Image Generation Tool description: Tool for generating images from text prompts. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/prompt_image_generation_client/requirements.txt b/griptape/tools/prompt_image_generation/requirements.txt similarity index 100% rename from griptape/tools/prompt_image_generation_client/requirements.txt rename to griptape/tools/prompt_image_generation/requirements.txt diff --git a/griptape/tools/prompt_image_generation_client/tool.py b/griptape/tools/prompt_image_generation/tool.py similarity index 87% rename from griptape/tools/prompt_image_generation_client/tool.py rename to griptape/tools/prompt_image_generation/tool.py index 771b4e41d..6cd6ac560 100644 --- a/griptape/tools/prompt_image_generation_client/tool.py +++ b/griptape/tools/prompt_image_generation/tool.py @@ -5,7 +5,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.tools import BaseImageGenerationClient +from griptape.tools import BaseImageGenerationTool from griptape.utils.decorators import activity if TYPE_CHECKING: @@ -14,7 +14,7 @@ @define -class PromptImageGenerationClient(BaseImageGenerationClient): +class PromptImageGenerationTool(BaseImageGenerationTool): """A tool that can be used to generate an image from a text prompt. Attributes: @@ -30,8 +30,8 @@ class PromptImageGenerationClient(BaseImageGenerationClient): "description": "Generates an image from text prompts.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, } ), }, diff --git a/griptape/tools/rag_client/__init__.py b/griptape/tools/rag/__init__.py similarity index 100% rename from griptape/tools/rag_client/__init__.py rename to griptape/tools/rag/__init__.py diff --git a/griptape/tools/rag_client/manifest.yml b/griptape/tools/rag/manifest.yml similarity index 88% rename from griptape/tools/rag_client/manifest.yml rename to griptape/tools/rag/manifest.yml index 86998feb4..7a3d49c65 100644 --- a/griptape/tools/rag_client/manifest.yml +++ b/griptape/tools/rag/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: RAG Client +name: RAG Tool description: Tool for querying RAG engines contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/rag_client/requirements.txt b/griptape/tools/rag/requirements.txt similarity index 100% rename from griptape/tools/rag_client/requirements.txt rename to griptape/tools/rag/requirements.txt diff --git a/griptape/tools/rag_client/tool.py b/griptape/tools/rag/tool.py similarity index 97% rename from griptape/tools/rag_client/tool.py rename to griptape/tools/rag/tool.py index 613e254af..8608493d1 100644 --- a/griptape/tools/rag_client/tool.py +++ b/griptape/tools/rag/tool.py @@ -14,7 +14,7 @@ @define(kw_only=True) -class RagClient(BaseTool): +class RagTool(BaseTool): """Tool for querying a RAG engine. Attributes: diff --git a/griptape/tools/rest_api_client/__init__.py b/griptape/tools/rest_api/__init__.py similarity index 100% rename from griptape/tools/rest_api_client/__init__.py rename to griptape/tools/rest_api/__init__.py diff --git a/griptape/tools/rest_api_client/manifest.yml b/griptape/tools/rest_api/manifest.yml similarity index 87% rename from griptape/tools/rest_api_client/manifest.yml rename to griptape/tools/rest_api/manifest.yml index 7a881d037..01816e483 100644 --- a/griptape/tools/rest_api_client/manifest.yml +++ b/griptape/tools/rest_api/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Rest Api +name: Rest Api Tool description: Tool for calling rest apis. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/rest_api_client/tool.py b/griptape/tools/rest_api/tool.py similarity index 99% rename from griptape/tools/rest_api_client/tool.py rename to griptape/tools/rest_api/tool.py index b27beda0e..24ab4c93e 100644 --- a/griptape/tools/rest_api_client/tool.py +++ b/griptape/tools/rest_api/tool.py @@ -14,7 +14,7 @@ @define -class RestApiClient(BaseTool): +class RestApiTool(BaseTool): """A tool for making REST API requests. Attributes: diff --git a/griptape/tools/sql_client/__init__.py b/griptape/tools/sql/__init__.py similarity index 100% rename from griptape/tools/sql_client/__init__.py rename to griptape/tools/sql/__init__.py diff --git a/griptape/tools/sql_client/manifest.yml b/griptape/tools/sql/manifest.yml similarity index 88% rename from griptape/tools/sql_client/manifest.yml rename to griptape/tools/sql/manifest.yml index 22d0f4be2..2e1459a0d 100644 --- a/griptape/tools/sql_client/manifest.yml +++ b/griptape/tools/sql/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: SQL Client +name: SQL Tool description: Tool for executing SQL queries. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/sql_client/tool.py b/griptape/tools/sql/tool.py similarity index 98% rename from griptape/tools/sql_client/tool.py rename to griptape/tools/sql/tool.py index 2de598c6b..a84bb87be 100644 --- a/griptape/tools/sql_client/tool.py +++ b/griptape/tools/sql/tool.py @@ -14,7 +14,7 @@ @define -class SqlClient(BaseTool): +class SqlTool(BaseTool): sql_loader: SqlLoader = field(kw_only=True) schema_name: Optional[str] = field(default=None, kw_only=True) table_name: str = field(kw_only=True) diff --git a/griptape/tools/structure_run_client/__init__.py b/griptape/tools/structure_run/__init__.py similarity index 100% rename from griptape/tools/structure_run_client/__init__.py rename to griptape/tools/structure_run/__init__.py diff --git a/griptape/tools/structure_run_client/manifest.yml b/griptape/tools/structure_run/manifest.yml similarity index 83% rename from griptape/tools/structure_run_client/manifest.yml rename to griptape/tools/structure_run/manifest.yml index 5f53158d8..b5feb835a 100644 --- a/griptape/tools/structure_run_client/manifest.yml +++ b/griptape/tools/structure_run/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Structure Run Client +name: Structure Run Tool description: Tool for running a Structure. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/structure_run_client/tool.py b/griptape/tools/structure_run/tool.py similarity index 97% rename from griptape/tools/structure_run_client/tool.py rename to griptape/tools/structure_run/tool.py index f4f6c3786..cda4f0b35 100644 --- a/griptape/tools/structure_run_client/tool.py +++ b/griptape/tools/structure_run/tool.py @@ -14,7 +14,7 @@ @define -class StructureRunClient(BaseTool): +class StructureRunTool(BaseTool): """Tool for running a Structure. Attributes: diff --git a/griptape/tools/task_memory_client/__init__.py b/griptape/tools/task_memory/__init__.py similarity index 100% rename from griptape/tools/task_memory_client/__init__.py rename to griptape/tools/task_memory/__init__.py diff --git a/griptape/tools/task_memory_client/manifest.yml b/griptape/tools/task_memory/manifest.yml similarity index 85% rename from griptape/tools/task_memory_client/manifest.yml rename to griptape/tools/task_memory/manifest.yml index 0bff1af3d..5d40a1e68 100644 --- a/griptape/tools/task_memory_client/manifest.yml +++ b/griptape/tools/task_memory/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Task Memory Client +name: Task Memory Tool description: Tool for summarizing and querying TaskMemory. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/task_memory_client/tool.py b/griptape/tools/task_memory/tool.py similarity index 98% rename from griptape/tools/task_memory_client/tool.py rename to griptape/tools/task_memory/tool.py index 160a54d85..dfd6c7c4b 100644 --- a/griptape/tools/task_memory_client/tool.py +++ b/griptape/tools/task_memory/tool.py @@ -9,7 +9,7 @@ @define -class TaskMemoryClient(BaseTool): +class TaskMemoryTool(BaseTool): @activity( config={ "description": "Can be used to summarize memory content", diff --git a/griptape/tools/text_to_speech_client/__init__.py b/griptape/tools/text_to_speech/__init__.py similarity index 100% rename from griptape/tools/text_to_speech_client/__init__.py rename to griptape/tools/text_to_speech/__init__.py diff --git a/griptape/tools/text_to_speech_client/manifest.yml b/griptape/tools/text_to_speech/manifest.yml similarity index 83% rename from griptape/tools/text_to_speech_client/manifest.yml rename to griptape/tools/text_to_speech/manifest.yml index 73062bb13..875e04576 100644 --- a/griptape/tools/text_to_speech_client/manifest.yml +++ b/griptape/tools/text_to_speech/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Text to Speech Client +name: Text to Speech Tool description: A tool for generating speech from text. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/text_to_speech_client/tool.py b/griptape/tools/text_to_speech/tool.py similarity index 95% rename from griptape/tools/text_to_speech_client/tool.py rename to griptape/tools/text_to_speech/tool.py index 295641fd3..95a42d0ae 100644 --- a/griptape/tools/text_to_speech_client/tool.py +++ b/griptape/tools/text_to_speech/tool.py @@ -15,7 +15,7 @@ @define -class TextToSpeechClient(BlobArtifactFileOutputMixin, BaseTool): +class TextToSpeechTool(BlobArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate speech from input text. Attributes: diff --git a/griptape/tools/variation_image_generation_client/__init__.py b/griptape/tools/variation_image_generation/__init__.py similarity index 100% rename from griptape/tools/variation_image_generation_client/__init__.py rename to griptape/tools/variation_image_generation/__init__.py diff --git a/griptape/tools/variation_image_generation_client/manifest.yml b/griptape/tools/variation_image_generation/manifest.yml similarity index 79% rename from griptape/tools/variation_image_generation_client/manifest.yml rename to griptape/tools/variation_image_generation/manifest.yml index eb9371016..1f3eb28e8 100644 --- a/griptape/tools/variation_image_generation_client/manifest.yml +++ b/griptape/tools/variation_image_generation/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Variation Image Generation Client +name: Variation Image Generation Tool description: Tool for generating variations of existing images. contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/variation_image_generation_client/requirements.txt b/griptape/tools/variation_image_generation/requirements.txt similarity index 100% rename from griptape/tools/variation_image_generation_client/requirements.txt rename to griptape/tools/variation_image_generation/requirements.txt diff --git a/griptape/tools/variation_image_generation_client/tool.py b/griptape/tools/variation_image_generation/tool.py similarity index 91% rename from griptape/tools/variation_image_generation_client/tool.py rename to griptape/tools/variation_image_generation/tool.py index 5f836c5b1..9691f6206 100644 --- a/griptape/tools/variation_image_generation_client/tool.py +++ b/griptape/tools/variation_image_generation/tool.py @@ -8,7 +8,7 @@ from griptape.artifacts import ErrorArtifact, ImageArtifact from griptape.loaders import ImageLoader -from griptape.tools.base_image_generation_client import BaseImageGenerationClient +from griptape.tools.base_image_generation_tool import BaseImageGenerationTool from griptape.utils.decorators import activity from griptape.utils.load_artifact_from_memory import load_artifact_from_memory @@ -17,7 +17,7 @@ @define -class VariationImageGenerationClient(BaseImageGenerationClient): +class VariationImageGenerationTool(BaseImageGenerationTool): """A tool that can be used to generate prompted variations of an image. Attributes: @@ -34,8 +34,8 @@ class VariationImageGenerationClient(BaseImageGenerationClient): "description": "Generates a variation of a given input image file.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, Literal( "image_file", description="The path to an image file to be used as a base to generate variations from.", @@ -61,8 +61,8 @@ def image_variation_from_file(self, params: dict[str, dict[str, str]]) -> ImageA "description": "Generates a variation of a given input image artifact in memory.", "schema": Schema( { - Literal("prompt", description=BaseImageGenerationClient.PROMPT_DESCRIPTION): str, - Literal("negative_prompt", description=BaseImageGenerationClient.NEGATIVE_PROMPT_DESCRIPTION): str, + Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str, + Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str, "memory_name": str, "artifact_namespace": str, "artifact_name": str, diff --git a/griptape/tools/vector_store_client/__init__.py b/griptape/tools/vector_store/__init__.py similarity index 100% rename from griptape/tools/vector_store_client/__init__.py rename to griptape/tools/vector_store/__init__.py diff --git a/griptape/tools/vector_store_client/manifest.yml b/griptape/tools/vector_store/manifest.yml similarity index 85% rename from griptape/tools/vector_store_client/manifest.yml rename to griptape/tools/vector_store/manifest.yml index a1a1d1d0c..d1fab7ce5 100644 --- a/griptape/tools/vector_store_client/manifest.yml +++ b/griptape/tools/vector_store/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Vector Store Client +name: Vector Store Tool description: Tool for storing and accessing data in vector stores contact_email: hello@griptape.ai legal_info_url: https://www.griptape.ai/legal \ No newline at end of file diff --git a/griptape/tools/vector_store_client/requirements.txt b/griptape/tools/vector_store/requirements.txt similarity index 100% rename from griptape/tools/vector_store_client/requirements.txt rename to griptape/tools/vector_store/requirements.txt diff --git a/griptape/tools/vector_store_client/tool.py b/griptape/tools/vector_store/tool.py similarity index 98% rename from griptape/tools/vector_store_client/tool.py rename to griptape/tools/vector_store/tool.py index a0c638eef..71902b1c7 100644 --- a/griptape/tools/vector_store_client/tool.py +++ b/griptape/tools/vector_store/tool.py @@ -14,7 +14,7 @@ @define(kw_only=True) -class VectorStoreClient(BaseTool): +class VectorStoreTool(BaseTool): """A tool for querying a vector database. Attributes: diff --git a/griptape/tools/web_scraper/manifest.yml b/griptape/tools/web_scraper/manifest.yml index e2d0597ec..ec9d3db25 100644 --- a/griptape/tools/web_scraper/manifest.yml +++ b/griptape/tools/web_scraper/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Web Scraper +name: Web Scraper Tool description: Tool for scraping web pages for content, titles, authors, and keywords. contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/web_scraper/tool.py b/griptape/tools/web_scraper/tool.py index 782e85d37..c27aaa066 100644 --- a/griptape/tools/web_scraper/tool.py +++ b/griptape/tools/web_scraper/tool.py @@ -10,7 +10,7 @@ @define -class WebScraper(BaseTool): +class WebScraperTool(BaseTool): web_loader: WebLoader = field(default=Factory(lambda: WebLoader()), kw_only=True) @activity( diff --git a/griptape/tools/web_search/manifest.yml b/griptape/tools/web_search/manifest.yml index 4bb2a82c8..c06db4f20 100644 --- a/griptape/tools/web_search/manifest.yml +++ b/griptape/tools/web_search/manifest.yml @@ -1,5 +1,5 @@ version: "v1" -name: Google Search -description: Tool for making making web searches on Google. +name: Web Search Tool +description: Tool for making making web searches. contact_email: hello@griptape.ai -legal_info_url: https://www.griptape.ai/legal \ No newline at end of file +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/web_search/tool.py b/griptape/tools/web_search/tool.py index 43f975acd..557c26a52 100644 --- a/griptape/tools/web_search/tool.py +++ b/griptape/tools/web_search/tool.py @@ -14,7 +14,7 @@ @define -class WebSearch(BaseTool): +class WebSearchTool(BaseTool): web_search_driver: BaseWebSearchDriver = field(kw_only=True) @activity( diff --git a/mkdocs.yml b/mkdocs.yml index 175918e87..ce1772754 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -129,34 +129,34 @@ nav: - Tools: - Overview: "griptape-tools/index.md" - Official Tools: - - AwsIamClient: "griptape-tools/official-tools/aws-iam-client.md" - - AwsS3Client: "griptape-tools/official-tools/aws-s3-client.md" - - Calculator: "griptape-tools/official-tools/calculator.md" - - Computer: "griptape-tools/official-tools/computer.md" - - DateTime: "griptape-tools/official-tools/date-time.md" - - EmailClient: "griptape-tools/official-tools/email-client.md" - - FileManager: "griptape-tools/official-tools/file-manager.md" - - GoogleCalendarClient: "griptape-tools/official-tools/google-cal-client.md" - - GoogleGmailClient: "griptape-tools/official-tools/google-gmail-client.md" - - GoogleDriveClient: "griptape-tools/official-tools/google-drive-client.md" - - GoogleDocsClient: "griptape-tools/official-tools/google-docs-client.md" - - StructureRunClient: "griptape-tools/official-tools/structure-run-client.md" - - OpenWeatherClient: "griptape-tools/official-tools/openweather-client.md" - - RestApiClient: "griptape-tools/official-tools/rest-api-client.md" - - SqlClient: "griptape-tools/official-tools/sql-client.md" - - TaskMemoryClient: "griptape-tools/official-tools/task-memory-client.md" - - VectorStoreClient: "griptape-tools/official-tools/vector-store-client.md" - - WebScraper: "griptape-tools/official-tools/web-scraper.md" - - WebSearch: "griptape-tools/official-tools/web-search.md" - - PromptImageGenerationClient: "griptape-tools/official-tools/prompt-image-generation-client.md" - - VariationImageGenerationClient: "griptape-tools/official-tools/variation-image-generation-client.md" - - InpaintingImageGenerationClient: "griptape-tools/official-tools/inpainting-image-generation-client.md" - - OutpaintingImageGenerationClient: "griptape-tools/official-tools/outpainting-image-generation-client.md" - - ImageQueryClient: "griptape-tools/official-tools/image-query-client.md" - - TextToSpeechClient: "griptape-tools/official-tools/text-to-speech-client.md" - - AudioTranscriptionClient: "griptape-tools/official-tools/audio-transcription-client.md" - - GriptapeCloudKnowledgeBaseClient: "griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md" - - RagClient: "griptape-tools/official-tools/rag-client.md" + - Aws Iam: "griptape-tools/official-tools/aws-iam-tool.md" + - Aws S3: "griptape-tools/official-tools/aws-s3-tool.md" + - Calculator: "griptape-tools/official-tools/calculator-tool.md" + - Computer: "griptape-tools/official-tools/computer-tool.md" + - Date Time: "griptape-tools/official-tools/date-time-tool.md" + - Email: "griptape-tools/official-tools/email-tool.md" + - File Manager: "griptape-tools/official-tools/file-manager-tool.md" + - Google Calendar: "griptape-tools/official-tools/google-calendar-tool.md" + - Google Gmail: "griptape-tools/official-tools/google-gmail-tool.md" + - Google Drive: "griptape-tools/official-tools/google-drive-tool.md" + - Google Docs: "griptape-tools/official-tools/google-docs-tool.md" + - Structure Run Client: "griptape-tools/official-tools/structure-run-tool.md" + - Open Weather: "griptape-tools/official-tools/openweather-tool.md" + - Rest Api Client: "griptape-tools/official-tools/rest-api-tool.md" + - Sql: "griptape-tools/official-tools/sql-tool.md" + - Task Memory: "griptape-tools/official-tools/task-memory-tool.md" + - Vector Store Tool: "griptape-tools/official-tools/vector-store-tool.md" + - Web Scraper: "griptape-tools/official-tools/web-scraper-tool.md" + - Web Search: "griptape-tools/official-tools/web-search-tool.md" + - Prompt Image Generation: "griptape-tools/official-tools/prompt-image-generation-tool.md" + - Variation ImageGeneration: "griptape-tools/official-tools/variation-image-generation-tool.md" + - Inpainting ImageGeneration: "griptape-tools/official-tools/inpainting-image-generation-tool.md" + - Outpainting ImageGeneration: "griptape-tools/official-tools/outpainting-image-generation-tool.md" + - Image Query: "griptape-tools/official-tools/image-query-tool.md" + - Text To Speech: "griptape-tools/official-tools/text-to-speech-tool.md" + - Audio Transcription: "griptape-tools/official-tools/audio-transcription-tool.md" + - Griptape Cloud Knowledge Base: "griptape-tools/official-tools/griptape-cloud-knowledge-base-tool.md" + - Rag: "griptape-tools/official-tools/rag-tool.md" - Custom Tools: - Building Custom Tools: "griptape-tools/custom-tools/index.md" - Recipes: diff --git a/tests/integration/tasks/test_tool_task.py b/tests/integration/tasks/test_tool_task.py index aee0af110..426dde995 100644 --- a/tests/integration/tasks/test_tool_task.py +++ b/tests/integration/tasks/test_tool_task.py @@ -10,10 +10,10 @@ class TestToolTask: def structure_tester(self, request): from griptape.structures import Agent from griptape.tasks import ToolTask - from griptape.tools import Calculator + from griptape.tools import CalculatorTool return StructureTester( - Agent(tasks=[ToolTask(tool=Calculator())], conversation_memory=None, prompt_driver=request.param) + Agent(tasks=[ToolTask(tool=CalculatorTool())], conversation_memory=None, prompt_driver=request.param) ) def test_tool_task(self, structure_tester): diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index 8dfcfdc73..5cb1aa0dc 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -14,18 +14,18 @@ def structure_tester(self, request): from griptape.drivers import GoogleWebSearchDriver from griptape.structures import Agent - from griptape.tools import TaskMemoryClient, WebScraper, WebSearch + from griptape.tools import TaskMemoryTool, WebScraperTool, WebSearchTool return StructureTester( Agent( tools=[ - WebSearch( + WebSearchTool( web_search_driver=GoogleWebSearchDriver( api_key=os.environ["GOOGLE_API_KEY"], search_id=os.environ["GOOGLE_API_SEARCH_ID"] ) ), - WebScraper(off_prompt=True), - TaskMemoryClient(off_prompt=False), + WebScraperTool(off_prompt=True), + TaskMemoryTool(off_prompt=False), ], conversation_memory=None, prompt_driver=request.param, diff --git a/tests/integration/test_code_blocks.py b/tests/integration/test_code_blocks.py index 2e9aac01f..c9c666b0e 100644 --- a/tests/integration/test_code_blocks.py +++ b/tests/integration/test_code_blocks.py @@ -4,7 +4,7 @@ import pytest SKIP_FILES = [ - "docs/griptape-tools/official-tools/src/computer_1.py", + "docs/griptape-tools/official-tools/src/computer_tool_1.py", "docs/examples/src/load_query_and_chat_marqo_1.py", "docs/griptape-framework/drivers/src/embedding_drivers_2.py", "docs/griptape-framework/drivers/src/embedding_drivers_6.py", diff --git a/tests/integration/tools/test_calculator.py b/tests/integration/tools/test_calculator_tool.py similarity index 73% rename from tests/integration/tools/test_calculator.py rename to tests/integration/tools/test_calculator_tool.py index 2547b947d..c209a9a2c 100644 --- a/tests/integration/tools/test_calculator.py +++ b/tests/integration/tools/test_calculator_tool.py @@ -11,9 +11,9 @@ class TestCalculator: ) def structure_tester(self, request): from griptape.structures import Agent - from griptape.tools import Calculator + from griptape.tools import CalculatorTool - return StructureTester(Agent(tools=[Calculator()], conversation_memory=None, prompt_driver=request.param)) + return StructureTester(Agent(tools=[CalculatorTool()], conversation_memory=None, prompt_driver=request.param)) def test_calculate(self, structure_tester): structure_tester.run("What is 7 times 3 divided by 5 plus 10.") diff --git a/tests/integration/tools/test_file_manager.py b/tests/integration/tools/test_file_manager_tool.py similarity index 79% rename from tests/integration/tools/test_file_manager.py rename to tests/integration/tools/test_file_manager_tool.py index 8a283c6e8..4b5299175 100644 --- a/tests/integration/tools/test_file_manager.py +++ b/tests/integration/tools/test_file_manager_tool.py @@ -11,9 +11,9 @@ class TestFileManager: ) def structure_tester(self, request): from griptape.structures import Agent - from griptape.tools import FileManager + from griptape.tools import FileManagerTool - return StructureTester(Agent(tools=[FileManager()], conversation_memory=None, prompt_driver=request.param)) + return StructureTester(Agent(tools=[FileManagerTool()], conversation_memory=None, prompt_driver=request.param)) def test_save_content_to_disk(self, structure_tester): structure_tester.run('Write the content "Hello World!" to a file called "poem.txt".') diff --git a/tests/integration/tools/test_google_docs_client.py b/tests/integration/tools/test_google_docs_tool.py similarity index 95% rename from tests/integration/tools/test_google_docs_client.py rename to tests/integration/tools/test_google_docs_tool.py index 4d70aac17..7c8828dd3 100644 --- a/tests/integration/tools/test_google_docs_client.py +++ b/tests/integration/tools/test_google_docs_tool.py @@ -5,7 +5,7 @@ from tests.utils.structure_tester import StructureTester -class TestGoogleDocsClient: +class TestGoogleDocsTool: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, @@ -13,12 +13,12 @@ class TestGoogleDocsClient: ) def structure_tester(self, request): from griptape.structures import Agent - from griptape.tools import GoogleDocsClient + from griptape.tools import GoogleDocsTool return StructureTester( Agent( tools=[ - GoogleDocsClient( + GoogleDocsTool( service_account_credentials={ "type": os.environ["GOOGLE_ACCOUNT_TYPE"], "project_id": os.environ["GOOGLE_PROJECT_ID"], diff --git a/tests/integration/tools/test_google_drive_client.py b/tests/integration/tools/test_google_drive_tool.py similarity index 94% rename from tests/integration/tools/test_google_drive_client.py rename to tests/integration/tools/test_google_drive_tool.py index 23ebb1b32..7fd8b9047 100644 --- a/tests/integration/tools/test_google_drive_client.py +++ b/tests/integration/tools/test_google_drive_tool.py @@ -5,7 +5,7 @@ from tests.utils.structure_tester import StructureTester -class TestGoogleDriveClient: +class TestGoogleDriveTool: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, @@ -13,12 +13,12 @@ class TestGoogleDriveClient: ) def structure_tester(self, request): from griptape.structures import Agent - from griptape.tools import GoogleDriveClient + from griptape.tools import GoogleDriveTool return StructureTester( Agent( tools=[ - GoogleDriveClient( + GoogleDriveTool( service_account_credentials={ "type": os.environ["GOOGLE_ACCOUNT_TYPE"], "project_id": os.environ["GOOGLE_PROJECT_ID"], diff --git a/tests/unit/tools/test_aws_iam.py b/tests/unit/tools/test_aws_iam_tool.py similarity index 56% rename from tests/unit/tools/test_aws_iam.py rename to tests/unit/tools/test_aws_iam_tool.py index 54dbaa5fb..fb2b1e381 100644 --- a/tests/unit/tools/test_aws_iam.py +++ b/tests/unit/tools/test_aws_iam_tool.py @@ -1,11 +1,11 @@ import boto3 import pytest -from griptape.tools import AwsIamClient +from griptape.tools import AwsIamTool from tests.utils.aws import mock_aws_credentials -class TestAwsIamClient: +class TestAwsIamTool: @pytest.fixture(autouse=True) def _run_before_and_after_tests(self): mock_aws_credentials() @@ -14,18 +14,18 @@ def test_get_user_policy(self): value = {"user_name": "test_user", "policy_name": "test_policy"} assert ( "error returning policy document" - in AwsIamClient(session=boto3.Session()).get_user_policy({"values": value}).value + in AwsIamTool(session=boto3.Session()).get_user_policy({"values": value}).value ) def test_list_mfa_devices(self): - assert "error listing mfa devices" in AwsIamClient(session=boto3.Session()).list_mfa_devices({}).value + assert "error listing mfa devices" in AwsIamTool(session=boto3.Session()).list_mfa_devices({}).value def test_list_user_policies(self): value = {"user_name": "test_user"} assert ( "error listing iam user policies" - in AwsIamClient(session=boto3.Session()).list_user_policies({"values": value}).value + in AwsIamTool(session=boto3.Session()).list_user_policies({"values": value}).value ) def test_list_users(self): - assert "error listing s3 users" in AwsIamClient(session=boto3.Session()).list_users({}).value + assert "error listing s3 users" in AwsIamTool(session=boto3.Session()).list_users({}).value diff --git a/tests/unit/tools/test_aws_s3.py b/tests/unit/tools/test_aws_s3_tool.py similarity index 58% rename from tests/unit/tools/test_aws_s3.py rename to tests/unit/tools/test_aws_s3_tool.py index 5c6a4c151..9c4c34e0b 100644 --- a/tests/unit/tools/test_aws_s3.py +++ b/tests/unit/tools/test_aws_s3_tool.py @@ -1,42 +1,38 @@ import boto3 import pytest -from griptape.tools import AwsS3Client +from griptape.tools import AwsS3Tool from tests.utils.aws import mock_aws_credentials -class TestAwsS3Client: +class TestAwsS3Tool: @pytest.fixture(autouse=True) def _run_before_and_after_tests(self): mock_aws_credentials() def test_get_bucket_acl(self): value = {"bucket_name": "bucket_test"} - assert ( - "error getting bucket acl" in AwsS3Client(session=boto3.Session()).get_bucket_acl({"values": value}).value - ) + assert "error getting bucket acl" in AwsS3Tool(session=boto3.Session()).get_bucket_acl({"values": value}).value def test_get_bucket_policy(self): value = {"bucket_name": "bucket_test"} assert ( "error getting bucket policy" - in AwsS3Client(session=boto3.Session()).get_bucket_policy({"values": value}).value + in AwsS3Tool(session=boto3.Session()).get_bucket_policy({"values": value}).value ) def test_get_object_acl(self): value = {"bucket_name": "bucket_test", "object_key": "key_test"} - assert ( - "error getting object acl" in AwsS3Client(session=boto3.Session()).get_object_acl({"values": value}).value - ) + assert "error getting object acl" in AwsS3Tool(session=boto3.Session()).get_object_acl({"values": value}).value def test_list_s3_buckets(self): - assert "error listing s3 buckets" in AwsS3Client(session=boto3.Session()).list_s3_buckets({}).value + assert "error listing s3 buckets" in AwsS3Tool(session=boto3.Session()).list_s3_buckets({}).value def test_list_objects(self): value = {"bucket_name": "bucket_test"} assert ( "error listing objects in bucket" - in AwsS3Client(session=boto3.Session()).list_objects({"values": value}).value + in AwsS3Tool(session=boto3.Session()).list_objects({"values": value}).value ) def test_upload_memory_artifacts_to_s3(self): @@ -48,7 +44,7 @@ def test_upload_memory_artifacts_to_s3(self): } assert ( "memory not found" - in AwsS3Client(session=boto3.Session()).upload_memory_artifacts_to_s3({"values": value}).value + in AwsS3Tool(session=boto3.Session()).upload_memory_artifacts_to_s3({"values": value}).value ) def test_upload_content_to_s3(self): @@ -56,13 +52,12 @@ def test_upload_content_to_s3(self): assert ( "error uploading objects" - in AwsS3Client(session=boto3.Session()).upload_content_to_s3({"values": value}).value + in AwsS3Tool(session=boto3.Session()).upload_content_to_s3({"values": value}).value ) def test_download_objects(self): value = {"objects": {"bucket_name": "bucket_test", "object_key": "test.txt"}} assert ( - "error downloading objects" - in AwsS3Client(session=boto3.Session()).download_objects({"values": value}).value + "error downloading objects" in AwsS3Tool(session=boto3.Session()).download_objects({"values": value}).value ) diff --git a/tests/unit/tools/test_calculator.py b/tests/unit/tools/test_calculator.py index 72a525210..e598867f9 100644 --- a/tests/unit/tools/test_calculator.py +++ b/tests/unit/tools/test_calculator.py @@ -1,6 +1,6 @@ -from griptape.tools import Calculator +from griptape.tools import CalculatorTool class TestCalculator: def test_calculate(self): - assert Calculator().calculate({"values": {"expression": "5 * 5"}}).value == "25" + assert CalculatorTool().calculate({"values": {"expression": "5 * 5"}}).value == "25" diff --git a/tests/unit/tools/test_computer.py b/tests/unit/tools/test_computer.py index 95de18ae3..1f6e5c7a6 100644 --- a/tests/unit/tools/test_computer.py +++ b/tests/unit/tools/test_computer.py @@ -1,13 +1,13 @@ import pytest -from griptape.tools import Computer +from griptape.tools import ComputerTool from tests.mocks.docker.fake_api_client import make_fake_client class TestComputer: @pytest.fixture() def computer(self): - return Computer(docker_client=make_fake_client(), install_dependencies_on_init=False) + return ComputerTool(docker_client=make_fake_client(), install_dependencies_on_init=False) def test_execute_code(self, computer): assert computer.execute_code({"values": {"code": "print(1)", "filename": "foo.py"}}).value == "hello world" diff --git a/tests/unit/tools/test_date_time.py b/tests/unit/tools/test_date_time.py index c534ae69b..9fa2ce4bb 100644 --- a/tests/unit/tools/test_date_time.py +++ b/tests/unit/tools/test_date_time.py @@ -1,28 +1,28 @@ from datetime import datetime -from griptape.tools import DateTime +from griptape.tools import DateTimeTool class TestDateTime: def test_get_current_datetime(self): - result = DateTime().get_current_datetime({}) + result = DateTimeTool().get_current_datetime({}) time_delta = datetime.strptime(result.value, "%Y-%m-%d %H:%M:%S.%f") - datetime.now() assert abs(time_delta.total_seconds()) <= 1000 def test_get_past_relative_datetime(self): - result = DateTime().get_relative_datetime({"values": {"relative_date_string": "5 min ago"}}) + result = DateTimeTool().get_relative_datetime({"values": {"relative_date_string": "5 min ago"}}) time_delta = datetime.strptime(result.value, "%Y-%m-%d %H:%M:%S.%f") - datetime.now() assert abs(time_delta.total_seconds()) <= 1000 - result = DateTime().get_relative_datetime({"values": {"relative_date_string": "2 min ago, 12 seconds"}}) + result = DateTimeTool().get_relative_datetime({"values": {"relative_date_string": "2 min ago, 12 seconds"}}) time_delta = datetime.strptime(result.value, "%Y-%m-%d %H:%M:%S.%f") - datetime.now() assert abs(time_delta.total_seconds()) <= 1000 def test_get_future_relative_datetime(self): - result = DateTime().get_relative_datetime({"values": {"relative_date_string": "in 1 min, 36 seconds"}}) + result = DateTimeTool().get_relative_datetime({"values": {"relative_date_string": "in 1 min, 36 seconds"}}) time_delta = datetime.strptime(result.value, "%Y-%m-%d %H:%M:%S.%f") - datetime.now() assert abs(time_delta.total_seconds()) <= 1000 def test_get_invalid_relative_datetime(self): - result = DateTime().get_relative_datetime({"values": {"relative_date_string": "3 days from now"}}) + result = DateTimeTool().get_relative_datetime({"values": {"relative_date_string": "3 days from now"}}) assert result.type == "ErrorArtifact" diff --git a/tests/unit/tools/test_email_client.py b/tests/unit/tools/test_email_tool.py similarity index 93% rename from tests/unit/tools/test_email_client.py rename to tests/unit/tools/test_email_tool.py index cf99009b8..6c0f7cbd7 100644 --- a/tests/unit/tools/test_email_client.py +++ b/tests/unit/tools/test_email_tool.py @@ -2,14 +2,14 @@ from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.loaders.email_loader import EmailLoader -from griptape.tools import EmailClient +from griptape.tools import EmailTool -class TestEmailClient: +class TestEmailTool: @pytest.fixture(autouse=True) def mock_email_loader(self, mocker): mock_email_loader = mocker.patch( - "griptape.tools.email_client.tool.EmailLoader", + "griptape.tools.email.tool.EmailLoader", EmailQuery=EmailLoader.EmailQuery, # Prevents mocking the nested EmailQuery class ).return_value mock_email_loader.load.return_value = ListArtifact([TextArtifact("fake-email-content")]) @@ -29,7 +29,7 @@ def mock_smtp_ssl(self, mocker): @pytest.fixture() def client(self): - return EmailClient( + return EmailTool( username="fake-username", password="fake-password", smtp_host="foobar.com", @@ -63,7 +63,7 @@ def test_retrieve(self, client, mock_email_loader, values, query): def test_retrieve_when_email_max_retrieve_count_set(self, mock_email_loader): # Given - client = EmailClient(email_max_retrieve_count=84, mailboxes={"INBOX": "default mailbox for incoming email"}) + client = EmailTool(email_max_retrieve_count=84, mailboxes={"INBOX": "default mailbox for incoming email"}) # When client.retrieve({"values": {"label": "fake-label"}}) @@ -91,7 +91,7 @@ def test_send(self, client, send_params): def test_send_when_smtp_overrides_set(self, send_params): # Given - client = EmailClient( + client = EmailTool( smtp_host="smtp-host", smtp_port=86, smtp_use_ssl=False, diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 57dd2c83e..dccf2f1a2 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -9,14 +9,14 @@ from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers.file_manager.local_file_manager_driver import LocalFileManagerDriver from griptape.loaders.text_loader import TextLoader -from griptape.tools import FileManager +from griptape.tools import FileManagerTool from tests.utils import defaults class TestFileManager: @pytest.fixture() def file_manager(self): - return FileManager( + return FileManagerTool( input_memory=[defaults.text_task_memory("Memory1")], file_manager_driver=LocalFileManagerDriver(workdir=os.path.abspath(os.path.dirname(__file__))), ) @@ -47,7 +47,7 @@ def test_load_files_from_disk_with_encoding(self, file_manager): assert isinstance(result.value[0], TextArtifact) def test_load_files_from_disk_with_encoding_failure(self): - file_manager = FileManager( + file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver( default_loader=TextLoader(encoding="utf-8"), loaders={}, @@ -65,7 +65,9 @@ def test_save_memory_artifacts_to_disk_for_one_artifact(self, temp_dir): memory.store_artifact("foobar", artifact) - file_manager = FileManager(input_memory=[memory], file_manager_driver=LocalFileManagerDriver(workdir=temp_dir)) + file_manager = FileManagerTool( + input_memory=[memory], file_manager_driver=LocalFileManagerDriver(workdir=temp_dir) + ) result = file_manager.save_memory_artifacts_to_disk( { "values": { @@ -88,7 +90,9 @@ def test_save_memory_artifacts_to_disk_for_multiple_artifacts(self, temp_dir): for a in artifacts: memory.store_artifact("foobar", a) - file_manager = FileManager(input_memory=[memory], file_manager_driver=LocalFileManagerDriver(workdir=temp_dir)) + file_manager = FileManagerTool( + input_memory=[memory], file_manager_driver=LocalFileManagerDriver(workdir=temp_dir) + ) result = file_manager.save_memory_artifacts_to_disk( { "values": { @@ -105,7 +109,7 @@ def test_save_memory_artifacts_to_disk_for_multiple_artifacts(self, temp_dir): assert result.value == "Successfully saved memory artifacts to disk" def test_save_content_to_file(self, temp_dir): - file_manager = FileManager(file_manager_driver=LocalFileManagerDriver(workdir=temp_dir)) + file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(workdir=temp_dir)) result = file_manager.save_content_to_file( {"values": {"path": os.path.join("test", "foobar.txt"), "content": "foobar"}} ) @@ -114,7 +118,7 @@ def test_save_content_to_file(self, temp_dir): assert result.value == "Successfully saved file" def test_save_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManager( + file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver(default_loader=TextLoader(encoding="utf-8"), workdir=temp_dir) ) result = file_manager.save_content_to_file( @@ -125,7 +129,7 @@ def test_save_content_to_file_with_encoding(self, temp_dir): assert result.value == "Successfully saved file" def test_save_and_load_content_to_file_with_encoding(self, temp_dir): - file_manager = FileManager( + file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver(loaders={"txt": TextLoader(encoding="ascii")}, workdir=temp_dir) ) result = file_manager.save_content_to_file( @@ -135,7 +139,7 @@ def test_save_and_load_content_to_file_with_encoding(self, temp_dir): assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" assert result.value == "Successfully saved file" - file_manager = FileManager( + file_manager = FileManagerTool( file_manager_driver=LocalFileManagerDriver( default_loader=TextLoader(encoding="ascii"), loaders={}, workdir=temp_dir ) diff --git a/tests/unit/tools/test_google_docs_client.py b/tests/unit/tools/test_google_docs_tool.py similarity index 84% rename from tests/unit/tools/test_google_docs_client.py rename to tests/unit/tools/test_google_docs_tool.py index a42fddda3..516961c61 100644 --- a/tests/unit/tools/test_google_docs_client.py +++ b/tests/unit/tools/test_google_docs_tool.py @@ -1,12 +1,12 @@ import pytest -class TestGoogleDocsClient: +class TestGoogleDocsTool: @pytest.fixture() def mock_docs_client(self): - from griptape.tools import GoogleDocsClient + from griptape.tools import GoogleDocsTool - return GoogleDocsClient(owner_email="tony@griptape.ai", service_account_credentials={}) + return GoogleDocsTool(owner_email="tony@griptape.ai", service_account_credentials={}) def test_append_text(self, mock_docs_client): params = {"file_path": "test_folder/test_document", "text": "Appending this text"} diff --git a/tests/unit/tools/test_google_drive_client.py b/tests/unit/tools/test_google_drive_tool.py similarity index 68% rename from tests/unit/tools/test_google_drive_client.py rename to tests/unit/tools/test_google_drive_tool.py index 55f3c168f..55eae2267 100644 --- a/tests/unit/tools/test_google_drive_client.py +++ b/tests/unit/tools/test_google_drive_tool.py @@ -1,11 +1,11 @@ from griptape.artifacts import ErrorArtifact -from griptape.tools import GoogleDriveClient +from griptape.tools import GoogleDriveTool -class TestGoogleDriveClient: +class TestGoogleDriveTool: def test_list_files(self): value = {"folder_path": "root"} # This can be any folder path you want to test - result = GoogleDriveClient(owner_email="tony@griptape.ai", service_account_credentials={}).list_files( + result = GoogleDriveTool(owner_email="tony@griptape.ai", service_account_credentials={}).list_files( {"values": value} ) @@ -14,16 +14,16 @@ def test_list_files(self): def test_save_content_to_drive(self): value = {"path": "/path/to/your/file.txt", "content": "Sample content for the file."} - result = GoogleDriveClient( - owner_email="tony@griptape.ai", service_account_credentials={} - ).save_content_to_drive({"values": value}) + result = GoogleDriveTool(owner_email="tony@griptape.ai", service_account_credentials={}).save_content_to_drive( + {"values": value} + ) assert isinstance(result, ErrorArtifact) assert "error saving file to Google Drive" in result.value def test_download_files(self): value = {"file_paths": ["example_folder/example_file.txt"]} - result = GoogleDriveClient(owner_email="tony@griptape.ai", service_account_credentials={}).download_files( + result = GoogleDriveTool(owner_email="tony@griptape.ai", service_account_credentials={}).download_files( {"values": value} ) @@ -33,7 +33,7 @@ def test_download_files(self): def test_search_files(self): value = {"search_mode": "name", "file_name": "search_file_name.txt"} - result = GoogleDriveClient(owner_email="tony@griptape.ai", service_account_credentials={}).search_files( + result = GoogleDriveTool(owner_email="tony@griptape.ai", service_account_credentials={}).search_files( {"values": value} ) @@ -43,7 +43,7 @@ def test_search_files(self): def test_share_file(self): value = {"file_path": "/path/to/your/file.txt", "email_address": "sample_email@example.com", "role": "reader"} - result = GoogleDriveClient(owner_email="tony@griptape.ai", service_account_credentials={}).share_file( + result = GoogleDriveTool(owner_email="tony@griptape.ai", service_account_credentials={}).share_file( {"values": value} ) diff --git a/tests/unit/tools/test_google_gmail_client.py b/tests/unit/tools/test_google_gmail_tool.py similarity index 61% rename from tests/unit/tools/test_google_gmail_client.py rename to tests/unit/tools/test_google_gmail_tool.py index 7dcf1de38..ace7ef0ba 100644 --- a/tests/unit/tools/test_google_gmail_client.py +++ b/tests/unit/tools/test_google_gmail_tool.py @@ -1,12 +1,12 @@ -from griptape.tools import GoogleGmailClient +from griptape.tools import GoogleGmailTool -class TestGoogleGmailClient: +class TestGoogleGmailTool: def test_create_draft_email(self): value = {"subject": "stacey's mom", "from": "test@test.com", "body": "got it going on"} assert ( "error creating draft email" - in GoogleGmailClient(service_account_credentials={}, owner_email="tony@griptape.ai") + in GoogleGmailTool(service_account_credentials={}, owner_email="tony@griptape.ai") .create_draft_email({"values": value}) .value ) diff --git a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py b/tests/unit/tools/test_griptape_cloud_knowledge_base_tool.py similarity index 83% rename from tests/unit/tools/test_griptape_cloud_knowledge_base_client.py rename to tests/unit/tools/test_griptape_cloud_knowledge_base_tool.py index 7d75d8670..b98713273 100644 --- a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py +++ b/tests/unit/tools/test_griptape_cloud_knowledge_base_tool.py @@ -4,10 +4,10 @@ from griptape.artifacts import ErrorArtifact, TextArtifact -class TestGriptapeCloudKnowledgeBaseClient: +class TestGriptapeCloudKnowledgeBaseTool: @pytest.fixture() def client(self, mocker): - from griptape.tools import GriptapeCloudKnowledgeBaseClient + from griptape.tools import GriptapeCloudKnowledgeBaseTool mock_response = mocker.Mock() mock_response.status_code = 201 @@ -19,45 +19,45 @@ def client(self, mocker): mock_response.json.return_value = {"description": "fizz buzz"} mocker.patch("requests.get", return_value=mock_response) - return GriptapeCloudKnowledgeBaseClient( + return GriptapeCloudKnowledgeBaseTool( base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) @pytest.fixture() def client_no_description(self, mocker): - from griptape.tools import GriptapeCloudKnowledgeBaseClient + from griptape.tools import GriptapeCloudKnowledgeBaseTool mock_response = mocker.Mock() mock_response.json.return_value = {} mock_response.status_code = 200 mocker.patch("requests.get", return_value=mock_response) - return GriptapeCloudKnowledgeBaseClient( + return GriptapeCloudKnowledgeBaseTool( base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) @pytest.fixture() def client_kb_not_found(self, mocker): - from griptape.tools import GriptapeCloudKnowledgeBaseClient + from griptape.tools import GriptapeCloudKnowledgeBaseTool mock_response = mocker.Mock() mock_response.json.return_value = {} mock_response.status_code = 404 mocker.patch("requests.get", return_value=mock_response) - return GriptapeCloudKnowledgeBaseClient( + return GriptapeCloudKnowledgeBaseTool( base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) @pytest.fixture() def client_kb_error(self, mocker): - from griptape.tools import GriptapeCloudKnowledgeBaseClient + from griptape.tools import GriptapeCloudKnowledgeBaseTool mock_response = mocker.Mock() mock_response.status_code = 500 mocker.patch("requests.post", return_value=mock_response, side_effect=exceptions.RequestException("error")) - return GriptapeCloudKnowledgeBaseClient( + return GriptapeCloudKnowledgeBaseTool( base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) @@ -75,7 +75,7 @@ def test_get_knowledge_base_description(self, client): assert client._get_knowledge_base_description() == "foo bar" def test_get_knowledge_base_description_error(self, client_no_description): - exception_match_text = f"No description found for Knowledge Base {client_no_description.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseClient.description` attribute." + exception_match_text = f"No description found for Knowledge Base {client_no_description.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseTool.description` attribute." with pytest.raises(ValueError, match=exception_match_text): client_no_description._get_knowledge_base_description() diff --git a/tests/unit/tools/test_inpainting_image_generation_client.py b/tests/unit/tools/test_inpainting_image_generation_tool.py similarity index 87% rename from tests/unit/tools/test_inpainting_image_generation_client.py rename to tests/unit/tools/test_inpainting_image_generation_tool.py index 0c5e49f9a..45afcbc63 100644 --- a/tests/unit/tools/test_inpainting_image_generation_client.py +++ b/tests/unit/tools/test_inpainting_image_generation_tool.py @@ -6,10 +6,10 @@ import pytest from griptape.artifacts import ImageArtifact -from griptape.tools import InpaintingImageGenerationClient +from griptape.tools import InpaintingImageGenerationTool -class TestInpaintingImageGenerationClient: +class TestInpaintingImageGenerationTool: @pytest.fixture() def image_artifact(self) -> ImageArtifact: return ImageArtifact(value=b"image_data", format="png", width=512, height=512, name="name") @@ -26,12 +26,12 @@ def image_loader(self) -> Mock: return loader @pytest.fixture() - def image_generator(self, image_generation_engine, image_loader) -> InpaintingImageGenerationClient: - return InpaintingImageGenerationClient(engine=image_generation_engine, image_loader=image_loader) + def image_generator(self, image_generation_engine, image_loader) -> InpaintingImageGenerationTool: + return InpaintingImageGenerationTool(engine=image_generation_engine, image_loader=image_loader) def test_validate_output_configs(self, image_generation_engine) -> None: with pytest.raises(ValueError): - InpaintingImageGenerationClient(engine=image_generation_engine, output_dir="test", output_file="test") + InpaintingImageGenerationTool(engine=image_generation_engine, output_dir="test", output_file="test") def test_image_inpainting(self, image_generator, path_from_resource_path) -> None: image_generator.engine.run.return_value = Mock( @@ -55,7 +55,7 @@ def test_image_inpainting_with_outfile( self, image_generation_engine, image_loader, path_from_resource_path ) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" - image_generator = InpaintingImageGenerationClient( + image_generator = InpaintingImageGenerationTool( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) @@ -78,7 +78,7 @@ def test_image_inpainting_with_outfile( assert os.path.exists(outfile) def test_image_inpainting_from_memory(self, image_generation_engine, image_artifact): - image_generator = InpaintingImageGenerationClient(engine=image_generation_engine) + image_generator = InpaintingImageGenerationTool(engine=image_generation_engine) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) diff --git a/tests/unit/tools/test_openweather_client.py b/tests/unit/tools/test_openweather_tool.py similarity index 92% rename from tests/unit/tools/test_openweather_client.py rename to tests/unit/tools/test_openweather_tool.py index 89b80e164..44acaf571 100644 --- a/tests/unit/tools/test_openweather_client.py +++ b/tests/unit/tools/test_openweather_tool.py @@ -3,12 +3,12 @@ import pytest from griptape.artifacts import ErrorArtifact -from griptape.tools import OpenWeatherClient +from griptape.tools import OpenWeatherTool @pytest.fixture() def client(): - return OpenWeatherClient(api_key="YOUR_API_KEY") + return OpenWeatherTool(api_key="YOUR_API_KEY") class MockResponse: @@ -21,9 +21,9 @@ def json(self): def mock_requests_get(*args, **kwargs): - if args[0] == OpenWeatherClient.GEOCODING_URL: + if args[0] == OpenWeatherTool.GEOCODING_URL: return MockResponse([{"lat": 40.7128, "lon": -74.0061}], 200) - elif args[0] == OpenWeatherClient.BASE_URL: + elif args[0] == OpenWeatherTool.BASE_URL: return MockResponse({"weather": "sunny"}, 200) return MockResponse(None, 404) diff --git a/tests/unit/tools/test_outpainting_image_variation_client.py b/tests/unit/tools/test_outpainting_image_variation_tool.py similarity index 87% rename from tests/unit/tools/test_outpainting_image_variation_client.py rename to tests/unit/tools/test_outpainting_image_variation_tool.py index 13d8df082..4fbcbe8d4 100644 --- a/tests/unit/tools/test_outpainting_image_variation_client.py +++ b/tests/unit/tools/test_outpainting_image_variation_tool.py @@ -6,10 +6,10 @@ import pytest from griptape.artifacts import ImageArtifact -from griptape.tools import OutpaintingImageGenerationClient +from griptape.tools import OutpaintingImageGenerationTool -class TestOutpaintingImageGenerationClient: +class TestOutpaintingImageGenerationTool: @pytest.fixture() def image_artifact(self) -> ImageArtifact: return ImageArtifact(value=b"image_data", format="png", width=512, height=512, name="name") @@ -26,12 +26,12 @@ def image_loader(self, image_artifact) -> Mock: return loader @pytest.fixture() - def image_generator(self, image_generation_engine, image_loader) -> OutpaintingImageGenerationClient: - return OutpaintingImageGenerationClient(engine=image_generation_engine, image_loader=image_loader) + def image_generator(self, image_generation_engine, image_loader) -> OutpaintingImageGenerationTool: + return OutpaintingImageGenerationTool(engine=image_generation_engine, image_loader=image_loader) def test_validate_output_configs(self, image_generation_engine) -> None: with pytest.raises(ValueError): - OutpaintingImageGenerationClient(engine=image_generation_engine, output_dir="test", output_file="test") + OutpaintingImageGenerationTool(engine=image_generation_engine, output_dir="test", output_file="test") def test_image_outpainting(self, image_generator, path_from_resource_path) -> None: image_generator.engine.run.return_value = Mock( @@ -55,7 +55,7 @@ def test_image_outpainting_with_outfile( self, image_generation_engine, image_loader, path_from_resource_path ) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" - image_generator = OutpaintingImageGenerationClient( + image_generator = OutpaintingImageGenerationTool( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) @@ -78,7 +78,7 @@ def test_image_outpainting_with_outfile( assert os.path.exists(outfile) def test_image_outpainting_from_memory(self, image_generation_engine, image_artifact): - image_generator = OutpaintingImageGenerationClient(engine=image_generation_engine) + image_generator = OutpaintingImageGenerationTool(engine=image_generation_engine) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) diff --git a/tests/unit/tools/test_prompt_image_generation_client.py b/tests/unit/tools/test_prompt_image_generation_tool.py similarity index 77% rename from tests/unit/tools/test_prompt_image_generation_client.py rename to tests/unit/tools/test_prompt_image_generation_tool.py index 276e33473..a0c5c7037 100644 --- a/tests/unit/tools/test_prompt_image_generation_client.py +++ b/tests/unit/tools/test_prompt_image_generation_tool.py @@ -5,21 +5,21 @@ import pytest -from griptape.tools import PromptImageGenerationClient +from griptape.tools import PromptImageGenerationTool -class TestPromptImageGenerationClient: +class TestPromptImageGenerationTool: @pytest.fixture() def image_generation_engine(self) -> Mock: return Mock() @pytest.fixture() - def image_generator(self, image_generation_engine) -> PromptImageGenerationClient: - return PromptImageGenerationClient(engine=image_generation_engine) + def image_generator(self, image_generation_engine) -> PromptImageGenerationTool: + return PromptImageGenerationTool(engine=image_generation_engine) def test_validate_output_configs(self, image_generation_engine) -> None: with pytest.raises(ValueError): - PromptImageGenerationClient(engine=image_generation_engine, output_dir="test", output_file="test") + PromptImageGenerationTool(engine=image_generation_engine, output_dir="test", output_file="test") def test_generate_image(self, image_generator) -> None: image_generator.engine.run.return_value = Mock( @@ -34,7 +34,7 @@ def test_generate_image(self, image_generator) -> None: def test_generate_image_with_outfile(self, image_generation_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" - image_generator = PromptImageGenerationClient(engine=image_generation_engine, output_file=outfile) + image_generator = PromptImageGenerationTool(engine=image_generation_engine, output_file=outfile) image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" diff --git a/tests/unit/tools/test_rag_client.py b/tests/unit/tools/test_rag_tool.py similarity index 72% rename from tests/unit/tools/test_rag_client.py rename to tests/unit/tools/test_rag_tool.py index 60a0df722..eb1c00e4c 100644 --- a/tests/unit/tools/test_rag_client.py +++ b/tests/unit/tools/test_rag_tool.py @@ -1,13 +1,13 @@ from griptape.drivers import LocalVectorStoreDriver -from griptape.tools import RagClient +from griptape.tools import RagTool from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.defaults import rag_engine -class TestRagClient: +class TestRagTool: def test_search(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - tool = RagClient(description="Test", rag_engine=rag_engine(MockPromptDriver(), vector_store_driver)) + tool = RagTool(description="Test", rag_engine=rag_engine(MockPromptDriver(), vector_store_driver)) assert tool.search({"values": {"query": "test"}}).value[0].value == "mock output" diff --git a/tests/unit/tools/test_rest_api_client.py b/tests/unit/tools/test_rest_api_tool.py similarity index 92% rename from tests/unit/tools/test_rest_api_client.py rename to tests/unit/tools/test_rest_api_tool.py index 58f21d1f1..70d63478e 100644 --- a/tests/unit/tools/test_rest_api_client.py +++ b/tests/unit/tools/test_rest_api_tool.py @@ -6,9 +6,9 @@ class TestRestApi: @pytest.fixture() def client(self): - from griptape.tools import RestApiClient + from griptape.tools import RestApiTool - return RestApiClient(base_url="http://www.griptape.ai", description="Griptape website.") + return RestApiTool(base_url="http://www.griptape.ai", description="Griptape website.") def test_put(self, client): assert isinstance(client.post({"values": {"body": {}}}), BaseArtifact) diff --git a/tests/unit/tools/test_sql_client.py b/tests/unit/tools/test_sql_tool.py similarity index 87% rename from tests/unit/tools/test_sql_client.py rename to tests/unit/tools/test_sql_tool.py index 8ab61fc8f..2ef50ff54 100644 --- a/tests/unit/tools/test_sql_client.py +++ b/tests/unit/tools/test_sql_tool.py @@ -4,10 +4,10 @@ from griptape.drivers import SqlDriver from griptape.loaders import SqlLoader -from griptape.tools import SqlClient +from griptape.tools import SqlTool -class TestSqlClient: +class TestSqlTool: @pytest.fixture() def driver(self): new_driver = SqlDriver(engine_url="sqlite:///:memory:") @@ -22,14 +22,14 @@ def driver(self): def test_execute_query(self, driver): with sqlite3.connect(":memory:"): - client = SqlClient(sql_loader=SqlLoader(sql_driver=driver), table_name="test_table", engine_name="sqlite") + client = SqlTool(sql_loader=SqlLoader(sql_driver=driver), table_name="test_table", engine_name="sqlite") result = client.execute_query({"values": {"sql_query": "SELECT * from test_table;"}}) assert len(result.value) == 1 assert result.value[0].value == {"id": 1, "name": "Alice", "age": 25, "city": "New York"} def test_execute_query_description(self, driver): - client = SqlClient( + client = SqlTool( sql_loader=SqlLoader(sql_driver=driver), table_name="test_table", table_description="foobar", diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_tool.py similarity index 82% rename from tests/unit/tools/test_structure_run_client.py rename to tests/unit/tools/test_structure_run_tool.py index ee76d4da1..f62cdeea7 100644 --- a/tests/unit/tools/test_structure_run_client.py +++ b/tests/unit/tools/test_structure_run_tool.py @@ -2,15 +2,15 @@ from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver from griptape.structures import Agent -from griptape.tools import StructureRunClient +from griptape.tools import StructureRunTool -class TestStructureRunClient: +class TestStructureRunTool: @pytest.fixture() def client(self): agent = Agent() - return StructureRunClient( + return StructureRunTool( description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) ) diff --git a/tests/unit/tools/test_task_memory_client.py b/tests/unit/tools/test_task_memory_tool.py similarity index 80% rename from tests/unit/tools/test_task_memory_client.py rename to tests/unit/tools/test_task_memory_tool.py index 4276b89ec..238001d55 100644 --- a/tests/unit/tools/test_task_memory_client.py +++ b/tests/unit/tools/test_task_memory_tool.py @@ -1,14 +1,14 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.tools import TaskMemoryClient +from griptape.tools import TaskMemoryTool from tests.utils import defaults -class TestTaskMemoryClient: +class TestTaskMemoryTool: @pytest.fixture() def tool(self): - return TaskMemoryClient(off_prompt=True, input_memory=[defaults.text_task_memory("TestMemory")]) + return TaskMemoryTool(off_prompt=True, input_memory=[defaults.text_task_memory("TestMemory")]) def test_summarize(self, tool): tool.input_memory[0].store_artifact("foo", TextArtifact("test")) diff --git a/tests/unit/tools/test_text_to_speech_client.py b/tests/unit/tools/test_text_to_speech_tool.py similarity index 74% rename from tests/unit/tools/test_text_to_speech_client.py rename to tests/unit/tools/test_text_to_speech_tool.py index 0b9061aa6..8821d48fc 100644 --- a/tests/unit/tools/test_text_to_speech_client.py +++ b/tests/unit/tools/test_text_to_speech_tool.py @@ -5,21 +5,21 @@ import pytest -from griptape.tools.text_to_speech_client.tool import TextToSpeechClient +from griptape.tools.text_to_speech.tool import TextToSpeechTool -class TestTextToSpeechClient: +class TestTextToSpeechTool: @pytest.fixture() def text_to_speech_engine(self) -> Mock: return Mock() @pytest.fixture() - def text_to_speech_client(self, text_to_speech_engine) -> TextToSpeechClient: - return TextToSpeechClient(engine=text_to_speech_engine) + def text_to_speech_client(self, text_to_speech_engine) -> TextToSpeechTool: + return TextToSpeechTool(engine=text_to_speech_engine) def test_validate_output_configs(self, text_to_speech_engine) -> None: with pytest.raises(ValueError): - TextToSpeechClient(engine=text_to_speech_engine, output_dir="test", output_file="test") + TextToSpeechTool(engine=text_to_speech_engine, output_dir="test", output_file="test") def test_text_to_speech(self, text_to_speech_client) -> None: text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") @@ -30,7 +30,7 @@ def test_text_to_speech(self, text_to_speech_client) -> None: def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" - text_to_speech_client = TextToSpeechClient(engine=text_to_speech_engine, output_file=outfile) + text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile) text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] diff --git a/tests/unit/tools/test_transcription_client.py b/tests/unit/tools/test_transcription_tool.py similarity index 83% rename from tests/unit/tools/test_transcription_client.py rename to tests/unit/tools/test_transcription_tool.py index 8b54e891b..07368495f 100644 --- a/tests/unit/tools/test_transcription_client.py +++ b/tests/unit/tools/test_transcription_tool.py @@ -3,10 +3,10 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient +from griptape.tools.audio_transcription.tool import AudioTranscriptionTool -class TestTranscriptionClient: +class TestTranscriptionTool: @pytest.fixture() def transcription_engine(self) -> Mock: return Mock() @@ -27,11 +27,11 @@ def mock_path(self, mocker) -> Mock: return mocker def test_init_transcription_client(self, transcription_engine, audio_loader) -> None: - assert AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) + assert AudioTranscriptionTool(engine=transcription_engine, audio_loader=audio_loader) @patch("builtins.open", mock_open(read_data=b"audio data")) def test_transcribe_audio_from_disk(self, transcription_engine, audio_loader) -> None: - client = AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) + client = AudioTranscriptionTool(engine=transcription_engine, audio_loader=audio_loader) client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_disk(params={"values": {"path": "audio.wav"}}) @@ -40,7 +40,7 @@ def test_transcribe_audio_from_disk(self, transcription_engine, audio_loader) -> assert text_artifact.value == "transcription" def test_transcribe_audio_from_memory(self, transcription_engine, audio_loader) -> None: - client = AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) + client = AudioTranscriptionTool(engine=transcription_engine, audio_loader=audio_loader) memory = Mock() memory.load_artifacts = Mock(return_value=[AudioArtifact(value=b"audio data", format="wav", name="name")]) client.find_input_memory = Mock(return_value=memory) diff --git a/tests/unit/tools/test_variation_image_generation_client.py b/tests/unit/tools/test_variation_image_generation_tool.py similarity index 88% rename from tests/unit/tools/test_variation_image_generation_client.py rename to tests/unit/tools/test_variation_image_generation_tool.py index 0db454f92..c4528a044 100644 --- a/tests/unit/tools/test_variation_image_generation_client.py +++ b/tests/unit/tools/test_variation_image_generation_tool.py @@ -6,10 +6,10 @@ import pytest from griptape.artifacts import ImageArtifact -from griptape.tools import VariationImageGenerationClient +from griptape.tools import VariationImageGenerationTool -class TestVariationImageGenerationClient: +class TestVariationImageGenerationTool: @pytest.fixture() def image_artifact(self) -> ImageArtifact: return ImageArtifact(value=b"image_data", format="png", width=512, height=512, name="name") @@ -26,12 +26,12 @@ def image_loader(self) -> Mock: return loader @pytest.fixture() - def image_generator(self, image_generation_engine, image_loader) -> VariationImageGenerationClient: - return VariationImageGenerationClient(engine=image_generation_engine, image_loader=image_loader) + def image_generator(self, image_generation_engine, image_loader) -> VariationImageGenerationTool: + return VariationImageGenerationTool(engine=image_generation_engine, image_loader=image_loader) def test_validate_output_configs(self, image_generation_engine, image_loader) -> None: with pytest.raises(ValueError): - VariationImageGenerationClient( + VariationImageGenerationTool( engine=image_generation_engine, output_dir="test", output_file="test", image_loader=image_loader ) @@ -54,7 +54,7 @@ def test_image_variation(self, image_generator, path_from_resource_path) -> None def test_image_variation_with_outfile(self, image_generation_engine, image_loader, path_from_resource_path) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" - image_generator = VariationImageGenerationClient( + image_generator = VariationImageGenerationTool( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) @@ -76,7 +76,7 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade assert os.path.exists(outfile) def test_image_variation_from_memory(self, image_generation_engine, image_artifact): - image_generator = VariationImageGenerationClient(engine=image_generation_engine) + image_generator = VariationImageGenerationTool(engine=image_generation_engine) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) diff --git a/tests/unit/tools/test_vector_store_client.py b/tests/unit/tools/test_vector_store_tool.py similarity index 78% rename from tests/unit/tools/test_vector_store_client.py rename to tests/unit/tools/test_vector_store_tool.py index b02dda226..ea52a13ea 100644 --- a/tests/unit/tools/test_vector_store_client.py +++ b/tests/unit/tools/test_vector_store_tool.py @@ -2,18 +2,18 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers import LocalVectorStoreDriver -from griptape.tools import VectorStoreClient +from griptape.tools import VectorStoreTool from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -class TestVectorStoreClient: +class TestVectorStoreTool: @pytest.fixture(autouse=True) def _mock_try_run(self, mocker): mocker.patch("griptape.drivers.OpenAiEmbeddingDriver.try_embed_chunk", return_value=[0, 1]) def test_search(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - tool = VectorStoreClient(description="Test", vector_store_driver=driver) + tool = VectorStoreTool(description="Test", vector_store_driver=driver) driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) @@ -21,8 +21,8 @@ def test_search(self): def test_search_with_namespace(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - tool1 = VectorStoreClient(description="Test", vector_store_driver=driver, query_params={"namespace": "test"}) - tool2 = VectorStoreClient(description="Test", vector_store_driver=driver, query_params={"namespace": "test2"}) + tool1 = VectorStoreTool(description="Test", vector_store_driver=driver, query_params={"namespace": "test"}) + tool2 = VectorStoreTool(description="Test", vector_store_driver=driver, query_params={"namespace": "test2"}) driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) @@ -31,7 +31,7 @@ def test_search_with_namespace(self): def test_custom_process_query_output_fn(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - tool1 = VectorStoreClient( + tool1 = VectorStoreTool( description="Test", vector_store_driver=driver, process_query_output_fn=lambda es: ListArtifact([e.vector for e in es]), diff --git a/tests/unit/tools/test_web_scraper.py b/tests/unit/tools/test_web_scraper.py index 30362ce65..0fdc761b4 100644 --- a/tests/unit/tools/test_web_scraper.py +++ b/tests/unit/tools/test_web_scraper.py @@ -6,9 +6,9 @@ class TestWebScraper: @pytest.fixture() def scraper(self): - from griptape.tools import WebScraper + from griptape.tools import WebScraperTool - return WebScraper() + return WebScraperTool() def test_get_content(self, scraper): assert isinstance( diff --git a/tests/unit/tools/test_web_search.py b/tests/unit/tools/test_web_search.py index 17ff610e0..c1f9555ea 100644 --- a/tests/unit/tools/test_web_search.py +++ b/tests/unit/tools/test_web_search.py @@ -1,7 +1,7 @@ import pytest from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact -from griptape.tools import WebSearch +from griptape.tools import WebSearchTool class TestWebSearch: @@ -11,7 +11,7 @@ def websearch_tool(self, mocker): driver = mocker.Mock() mocker.patch.object(driver, "search", return_value=mock_response) - return WebSearch(web_search_driver=driver) + return WebSearchTool(web_search_driver=driver) @pytest.fixture() def websearch_tool_with_error(self, mocker): @@ -19,7 +19,7 @@ def websearch_tool_with_error(self, mocker): driver = mocker.Mock() mocker.patch.object(driver, "search", side_effect=mock_response) - return WebSearch(web_search_driver=driver) + return WebSearchTool(web_search_driver=driver) def test_search(self, websearch_tool): assert isinstance(websearch_tool.search({"values": {"query": "foo bar"}}), BaseArtifact) From 684eeece574c3afbd92a3f0d55f748c0990fd7c1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 16:13:50 -0700 Subject: [PATCH 53/63] Fix docs, add tests --- docs/griptape-framework/data/artifacts.md | 4 +-- docs/griptape-framework/data/loaders.md | 4 +-- .../structures/task-memory.md | 8 ++--- docs/griptape-framework/structures/tasks.md | 4 +-- .../structures/workflows.md | 2 +- .../official-tools/extraction-client.md | 6 ++-- .../official-tools/prompt-summary-client.md | 4 +-- griptape/tools/query/tool.py | 10 +++--- mkdocs.yml | 6 ++-- ...tion_client.py => test_extraction_tool.py} | 0 ..._client.py => test_prompt_summary_tool.py} | 0 tests/unit/tools/test_query_tool.py | 31 +++++++++++++++++++ 12 files changed, 53 insertions(+), 26 deletions(-) rename tests/unit/tools/{test_extraction_client.py => test_extraction_tool.py} (100%) rename tests/unit/tools/{test_prompt_summary_client.py => test_prompt_summary_tool.py} (100%) create mode 100644 tests/unit/tools/test_query_tool.py diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index 065e36123..7cff2c21d 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -40,11 +40,11 @@ Each blob has a [name](../../reference/griptape/artifacts/base_artifact.md#gript ## Image -An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an Image Artifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blobartifact). +An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an Image Artifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blob). ## Audio -An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blobartifact). +An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blob). ## Boolean diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 82d97f494..914fdee2a 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -75,7 +75,7 @@ Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) !!! info This driver requires the `loaders-image` [extra](../index.md#extras). -The Image Loader is used to load an image as an [ImageArtifact](./artifacts.md#imageartifact). The Loader operates on image bytes that can be sourced from files on disk, downloaded images, or images in memory. +The Image Loader is used to load an image as an [ImageArtifact](./artifacts.md#image). The Loader operates on image bytes that can be sourced from files on disk, downloaded images, or images in memory. ```python --8<-- "docs/griptape-framework/data/src/loaders_7.py" @@ -104,7 +104,7 @@ Can be used to load email from an imap server: !!! info This driver requires the `loaders-audio` [extra](../index.md#extras). -The [Audio Loader](../../reference/griptape/loaders/audio_loader.md) is used to load audio content as an [AudioArtifact](./artifacts.md#audioartifact). The Loader operates on audio bytes that can be sourced from files on disk, downloaded audio, or audio in memory. +The [Audio Loader](../../reference/griptape/loaders/audio_loader.md) is used to load audio content as an [AudioArtifact](./artifacts.md#audio). The Loader operates on audio bytes that can be sourced from files on disk, downloaded audio, or audio in memory. The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 0d479a36b..cc50be322 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -74,7 +74,7 @@ This is an example of [not providing a Task Memory compatible Tool](#not-providi ## Prompt Summary Client -The [PromptSummaryTool](../../griptape-tools/official-tools/prompt-summary-client.md) is a Tool that allows an Agent to summarize the Artifacts in Task Memory. It has the following methods: +The [PromptSummaryTool](../../griptape-tools/official-tools/prompt-summary-tool.md) is a Tool that allows an Agent to summarize the Artifacts in Task Memory. It has the following methods: Let's add `PromptSummaryTool` to the Agent and run the same task. Note that on the `PromptSummaryTool` we've set `off_prompt` to `False` so that the results of the query can be returned directly to the LLM. @@ -273,9 +273,9 @@ As seen in the previous example, certain Tools are designed to read directly fro Today, these include: -- [PromptSummaryTool](../../griptape-tools/official-tools/prompt-summary-client.md) -- [ExtractionTool](../../griptape-tools/official-tools/extraction-client.md) -- [RagClient](../../griptape-tools/official-tools/rag-client.md) +- [PromptSummaryTool](../../griptape-tools/official-tools/prompt-summary-tool.md) +- [ExtractionTool](../../griptape-tools/official-tools/extraction-tool.md) +- [RagClient](../../griptape-tools/official-tools/rag-tool.md) - [FileManagerTool](../../griptape-tools/official-tools/file-manager.md) ## Task Memory Considerations diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 0cf601220..40dff8f8d 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -304,7 +304,7 @@ This task takes a python function, and authors can elect to return a custom arti To generate an image, use one of the following [Image Generation Tasks](../../reference/griptape/tasks/index.md). All Image Generation Tasks accept an [Image Generation Engine](../engines/image-generation-engines.md) configured to use an [Image Generation Driver](../drivers/image-generation-drivers.md). -All successful Image Generation Tasks will always output an [Image Artifact](../data/artifacts.md#imageartifact). Each task can be configured to additionally write the generated image to disk by providing either the `output_file` or `output_dir` field. The `output_file` field supports file names in the current directory (`my_image.png`), relative directory prefixes (`images/my_image.png`), or absolute paths (`/usr/var/my_image.png`). By setting `output_dir`, the task will generate a file name and place the image in the requested directory. +All successful Image Generation Tasks will always output an [Image Artifact](../data/artifacts.md#image). Each task can be configured to additionally write the generated image to disk by providing either the `output_file` or `output_dir` field. The `output_file` field supports file names in the current directory (`my_image.png`), relative directory prefixes (`images/my_image.png`), or absolute paths (`/usr/var/my_image.png`). By setting `output_dir`, the task will generate a file name and place the image in the requested directory. ### Prompt Image Generation Task @@ -342,7 +342,7 @@ The [Outpainting Image Generation Task](../../reference/griptape/tasks/outpainti The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) performs a natural language query on one or more input images. This Task uses an [Image Query Engine](../engines/image-query-engines.md) configured with an [Image Query Driver](../drivers/image-query-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver. -This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#textartifact)) and a list of [Image Artifacts](../data/artifacts.md#imageartifact) or a Callable returning these two values. +This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#textartifact)) and a list of [Image Artifacts](../data/artifacts.md#image) or a Callable returning these two values. ```python --8<-- "docs/griptape-framework/structures/src/tasks_15.py" diff --git a/docs/griptape-framework/structures/workflows.md b/docs/griptape-framework/structures/workflows.md index 5f7c271fc..a346639f9 100644 --- a/docs/griptape-framework/structures/workflows.md +++ b/docs/griptape-framework/structures/workflows.md @@ -7,7 +7,7 @@ search: A [Workflow](../../reference/griptape/structures/workflow.md) is a non-sequential DAG that can be used for complex concurrent scenarios with tasks having multiple inputs. -You can access the final output of the Workflow by using the [output](../../reference/griptape/structures/agent.md#griptape.structures.structure.Structure.output) attribute. +You can access the final output of the Workflow by using the [output](../../reference/griptape/structures/structure.md#griptape.structures.structure.Structure.output) attribute. ## Context diff --git a/docs/griptape-tools/official-tools/extraction-client.md b/docs/griptape-tools/official-tools/extraction-client.md index d4dc8aa31..9779bfc5b 100644 --- a/docs/griptape-tools/official-tools/extraction-client.md +++ b/docs/griptape-tools/official-tools/extraction-client.md @@ -1,9 +1,7 @@ -The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. - -Here is an example of how it can be used with a local vector store driver: +The [ExractionTool](../../reference/griptape/tools/extraction/tool.md) enables LLMs to extract structured data from unstructured data. ```python ---8<-- "docs/griptape-tools/official-tools/src/rag_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/rag_tool_1.py" ``` ``` [08/12/24 15:58:03] INFO ToolkitTask 43b3d209a83c470d8371b7ef4af175b4 diff --git a/docs/griptape-tools/official-tools/prompt-summary-client.md b/docs/griptape-tools/official-tools/prompt-summary-client.md index 2315114eb..7afecf57b 100644 --- a/docs/griptape-tools/official-tools/prompt-summary-client.md +++ b/docs/griptape-tools/official-tools/prompt-summary-client.md @@ -1,6 +1,4 @@ -The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. - -Here is an example of how it can be used with a local vector store driver: +The [PromptSummaryTool](../../reference/griptape/tools/prompt_summary/tool.md) enables LLMs summarize text data. ```python --8<-- "docs/griptape-tools/official-tools/src/prompt_summary_tool_1.py" diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 5b4e29c29..3ecc63bca 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -55,13 +55,13 @@ class QueryTool(BaseTool, RuleMixin): ) def query(self, params: dict) -> BaseArtifact: query = params["values"]["query"] - summary = params["values"]["content"] + content = params["values"]["content"] - if isinstance(summary, str): - text_artifacts = [TextArtifact(summary)] + if isinstance(content, str): + text_artifacts = [TextArtifact(content)] else: - memory = self.find_input_memory(summary["memory_name"]) - artifact_namespace = summary["artifact_namespace"] + memory = self.find_input_memory(content["memory_name"]) + artifact_namespace = content["artifact_namespace"] if memory is not None: artifacts = memory.load_artifacts(artifact_namespace) diff --git a/mkdocs.yml b/mkdocs.yml index 68f5c2dfb..da68fb29c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -157,9 +157,9 @@ nav: - Audio Transcription: "griptape-tools/official-tools/audio-transcription-tool.md" - Griptape Cloud Knowledge Base: "griptape-tools/official-tools/griptape-cloud-knowledge-base-tool.md" - Rag: "griptape-tools/official-tools/rag-tool.md" - - Extraction: "griptape-tools/official-tools/extraction-client.md" - - Query: "griptape-tools/official-tools/query-client.md" - - Prompt Summary: "griptape-tools/official-tools/prompt-summary-client.md" + - Extraction: "griptape-tools/official-tools/extraction-tool.md" + - Query: "griptape-tools/official-tools/query-tool.md" + - Prompt Summary: "griptape-tools/official-tools/prompt-summary-tool.md" - Custom Tools: - Building Custom Tools: "griptape-tools/custom-tools/index.md" - Recipes: diff --git a/tests/unit/tools/test_extraction_client.py b/tests/unit/tools/test_extraction_tool.py similarity index 100% rename from tests/unit/tools/test_extraction_client.py rename to tests/unit/tools/test_extraction_tool.py diff --git a/tests/unit/tools/test_prompt_summary_client.py b/tests/unit/tools/test_prompt_summary_tool.py similarity index 100% rename from tests/unit/tools/test_prompt_summary_client.py rename to tests/unit/tools/test_prompt_summary_tool.py diff --git a/tests/unit/tools/test_query_tool.py b/tests/unit/tools/test_query_tool.py new file mode 100644 index 000000000..dcbee16cf --- /dev/null +++ b/tests/unit/tools/test_query_tool.py @@ -0,0 +1,31 @@ +import pytest + +from griptape.tools.query.tool import QueryTool +from tests.utils import defaults + + +class TestQueryTool: + @pytest.fixture() + def tool(self): + return QueryTool(input_memory=[defaults.text_task_memory("TestMemory")]) + + def test_query_str(self, tool): + assert tool.query({"values": {"query": "test", "content": "foo"}}).value[0].value == "mock output" + + def test_query_artifacts(self, tool): + assert ( + tool.query( + { + "values": { + "query": "test", + "content": { + "memory_name": tool.input_memory[0].name, + "artifact_namespace": "test", + }, + } + } + ) + .value[0] + .value + == "mock output" + ) From 75f6bfab1127db71deb8eacf829298c52f2e9187 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 16:49:09 -0700 Subject: [PATCH 54/63] Increase test coverage --- tests/unit/engines/extraction/test_json_extraction_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index 2d4626d3a..48430f1e5 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -32,3 +32,6 @@ def test_json_to_text_artifacts(self, engine): a.value for a in engine.json_to_text_artifacts('[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]') ] == ['{"test_key_1": "test_value_1"}', '{"test_key_2": "test_value_2"}'] + + def test_json_to_text_artifacts_no_matches(self, engine): + assert engine.json_to_text_artifacts("asdfasdfasdf") == [] From 068682d79b10969657f175bd5fdfcc0eca3d12f0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 14 Aug 2024 08:56:16 -0700 Subject: [PATCH 55/63] Fix docs --- docs/griptape-framework/structures/agents.md | 2 +- docs/griptape-framework/structures/pipelines.md | 2 +- docs/griptape-framework/structures/task-memory.md | 2 +- docs/griptape-framework/structures/tasks.md | 2 +- .../official-tools/{extraction-client.md => extraction-tool.md} | 0 .../{prompt-summary-client.md => prompt-summary-tool.md} | 0 .../official-tools/{query-client.md => query-tool.md} | 2 +- mkdocs.yml | 1 - 8 files changed, 5 insertions(+), 6 deletions(-) rename docs/griptape-tools/official-tools/{extraction-client.md => extraction-tool.md} (100%) rename docs/griptape-tools/official-tools/{prompt-summary-client.md => prompt-summary-tool.md} (100%) rename docs/griptape-tools/official-tools/{query-client.md => query-tool.md} (98%) diff --git a/docs/griptape-framework/structures/agents.md b/docs/griptape-framework/structures/agents.md index 376e7288a..1b40fad2b 100644 --- a/docs/griptape-framework/structures/agents.md +++ b/docs/griptape-framework/structures/agents.md @@ -12,7 +12,7 @@ directly, which the agent uses to dynamically determine whether to use a [Prompt If [tools](../../reference/griptape/structures/agent.md#griptape.structures.agent.Agent.tools) are passed provided to the Agent, a [Toolkit Task](./tasks.md#toolkit-task) will be used. If no [tools](../../reference/griptape/structures/agent.md#griptape.structures.agent.Agent.tools) are provided, a [Prompt Task](./tasks.md#prompt-task) will be used. -You can access the final output of the Agent by using the [output](../../reference/griptape/structures/agent.md#griptape.structures.structure.Structure.output) attribute. +You can access the final output of the Agent by using the [output](../../reference/griptape/structures/structure.md#griptape.structures.structure.Structure.output) attribute. ## Toolkit Task Agent diff --git a/docs/griptape-framework/structures/pipelines.md b/docs/griptape-framework/structures/pipelines.md index fc5046196..7bcfc1348 100644 --- a/docs/griptape-framework/structures/pipelines.md +++ b/docs/griptape-framework/structures/pipelines.md @@ -6,7 +6,7 @@ search: ## Overview A [Pipeline](../../reference/griptape/structures/pipeline.md) is very similar to an [Agent](../../reference/griptape/structures/agent.md), but allows for multiple tasks. -You can access the final output of the Pipeline by using the [output](../../reference/griptape/structures/agent.md#griptape.structures.structure.Structure.output) attribute. +You can access the final output of the Pipeline by using the [output](../../reference/griptape/structures/structure.md#griptape.structures.structure.Structure.output) attribute. ## Context diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index cc50be322..1cbc8e6ed 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -276,7 +276,7 @@ Today, these include: - [PromptSummaryTool](../../griptape-tools/official-tools/prompt-summary-tool.md) - [ExtractionTool](../../griptape-tools/official-tools/extraction-tool.md) - [RagClient](../../griptape-tools/official-tools/rag-tool.md) -- [FileManagerTool](../../griptape-tools/official-tools/file-manager.md) +- [FileManagerTool](../../griptape-tools/official-tools/file-manager-tool.md) ## Task Memory Considerations diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 40dff8f8d..f91937ec0 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -342,7 +342,7 @@ The [Outpainting Image Generation Task](../../reference/griptape/tasks/outpainti The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) performs a natural language query on one or more input images. This Task uses an [Image Query Engine](../engines/image-query-engines.md) configured with an [Image Query Driver](../drivers/image-query-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver. -This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#textartifact)) and a list of [Image Artifacts](../data/artifacts.md#image) or a Callable returning these two values. +This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#text)) and a list of [Image Artifacts](../data/artifacts.md#image) or a Callable returning these two values. ```python --8<-- "docs/griptape-framework/structures/src/tasks_15.py" diff --git a/docs/griptape-tools/official-tools/extraction-client.md b/docs/griptape-tools/official-tools/extraction-tool.md similarity index 100% rename from docs/griptape-tools/official-tools/extraction-client.md rename to docs/griptape-tools/official-tools/extraction-tool.md diff --git a/docs/griptape-tools/official-tools/prompt-summary-client.md b/docs/griptape-tools/official-tools/prompt-summary-tool.md similarity index 100% rename from docs/griptape-tools/official-tools/prompt-summary-client.md rename to docs/griptape-tools/official-tools/prompt-summary-tool.md diff --git a/docs/griptape-tools/official-tools/query-client.md b/docs/griptape-tools/official-tools/query-tool.md similarity index 98% rename from docs/griptape-tools/official-tools/query-client.md rename to docs/griptape-tools/official-tools/query-tool.md index 5728806a2..5c5754e3c 100644 --- a/docs/griptape-tools/official-tools/query-client.md +++ b/docs/griptape-tools/official-tools/query-tool.md @@ -1,4 +1,4 @@ -The [RagClient](../../reference/griptape/tools/rag_client/tool.md) enables LLMs to query modular RAG engines. +The [RagClient](../../reference/griptape/tools/rag/tool.md) enables LLMs to query modular RAG engines. Here is an example of how it can be used with a local vector store driver: diff --git a/mkdocs.yml b/mkdocs.yml index da68fb29c..7e3806264 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -144,7 +144,6 @@ nav: - Open Weather: "griptape-tools/official-tools/openweather-tool.md" - Rest Api Client: "griptape-tools/official-tools/rest-api-tool.md" - Sql: "griptape-tools/official-tools/sql-tool.md" - - Task Memory: "griptape-tools/official-tools/task-memory-tool.md" - Vector Store Tool: "griptape-tools/official-tools/vector-store-tool.md" - Web Scraper: "griptape-tools/official-tools/web-scraper-tool.md" - Web Search: "griptape-tools/official-tools/web-search-tool.md" From 2a61f43c9952406a295483b3b7e5e826dbfd5176 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 14 Aug 2024 10:27:22 -0700 Subject: [PATCH 56/63] Fix extraction --- griptape/engines/extraction/json_extraction_engine.py | 2 +- griptape/templates/engines/extraction/json/user.j2 | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 56815bc06..5fc64f838 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -18,7 +18,7 @@ @define class JsonExtractionEngine(BaseExtractionEngine): - JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" + JSON_PATTERN = r"(?s)(\{.*?\}|\[.*?\])" template_schema: dict = field(default=Factory(dict), kw_only=True) system_template_generator: J2 = field( diff --git a/griptape/templates/engines/extraction/json/user.j2 b/griptape/templates/engines/extraction/json/user.j2 index 984977d9a..00e162d15 100644 --- a/griptape/templates/engines/extraction/json/user.j2 +++ b/griptape/templates/engines/extraction/json/user.j2 @@ -1,4 +1,4 @@ -Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects. +Extract information from the Text based on the Extraction Template JSON Schema into valid JSON. Text: """{{ text }}""" -JSON array: +JSON: From 1ab61b9dad0baff99d2dac5bcec4a493700592ef Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 10:12:25 -0700 Subject: [PATCH 57/63] Trigger build From deb8ff5562ceb1244dbf15fc3e36175e9d688c13 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 10:54:51 -0700 Subject: [PATCH 58/63] Clean up tool --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3563aa9ac..ef42d9094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,9 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. - Global config, `griptape.config.config`, for setting global configuration defaults. - Unique name generation for all `RagEngine` modules. -- `ExtractionTool` Tool for having the LLM extract structured data from text. -- `PromptSummaryTool` Tool for having the LLM summarize text. -- `QueryTool` Tool for having the LLM query text. +- `ExtractionTool` for having the LLM extract structured data from text. +- `PromptSummaryTool` for having the LLM summarize text. +- `QueryTool` for having the LLM query text. - Support for bitshift composition in `BaseTask` for adding parent/child tasks. ### Changed From 9008a639e601a70d779c53b16d9b3a8b2139d80f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 10:55:55 -0700 Subject: [PATCH 59/63] Fix header --- docs/griptape-framework/structures/task-memory.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 1cbc8e6ed..81334b1cb 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -72,7 +72,7 @@ When we set `off_prompt` to `True`, the Agent does not function as expected, eve To fix this, we need a [Tool that can read from Task Memory](#tools-that-can-read-from-task-memory) such as the `PromptSummaryTool`. This is an example of [not providing a Task Memory compatible Tool](#not-providing-a-task-memory-compatible-tool). -## Prompt Summary Client +## Prompt Summary Tool The [PromptSummaryTool](../../griptape-tools/official-tools/prompt-summary-tool.md) is a Tool that allows an Agent to summarize the Artifacts in Task Memory. It has the following methods: From ed9b074db525e08edf6f04a319fa9b28185a7f53 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 12:09:10 -0700 Subject: [PATCH 60/63] Fix doc links --- docs/griptape-tools/official-tools/extraction-tool.md | 4 ++-- docs/griptape-tools/official-tools/query-tool.md | 6 ++---- .../src/{extraction_client_1.py => extraction_tool_1.py} | 0 ...{prompt_summary_client_1.py => prompt_summary_tool_1.py} | 0 .../src/{query_client_1.py => query_tool_1.py} | 0 5 files changed, 4 insertions(+), 6 deletions(-) rename docs/griptape-tools/official-tools/src/{extraction_client_1.py => extraction_tool_1.py} (100%) rename docs/griptape-tools/official-tools/src/{prompt_summary_client_1.py => prompt_summary_tool_1.py} (100%) rename docs/griptape-tools/official-tools/src/{query_client_1.py => query_tool_1.py} (100%) diff --git a/docs/griptape-tools/official-tools/extraction-tool.md b/docs/griptape-tools/official-tools/extraction-tool.md index 9779bfc5b..5b0486ffd 100644 --- a/docs/griptape-tools/official-tools/extraction-tool.md +++ b/docs/griptape-tools/official-tools/extraction-tool.md @@ -1,7 +1,7 @@ -The [ExractionTool](../../reference/griptape/tools/extraction/tool.md) enables LLMs to extract structured data from unstructured data. +The [ExractionTool](../../reference/griptape/tools/extraction/tool.md) enables LLMs to extract structured text from unstructured data. ```python ---8<-- "docs/griptape-tools/official-tools/src/rag_tool_1.py" +--8<-- "docs/griptape-tools/official-tools/src/extraction_tool_1.py" ``` ``` [08/12/24 15:58:03] INFO ToolkitTask 43b3d209a83c470d8371b7ef4af175b4 diff --git a/docs/griptape-tools/official-tools/query-tool.md b/docs/griptape-tools/official-tools/query-tool.md index 5c5754e3c..4a4f2bf33 100644 --- a/docs/griptape-tools/official-tools/query-tool.md +++ b/docs/griptape-tools/official-tools/query-tool.md @@ -1,9 +1,7 @@ -The [RagClient](../../reference/griptape/tools/rag/tool.md) enables LLMs to query modular RAG engines. - -Here is an example of how it can be used with a local vector store driver: +The [QueryTool](../../reference/griptape/tools/query/tool.md) enables Agents to query unstructured data for specific information. ```python ---8<-- "docs/griptape-tools/official-tools/src/query_client_1.py" +--8<-- "docs/griptape-tools/official-tools/src/query_tool_1.py" ``` ``` [08/12/24 15:49:23] INFO ToolkitTask a88abda2e5324bdf81a3e2b99c26b9df diff --git a/docs/griptape-tools/official-tools/src/extraction_client_1.py b/docs/griptape-tools/official-tools/src/extraction_tool_1.py similarity index 100% rename from docs/griptape-tools/official-tools/src/extraction_client_1.py rename to docs/griptape-tools/official-tools/src/extraction_tool_1.py diff --git a/docs/griptape-tools/official-tools/src/prompt_summary_client_1.py b/docs/griptape-tools/official-tools/src/prompt_summary_tool_1.py similarity index 100% rename from docs/griptape-tools/official-tools/src/prompt_summary_client_1.py rename to docs/griptape-tools/official-tools/src/prompt_summary_tool_1.py diff --git a/docs/griptape-tools/official-tools/src/query_client_1.py b/docs/griptape-tools/official-tools/src/query_tool_1.py similarity index 100% rename from docs/griptape-tools/official-tools/src/query_client_1.py rename to docs/griptape-tools/official-tools/src/query_tool_1.py From 8f59c7c65d3bb65adcb503afdf0143d5151ef214 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 12:10:46 -0700 Subject: [PATCH 61/63] Fix extraction --- griptape/engines/extraction/json_extraction_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 5fc64f838..dcba8b433 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -18,7 +18,7 @@ @define class JsonExtractionEngine(BaseExtractionEngine): - JSON_PATTERN = r"(?s)(\{.*?\}|\[.*?\])" + JSON_PATTERN = r"(?s)[^{]*({.*})" template_schema: dict = field(default=Factory(dict), kw_only=True) system_template_generator: J2 = field( @@ -38,6 +38,7 @@ def extract( self._extract_rec( cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], [], + rulesets=rulesets, ), item_separator="\n", ) @@ -48,7 +49,7 @@ def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) if json_matches: - return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])] + return [TextArtifact(json.dumps(json.loads(e))) for e in json_matches] else: return [] From fe1f736512f50ce2407685f1acb37622ea584867 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 12:11:00 -0700 Subject: [PATCH 62/63] Merge extraction tool activities --- griptape/tools/extraction/tool.py | 45 +++++------------------- tests/unit/tools/test_extraction_tool.py | 14 +++----- 2 files changed, 12 insertions(+), 47 deletions(-) diff --git a/griptape/tools/extraction/tool.py b/griptape/tools/extraction/tool.py index fca38cc5b..1f6d06b80 100644 --- a/griptape/tools/extraction/tool.py +++ b/griptape/tools/extraction/tool.py @@ -5,14 +5,13 @@ from attrs import define, field from schema import Literal, Or, Schema -from griptape.artifacts import ErrorArtifact -from griptape.engines import CsvExtractionEngine, JsonExtractionEngine +from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact from griptape.mixins import RuleMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity if TYPE_CHECKING: - from griptape.artifacts import InfoArtifact, ListArtifact + from griptape.artifacts import InfoArtifact from griptape.engines import BaseExtractionEngine @@ -26,15 +25,9 @@ class ExtractionTool(BaseTool, RuleMixin): extraction_engine: BaseExtractionEngine = field() - def __attrs_post_init__(self) -> None: - if isinstance(self.extraction_engine, CsvExtractionEngine): - self.allowlist = ["extract_csv"] - elif isinstance(self.extraction_engine, JsonExtractionEngine): - self.allowlist = ["extract_json"] - @activity( config={ - "description": "Can be used extract data in JSON format", + "description": "Can be used extract structured text from data.", "schema": Schema( { Literal("data"): Or( @@ -50,40 +43,18 @@ def __attrs_post_init__(self) -> None: ), }, ) - def extract_json(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: - return self._extract(params) - - @activity( - config={ - "description": "Can be used extract data in CSV format", - "schema": Schema( - { - Literal("data"): Or( - str, - Schema( - { - "memory_name": str, - "artifact_namespace": str, - } - ), - ), - } - ), - }, - ) - def extract_csv(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: - return self._extract(params) - - def _extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + def extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: data = params["values"]["data"] if isinstance(data, str): - return self.extraction_engine.extract(data, rulesets=self.rulesets) + artifacts = ListArtifact([TextArtifact(data)]) else: memory = self.find_input_memory(data["memory_name"]) artifact_namespace = data["artifact_namespace"] if memory is not None: - return self.extraction_engine.extract(memory.load_artifacts(artifact_namespace)) + artifacts = memory.load_artifacts(artifact_namespace) else: return ErrorArtifact("memory not found") + + return self.extraction_engine.extract(artifacts) diff --git a/tests/unit/tools/test_extraction_tool.py b/tests/unit/tools/test_extraction_tool.py index 33598971c..1219da373 100644 --- a/tests/unit/tools/test_extraction_tool.py +++ b/tests/unit/tools/test_extraction_tool.py @@ -35,7 +35,7 @@ def csv_tool(self): def test_json_extract_artifacts(self, json_tool): json_tool.input_memory[0].store_artifact("foo", TextArtifact(json.dumps({}))) - result = json_tool.extract_json( + result = json_tool.extract( {"values": {"data": {"memory_name": json_tool.input_memory[0].name, "artifact_namespace": "foo"}}} ) @@ -44,7 +44,7 @@ def test_json_extract_artifacts(self, json_tool): assert result.value[1].value == '{"test_key_2": "test_value_2"}' def test_json_extract_content(self, json_tool): - result = json_tool.extract_json({"values": {"data": "foo"}}) + result = json_tool.extract({"values": {"data": "foo"}}) assert len(result.value) == 2 assert result.value[0].value == '{"test_key_1": "test_value_1"}' @@ -53,7 +53,7 @@ def test_json_extract_content(self, json_tool): def test_csv_extract_artifacts(self, csv_tool): csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz")) - result = csv_tool.extract_csv( + result = csv_tool.extract( {"values": {"data": {"memory_name": csv_tool.input_memory[0].name, "artifact_namespace": "foo"}}} ) @@ -61,13 +61,7 @@ def test_csv_extract_artifacts(self, csv_tool): assert result.value[0].value == {"test1": "mock output"} def test_csv_extract_content(self, csv_tool): - result = csv_tool.extract_csv({"values": {"data": "foo"}}) + result = csv_tool.extract({"values": {"data": "foo"}}) assert len(result.value) == 1 assert result.value[0].value == {"test1": "mock output"} - - def test_json_allowlist(self, json_tool): - assert json_tool.allowlist == ["extract_json"] - - def test_csv_allowlist(self, csv_tool): - assert csv_tool.allowlist == ["extract_csv"] From 2ca887fdd4b5ed09fc18fe500532a5773b927b1e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 15 Aug 2024 12:17:56 -0700 Subject: [PATCH 63/63] Revert json extraction logic to use arrays --- griptape/engines/extraction/json_extraction_engine.py | 4 ++-- griptape/templates/engines/extraction/json/user.j2 | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index dcba8b433..8f2f4a3fe 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -18,7 +18,7 @@ @define class JsonExtractionEngine(BaseExtractionEngine): - JSON_PATTERN = r"(?s)[^{]*({.*})" + JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" template_schema: dict = field(default=Factory(dict), kw_only=True) system_template_generator: J2 = field( @@ -49,7 +49,7 @@ def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) if json_matches: - return [TextArtifact(json.dumps(json.loads(e))) for e in json_matches] + return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])] else: return [] diff --git a/griptape/templates/engines/extraction/json/user.j2 b/griptape/templates/engines/extraction/json/user.j2 index 00e162d15..984977d9a 100644 --- a/griptape/templates/engines/extraction/json/user.j2 +++ b/griptape/templates/engines/extraction/json/user.j2 @@ -1,4 +1,4 @@ -Extract information from the Text based on the Extraction Template JSON Schema into valid JSON. +Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects. Text: """{{ text }}""" -JSON: +JSON array: