Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Oct 17, 2024
1 parent 49aeb1e commit c88058c
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 49 deletions.
6 changes: 3 additions & 3 deletions docs/griptape-framework/misc/src/events_3.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import cast

from griptape.drivers import OpenAiChatPromptDriver
from griptape.events import CompletionChunkEvent, EventBus, EventListener
from griptape.events import BaseChunkEvent, EventBus, EventListener
from griptape.structures import Pipeline
from griptape.tasks import ToolkitTask
from griptape.tools import PromptSummaryTool, WebScraperTool

EventBus.add_event_listeners(
[
EventListener(
lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True),
event_types=[CompletionChunkEvent],
lambda e: print(cast(BaseChunkEvent, e).token, end="", flush=True),
event_types=[BaseChunkEvent],
)
]
)
Expand Down
12 changes: 5 additions & 7 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
observable,
)
from griptape.events import (
ActionCallCompletionChunkEvent,
ActionChunkEvent,
EventBus,
FinishPromptEvent,
StartPromptEvent,
TextCompletionChunkEvent,
TextChunkEvent,
)
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin
Expand Down Expand Up @@ -133,12 +133,10 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
else:
delta_contents[content.index] = [content]
if isinstance(content, TextDeltaMessageContent):
EventBus.publish_event(TextCompletionChunkEvent(token=content.text))
elif isinstance(content, ActionCallDeltaMessageContent) and content.partial_input is not None:
EventBus.publish_event(TextChunkEvent.from_delta_message_content(content))
elif isinstance(content, ActionCallDeltaMessageContent):
EventBus.publish_event(
ActionCallCompletionChunkEvent(
token=content.partial_input, tag=content.tag, name=content.name, path=content.path
)
ActionChunkEvent.from_delta_message_content(content),
)

# Build a complete content from the content deltas
Expand Down
12 changes: 6 additions & 6 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from .finish_prompt_event import FinishPromptEvent
from .start_structure_run_event import StartStructureRunEvent
from .finish_structure_run_event import FinishStructureRunEvent
from .completion_chunk_event import CompletionChunkEvent
from .text_completion_chunk_event import TextCompletionChunkEvent
from .action_call_completion_chunk_event import ActionCallCompletionChunkEvent
from .base_chunk_event import BaseChunkEvent
from .text_chunk_event import TextChunkEvent
from .action_chunk_event import ActionChunkEvent
from .event_listener import EventListener
from .start_image_generation_event import StartImageGenerationEvent
from .finish_image_generation_event import FinishImageGenerationEvent
Expand All @@ -39,9 +39,9 @@
"FinishPromptEvent",
"StartStructureRunEvent",
"FinishStructureRunEvent",
"CompletionChunkEvent",
"TextCompletionChunkEvent",
"ActionCallCompletionChunkEvent",
"BaseChunkEvent",
"TextChunkEvent",
"ActionChunkEvent",
"EventListener",
"StartImageGenerationEvent",
"FinishImageGenerationEvent",
Expand Down
14 changes: 0 additions & 14 deletions griptape/events/action_call_completion_chunk_event.py

This file was deleted.

32 changes: 32 additions & 0 deletions griptape/events/action_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent

if TYPE_CHECKING:
from griptape.common import BaseDeltaMessageContent


@define
class ActionChunkEvent(BaseChunkEvent):
tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

@classmethod
def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> ActionChunkEvent:
from griptape.common import ActionCallDeltaMessageContent

if isinstance(content, ActionCallDeltaMessageContent):
return cls(
token=content.partial_input if content.partial_input is not None else "",
index=content.index,
tag=content.tag,
name=content.name,
path=content.path,
)

raise ValueError(f"Content is not an instance of ActionCallDeltaMessageContent: {content.__class__.__name__}")
19 changes: 19 additions & 0 deletions griptape/events/base_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import define, field

from griptape.events.base_event import BaseEvent

if TYPE_CHECKING:
from griptape.common import BaseDeltaMessageContent


@define
class BaseChunkEvent(BaseEvent):
token: str = field(metadata={"serializable": True})
index: int = field(default=0, metadata={"serializable": True})

@classmethod
def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> BaseChunkEvent: ...
8 changes: 0 additions & 8 deletions griptape/events/completion_chunk_event.py

This file was deleted.

24 changes: 24 additions & 0 deletions griptape/events/text_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent

if TYPE_CHECKING:
from griptape.common import BaseDeltaMessageContent


@define
class TextChunkEvent(BaseChunkEvent):
token: str = field(metadata={"serializable": True})

@classmethod
def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> TextChunkEvent:
from griptape.common import TextDeltaMessageContent

if isinstance(content, TextDeltaMessageContent):
return cls(token=content.text, index=content.index)

raise ValueError(f"Content is not an instance of TextDeltaMessageContent: {content.__class__.__name__}")
8 changes: 0 additions & 8 deletions griptape/events/text_completion_chunk_event.py

This file was deleted.

6 changes: 3 additions & 3 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from attrs import Attribute, Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent
from griptape.events import BaseChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -54,7 +54,7 @@ def run(self, *args) -> Iterator[TextArtifact]:
break
elif isinstance(event, FinishPromptEvent):
yield TextArtifact(value="\n")
elif isinstance(event, CompletionChunkEvent):
elif isinstance(event, BaseChunkEvent):
yield TextArtifact(value=event.token)
t.join()

Expand All @@ -64,7 +64,7 @@ def event_handler(event: BaseEvent) -> None:

stream_event_listener = EventListener(
handler=event_handler,
event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
)
EventBus.add_event_listener(stream_event_listener)

Expand Down

0 comments on commit c88058c

Please sign in to comment.