Skip to content

Commit

Permalink
Set default service_name for GriptapeCloudObservabilityDriver + ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed Jul 16, 2024
1 parent a30bae9 commit 9b33397
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 34 deletions.
4 changes: 1 addition & 3 deletions docs/griptape-framework/drivers/observability-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ from griptape.rules import Rule
from griptape.structures import Agent
from griptape.observability import Observability

observability_driver = GriptapeCloudObservabilityDriver(
service_name="my-gt-app",
)
observability_driver = GriptapeCloudObservabilityDriver()

with Observability(observability_driver=observability_driver):
agent = Agent(rules=[Rule("Output one word")])
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/observability.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from griptape.drivers import GriptapeCloudObservabilityDriver
from griptape.structures import Agent
from griptape.observability import Observability

observability_driver = GriptapeCloudObservabilityDriver(service_name="hot-fire")
observability_driver = GriptapeCloudObservabilityDriver()

with Observability(observability_driver=observability_driver):
# Important! Only code within this block is subject to observability
Expand Down Expand Up @@ -47,7 +47,7 @@ class MyClass:
my_function()
time.sleep(2)

observability_driver = GriptapeCloudObservabilityDriver(service_name="my-app")
observability_driver = GriptapeCloudObservabilityDriver()

# When invoking the instrumented code from within the Observability context manager, the
# telemetry for the custom code will be sent to the destination specified by the driver.
Expand Down
8 changes: 4 additions & 4 deletions griptape/common/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Call:
decorator_args: tuple[Any, ...] = field(default=Factory(tuple), kw_only=True)
decorator_kwargs: dict[str, Any] = field(default=Factory(dict), kw_only=True)

def __call__(self):
def __call__(self) -> Any:
# If self.func has a __self__ attribute, it is a bound method and we do not need to pass the instance.
args = (self.instance, *self.args) if self.instance and not hasattr(self.func, "__self__") else self.args
return self.func(*args, **self.kwargs)
Expand All @@ -32,7 +32,7 @@ def __call__(self):
def tags(self) -> Optional[list[str]]:
return self.decorator_kwargs.get("tags")

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self._instance = None
if len(args) == 1 and len(kwargs) == 0 and isfunction(args[0]):
# Parameterless call. In otherwords, the `@observable` annotation
Expand All @@ -49,11 +49,11 @@ def __init__(self, *args, **kwargs):
self.decorator_args = args
self.decorator_kwargs = kwargs

def __get__(self, obj, objtype=None):
def __get__(self, obj: Any, objtype: Any = None) -> Observable:
self._instance = obj
return self

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> Any:
if self._func:
# Parameterless call (self._func was a set in __init__)
from griptape.observability.observability import Observability
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def validate_run_id(self, _: Attribute, structure_run_id: str) -> None:
"structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).",
)

def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None:
def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None:
from griptape.observability.observability import Observability

event_payload = event.to_dict() if isinstance(event, BaseEvent) else event
Expand All @@ -48,7 +48,7 @@ def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None:
if span_id is not None:
event_payload["span_id"] = span_id

super().publish_event(event_payload, flush)
super().publish_event(event_payload, flush=flush)

def try_publish_event_payload(self, event_payload: dict) -> None:
url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events")
Expand Down
10 changes: 7 additions & 3 deletions griptape/drivers/observability/base_observability_driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import define

from griptape.common import Observable
if TYPE_CHECKING:
from types import TracebackType

from griptape.common import Observable


@define
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from __future__ import annotations

import os
from collections.abc import Sequence
from typing import Optional
from typing import TYPE_CHECKING, Optional
from urllib.parse import urljoin
from uuid import UUID

import requests
from attrs import Factory, define, field
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
from attrs import Attribute, Factory, define, field
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter, SpanExportResult
from opentelemetry.sdk.util import ns_to_iso_str
from opentelemetry.trace import INVALID_SPAN, get_current_span

from griptape.drivers.observability.open_telemetry_observability_driver import OpenTelemetryObservabilityDriver

if TYPE_CHECKING:
from collections.abc import Sequence

from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor


