Skip to content

Commit

Permalink
Feature/observable tags (#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed Jul 16, 2024
1 parent 5d4c6d6 commit a30bae9
Show file tree
Hide file tree
Showing 23 changed files with 104 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/observability-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ with Observability(observability_driver=observability_driver):
agent.run("Name an animal")
```

Ouput (only relevant because of use of `ConsoleSpanExporter`):
Output (only relevant because of use of `ConsoleSpanExporter`):
```
[06/18/24 06:57:22] INFO PromptTask 2d8ef95bf817480188ae2f74e754308a
Input: Name an animal
Expand Down
7 changes: 5 additions & 2 deletions griptape/common/observable.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import functools

from attrs import define, field, Factory
from inspect import isfunction
from typing import Any, Callable, Optional, TypeVar, cast

from attrs import Factory, define, field

T = TypeVar("T", bound=Callable)

Expand All @@ -29,6 +28,10 @@ def __call__(self):
args = (self.instance, *self.args) if self.instance and not hasattr(self.func, "__self__") else self.args
return self.func(*args, **self.kwargs)

@property
def tags(self) -> Optional[list[str]]:
return self.decorator_kwargs.get("tags")

def __init__(self, *args, **kwargs):
self._instance = None
if len(args) == 1 and len(kwargs) == 0 and isfunction(args[0]):
Expand Down
8 changes: 5 additions & 3 deletions griptape/drivers/observability/base_observability_driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from abc import ABC, abstractmethod
from attrs import define
from griptape.common import Observable
from types import TracebackType
from typing import Any, Optional

from attrs import define

from griptape.common import Observable


@define
class BaseObservabilityDriver(ABC):
def __enter__(self) -> None:
def __enter__(self) -> None: # noqa: B027
pass

def __exit__(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from __future__ import annotations

import os
from collections.abc import Sequence
from typing import Optional
import requests

from urllib.parse import urljoin
from uuid import UUID
from collections.abc import Sequence

from attrs import define, Factory, field
from griptape.drivers.observability.open_telemetry_observability_driver import OpenTelemetryObservabilityDriver
from opentelemetry.trace import get_current_span, INVALID_SPAN
from opentelemetry.sdk.trace import SpanProcessor, ReadableSpan
import requests
from attrs import Factory, define, field
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter, SpanExportResult
from opentelemetry.sdk.util import ns_to_iso_str
from urllib.parse import urljoin
from opentelemetry.trace import INVALID_SPAN, get_current_span

from griptape.drivers.observability.open_telemetry_observability_driver import OpenTelemetryObservabilityDriver


@define
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/observability/no_op_observability_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, Optional

from attrs import define

from griptape.common import Observable
from griptape.drivers import BaseObservabilityDriver
from typing import Any, Optional


@define
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from attrs import define, Factory, field
from griptape.common import Observable
from griptape.drivers import BaseObservabilityDriver
from opentelemetry.trace import format_span_id, get_current_span, get_tracer, INVALID_SPAN
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider, SpanProcessor
from opentelemetry.trace import Tracer, Status, StatusCode
from types import TracebackType
from typing import Any, Optional

from attrs import Factory, define, field
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
from opentelemetry.trace import INVALID_SPAN, Status, StatusCode, Tracer, format_span_id, get_current_span, get_tracer

from griptape.common import Observable
from griptape.drivers import BaseObservabilityDriver


@define
class OpenTelemetryObservabilityDriver(BaseObservabilityDriver):
Expand Down Expand Up @@ -56,10 +57,14 @@ def __exit__(
def observe(self, call: Observable.Call) -> Any:
func = call.func
instance = call.instance
tags = call.tags

class_name = f"{instance.__class__.__name__}." if instance else ""
span_name = f"{class_name}{func.__name__}()"
with self._tracer.start_as_current_span(span_name) as span: # pyright: ignore[reportCallIssue]
if tags is not None:
span.set_attribute("tags", tags)

try:
result = call()
span.set_status(Status(StatusCode.OK))
Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
Expand Down Expand Up @@ -55,6 +56,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.bedrock_client.converse(**self._base_params(prompt_stack))

Expand All @@ -67,6 +69,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from attrs import Attribute, Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent
from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand Down Expand Up @@ -44,6 +44,7 @@ def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001
if stream:
raise ValueError("streaming is not supported")

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
payload = {
"inputs": self.prompt_stack_to_string(prompt_stack),
Expand Down Expand Up @@ -81,6 +82,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise NotImplementedError("streaming is not supported")

Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer
Expand Down Expand Up @@ -70,6 +71,7 @@ class AnthropicPromptDriver(BasePromptDriver):
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.messages.create(**self._base_params(prompt_stack))

Expand All @@ -79,6 +81,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
events = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PromptStack,
TextDeltaMessageContent,
TextMessageContent,
observable,
)
from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin
Expand Down Expand Up @@ -65,6 +66,7 @@ def after_run(self, result: Message) -> None:
),
)

@observable(tags=["PromptDriver.run()"])
def run(self, prompt_stack: PromptStack) -> Message:
for attempt in self.retrying():
with attempt:
Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.common.prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from griptape.drivers import BasePromptDriver
Expand Down Expand Up @@ -53,6 +54,7 @@ class CoherePromptDriver(BasePromptDriver):
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
result = self.client.chat(**self._base_params(prompt_stack))
usage = result.meta.tokens
Expand All @@ -63,6 +65,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
result = self.client.chat_stream(**self._base_params(prompt_stack))

Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/dummy_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from attrs import Factory, define, field

from griptape.common import observable
from griptape.drivers import BasePromptDriver
from griptape.exceptions import DummyException
from griptape.tokenizers import DummyTokenizer
Expand All @@ -19,8 +20,10 @@ class DummyPromptDriver(BasePromptDriver):
model: None = field(init=False, default=None, kw_only=True)
tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
raise DummyException(__class__.__name__, "try_run")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise DummyException(__class__.__name__, "try_stream")
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, GoogleTokenizer
Expand Down Expand Up @@ -62,6 +63,7 @@ class GooglePromptDriver(BasePromptDriver):
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self.__to_google_messages(prompt_stack)
response: GenerateContentResponse = self.model_client.generate_content(
Expand All @@ -80,6 +82,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
messages = self.__to_google_messages(prompt_stack)
response: GenerateContentResponse = self.model_client.generate_content(
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import Factory, define, field

from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent
from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, observable
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand Down Expand Up @@ -50,6 +50,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
prompt = self.prompt_stack_to_string(prompt_stack)

Expand All @@ -68,6 +69,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt = self.prompt_stack_to_string(prompt_stack)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent
from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand Down Expand Up @@ -47,6 +47,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
),
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self._prompt_stack_to_messages(prompt_stack)

Expand Down Expand Up @@ -75,6 +76,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
else:
raise Exception("invalid output format")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise NotImplementedError("streaming is not supported")

Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PromptStack,
TextDeltaMessageContent,
TextMessageContent,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import SimpleTokenizer
Expand Down Expand Up @@ -61,6 +62,7 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.chat(**self._base_params(prompt_stack))

Expand All @@ -72,6 +74,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
else:
raise Exception("invalid model response")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
stream = self.client.chat(**self._base_params(prompt_stack), stream=True)

Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TextDeltaMessageContent,
TextMessageContent,
ToolAction,
observable,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
Expand Down Expand Up @@ -88,6 +89,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
result = self.client.chat.completions.create(**self._base_params(prompt_stack))

Expand All @@ -105,6 +107,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
else:
raise Exception("Completion with more than one choice is not supported yet.")

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True)

Expand Down
2 changes: 2 additions & 0 deletions griptape/observability/observability.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from types import TracebackType
from typing import Any, Optional

from attrs import define, field

from griptape.common import Observable
from griptape.drivers import BaseObservabilityDriver, NoOpObservabilityDriver

Expand Down
Loading

0 comments on commit a30bae9

Please sign in to comment.