diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index 7adac812f7..81b218a295 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,7 +1,7 @@ from typing import cast from griptape.drivers import OpenAiChatPromptDriver -from griptape.events import CompletionChunkEvent, EventBus, EventListener +from griptape.events import BaseChunkEvent, EventBus, EventListener from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import PromptSummaryTool, WebScraperTool @@ -9,8 +9,8 @@ EventBus.add_event_listeners( [ EventListener( - lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True), - event_types=[CompletionChunkEvent], + lambda e: print(cast(BaseChunkEvent, e).token, end="", flush=True), + event_types=[BaseChunkEvent], ) ] ) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 7b003e7058..91d0cccc89 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -17,11 +17,11 @@ observable, ) from griptape.events import ( - ActionCallCompletionChunkEvent, + ActionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent, - TextCompletionChunkEvent, + TextChunkEvent, ) from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin @@ -133,12 +133,10 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(TextCompletionChunkEvent(token=content.text)) - elif isinstance(content, ActionCallDeltaMessageContent) and content.partial_input is not None: + EventBus.publish_event(TextChunkEvent.from_delta_message_content(content)) + elif isinstance(content, ActionCallDeltaMessageContent): EventBus.publish_event( - ActionCallCompletionChunkEvent( - token=content.partial_input, tag=content.tag, name=content.name, path=content.path - ) + ActionChunkEvent.from_delta_message_content(content), ) # Build a complete content from the content deltas diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b2af1a25cc..e8a14d7506 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -10,9 +10,9 @@ from .finish_prompt_event import FinishPromptEvent from .start_structure_run_event import StartStructureRunEvent from .finish_structure_run_event import FinishStructureRunEvent -from .completion_chunk_event import CompletionChunkEvent -from .text_completion_chunk_event import TextCompletionChunkEvent -from .action_call_completion_chunk_event import ActionCallCompletionChunkEvent +from .base_chunk_event import BaseChunkEvent +from .text_chunk_event import TextChunkEvent +from .action_chunk_event import ActionChunkEvent from .event_listener import EventListener from .start_image_generation_event import StartImageGenerationEvent from .finish_image_generation_event import FinishImageGenerationEvent @@ -39,9 +39,9 @@ "FinishPromptEvent", "StartStructureRunEvent", "FinishStructureRunEvent", - "CompletionChunkEvent", - "TextCompletionChunkEvent", - "ActionCallCompletionChunkEvent", + "BaseChunkEvent", + "TextChunkEvent", + "ActionChunkEvent", "EventListener", "StartImageGenerationEvent", "FinishImageGenerationEvent", diff --git a/griptape/events/action_call_completion_chunk_event.py b/griptape/events/action_call_completion_chunk_event.py deleted file mode 100644 index df9195893e..0000000000 --- a/griptape/events/action_call_completion_chunk_event.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from attrs import define, field - -from griptape.events.completion_chunk_event import CompletionChunkEvent - - -@define -class ActionCallCompletionChunkEvent(CompletionChunkEvent): - tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/action_chunk_event.py b/griptape/events/action_chunk_event.py new file mode 100644 index 0000000000..7bee32c7cd --- /dev/null +++ b/griptape/events/action_chunk_event.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +from griptape.events.base_chunk_event import BaseChunkEvent + +if TYPE_CHECKING: + from griptape.common import BaseDeltaMessageContent + + +@define +class ActionChunkEvent(BaseChunkEvent): + tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + + @classmethod + def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> ActionChunkEvent: + from griptape.common import ActionCallDeltaMessageContent + + if isinstance(content, ActionCallDeltaMessageContent): + return cls( + token=content.partial_input if content.partial_input is not None else "", + index=content.index, + tag=content.tag, + name=content.name, + path=content.path, + ) + + raise ValueError(f"Content is not an instance of ActionCallDeltaMessageContent: {content.__class__.__name__}") diff --git a/griptape/events/base_chunk_event.py b/griptape/events/base_chunk_event.py new file mode 100644 index 0000000000..2fedfa0777 --- /dev/null +++ b/griptape/events/base_chunk_event.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from attrs import define, field + +from griptape.events.base_event import BaseEvent + +if TYPE_CHECKING: + from griptape.common import BaseDeltaMessageContent + + +@define +class BaseChunkEvent(BaseEvent): + token: str = field(metadata={"serializable": True}) + index: int = field(default=0, metadata={"serializable": True}) + + @classmethod + def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> BaseChunkEvent: ... diff --git a/griptape/events/completion_chunk_event.py b/griptape/events/completion_chunk_event.py deleted file mode 100644 index 48b479625b..0000000000 --- a/griptape/events/completion_chunk_event.py +++ /dev/null @@ -1,8 +0,0 @@ -from attrs import define, field - -from griptape.events.base_event import BaseEvent - - -@define -class CompletionChunkEvent(BaseEvent): - token: str = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/text_chunk_event.py b/griptape/events/text_chunk_event.py new file mode 100644 index 0000000000..54666c52ab --- /dev/null +++ b/griptape/events/text_chunk_event.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from attrs import define, field + +from griptape.events.base_chunk_event import BaseChunkEvent + +if TYPE_CHECKING: + from griptape.common import BaseDeltaMessageContent + + +@define +class TextChunkEvent(BaseChunkEvent): + token: str = field(metadata={"serializable": True}) + + @classmethod + def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> TextChunkEvent: + from griptape.common import TextDeltaMessageContent + + if isinstance(content, TextDeltaMessageContent): + return cls(token=content.text, index=content.index) + + raise ValueError(f"Content is not an instance of TextDeltaMessageContent: {content.__class__.__name__}") diff --git a/griptape/events/text_completion_chunk_event.py b/griptape/events/text_completion_chunk_event.py deleted file mode 100644 index 99e4baafa1..0000000000 --- a/griptape/events/text_completion_chunk_event.py +++ /dev/null @@ -1,8 +0,0 @@ -from attrs import define - -from griptape.events.completion_chunk_event import CompletionChunkEvent - - -@define -class TextCompletionChunkEvent(CompletionChunkEvent): - pass diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 8a764e85a1..588cdeed90 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import BaseChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -54,7 +54,7 @@ def run(self, *args) -> Iterator[TextArtifact]: break elif isinstance(event, FinishPromptEvent): yield TextArtifact(value="\n") - elif isinstance(event, CompletionChunkEvent): + elif isinstance(event, BaseChunkEvent): yield TextArtifact(value=event.token) t.join() @@ -64,7 +64,7 @@ def event_handler(event: BaseEvent) -> None: stream_event_listener = EventListener( handler=event_handler, - event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], + event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) EventBus.add_event_listener(stream_event_listener)