@define
class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver):
service_name: str = field(default="griptape-cloud", kw_only=True)
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True
)
Expand All @@ -40,6 +45,7 @@ class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver):
),
kw_only=True,
)
trace_provider: TracerProvider = field(default=Factory(lambda: TracerProvider()), kw_only=True)

@staticmethod
def format_trace_id(trace_id: int) -> str:
Expand Down Expand Up @@ -84,8 +90,8 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
response = requests.post(url=url, json=payload, headers=self.headers)
return SpanExportResult.SUCCESS if response.status_code == 200 else SpanExportResult.FAILURE

@structure_run_id.validator # pyright: ignore
def validate_run_id(self, _, structure_run_id: str):
@structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_run_id(self, _: Attribute, structure_run_id: str) -> None:
if structure_run_id is None:
raise ValueError(
"structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)."
Expand Down
8 changes: 6 additions & 2 deletions griptape/drivers/observability/no_op_observability_driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any, Optional
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from attrs import define

from griptape.common import Observable
from griptape.drivers import BaseObservabilityDriver

if TYPE_CHECKING:
from griptape.common import Observable


@define
class NoOpObservabilityDriver(BaseObservabilityDriver):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from types import TracebackType
from typing import Any, Optional
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
from opentelemetry.trace import INVALID_SPAN, Status, StatusCode, Tracer, format_span_id, get_current_span, get_tracer

from griptape.common import Observable
from griptape.drivers import BaseObservabilityDriver

if TYPE_CHECKING:
from types import TracebackType

from griptape.common import Observable


@define
class OpenTelemetryObservabilityDriver(BaseObservabilityDriver):
Expand All @@ -25,7 +30,7 @@ class OpenTelemetryObservabilityDriver(BaseObservabilityDriver):
_tracer: Optional[Tracer] = None
_root_span_context_manager: Any = None

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self.trace_provider.add_span_processor(self.span_processor)
self._tracer = get_tracer(self.service_name, tracer_provider=self.trace_provider)

Expand Down
14 changes: 10 additions & 4 deletions griptape/observability/observability.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from types import TracebackType
from typing import Any, Optional
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

Expand All @@ -9,6 +10,11 @@
_no_op_observability_driver = NoOpObservabilityDriver()
_global_observability_driver: Optional[BaseObservabilityDriver] = None

if TYPE_CHECKING:
from types import TracebackType

from griptape.common import Observable


@define
class Observability:
Expand All @@ -20,7 +26,7 @@ def get_global_driver() -> Optional[BaseObservabilityDriver]:
return _global_observability_driver

@staticmethod
def set_global_driver(driver: Optional[BaseObservabilityDriver]):
def set_global_driver(driver: Optional[BaseObservabilityDriver]) -> None:
global _global_observability_driver
_global_observability_driver = driver

Expand All @@ -34,7 +40,7 @@ def get_span_id() -> Optional[str]:
driver = Observability.get_global_driver() or _no_op_observability_driver
return driver.get_span_id()

def __enter__(self):
def __enter__(self) -> None:
if Observability.get_global_driver() is not None:
raise ValueError("Observability driver already set.")
Observability.set_global_driver(self.observability_driver)
Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/calculator/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Calculator(BaseTool):
},
)
def calculate(self, params: dict) -> BaseArtifact:
import numexpr
import numexpr # pyright: ignore[reportMissingImports]

try:
expression = params["values"]["expression"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime
import pytest
from griptape.common import Observable
from griptape.drivers import GriptapeCloudObservabilityDriver
Expand All @@ -11,7 +10,7 @@ class TestGriptapeCloudObservabilityDriver:
@pytest.fixture
def driver(self):
return GriptapeCloudObservabilityDriver(
service_name="test", base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id"
base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id"
)

@pytest.fixture(autouse=True)
Expand All @@ -26,7 +25,7 @@ def mock_span_exporter(self, mock_span_exporter_class):

def test_init(self, mock_span_exporter_class, mock_span_exporter):
GriptapeCloudObservabilityDriver(
service_name="test", base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id"
base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id"
)

assert mock_span_exporter_class.call_count == 1
Expand Down

0 comments on commit 9b33397

Please sign in to comment.