From 4a67b54c2d7eefdb4b15235f43ee5fb708f834a8 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 16 Oct 2024 17:16:46 -0500 Subject: [PATCH] Add `CompletionChunkEvent` subtypes --- griptape/drivers/prompt/base_prompt_driver.py | 22 ++++++++++++------- griptape/events/__init__.py | 4 ++++ .../action_call_completion_chunk_event.py | 14 ++++++++++++ .../events/text_completion_chunk_event.py | 8 +++++++ 4 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 griptape/events/action_call_completion_chunk_event.py create mode 100644 griptape/events/text_completion_chunk_event.py diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index b06f1f0151..7b003e7058 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,13 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import ( + ActionCallCompletionChunkEvent, + EventBus, + FinishPromptEvent, + StartPromptEvent, + TextCompletionChunkEvent, +) from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin @@ -127,13 +133,13 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text, meta={"type": "text"})) - elif isinstance(content, ActionCallDeltaMessageContent): - meta = {"type": "action"} - if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content), meta=meta)) - elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input, meta=meta)) + EventBus.publish_event(TextCompletionChunkEvent(token=content.text)) + elif isinstance(content, ActionCallDeltaMessageContent) and content.partial_input is not None: + EventBus.publish_event( + ActionCallCompletionChunkEvent( + token=content.partial_input, tag=content.tag, name=content.name, path=content.path + ) + ) # Build a complete content from the content deltas return self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a795..b2af1a25cc 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -11,6 +11,8 @@ from .start_structure_run_event import StartStructureRunEvent from .finish_structure_run_event import FinishStructureRunEvent from .completion_chunk_event import CompletionChunkEvent +from .text_completion_chunk_event import TextCompletionChunkEvent +from .action_call_completion_chunk_event import ActionCallCompletionChunkEvent from .event_listener import EventListener from .start_image_generation_event import StartImageGenerationEvent from .finish_image_generation_event import FinishImageGenerationEvent @@ -38,6 +40,8 @@ "StartStructureRunEvent", "FinishStructureRunEvent", "CompletionChunkEvent", + "TextCompletionChunkEvent", + "ActionCallCompletionChunkEvent", "EventListener", "StartImageGenerationEvent", "FinishImageGenerationEvent", diff --git a/griptape/events/action_call_completion_chunk_event.py b/griptape/events/action_call_completion_chunk_event.py new file mode 100644 index 0000000000..df9195893e --- /dev/null +++ b/griptape/events/action_call_completion_chunk_event.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import Optional + +from attrs import define, field + +from griptape.events.completion_chunk_event import CompletionChunkEvent + + +@define +class ActionCallCompletionChunkEvent(CompletionChunkEvent): + tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/text_completion_chunk_event.py b/griptape/events/text_completion_chunk_event.py new file mode 100644 index 0000000000..99e4baafa1 --- /dev/null +++ b/griptape/events/text_completion_chunk_event.py @@ -0,0 +1,8 @@ +from attrs import define + +from griptape.events.completion_chunk_event import CompletionChunkEvent + + +@define +class TextCompletionChunkEvent(CompletionChunkEvent): + pass