From 8ebb262e2d0e5901f37aa08f006f3f5156ef0067 Mon Sep 17 00:00:00 2001 From: dylanholmes <4370153+dylanholmes@users.noreply.github.com> Date: Tue, 16 Jul 2024 15:02:59 -0700 Subject: [PATCH] Set default service_name for GriptapeCloudObservabilityDriver (#988) --- .../drivers/observability-drivers.md | 4 +--- .../structures/observability.md | 4 ++-- griptape/common/observable.py | 8 ++++---- .../griptape_cloud_event_listener_driver.py | 4 ++-- .../observability/base_observability_driver.py | 10 +++++++--- .../griptape_cloud_observability_driver.py | 18 ++++++++++++------ .../no_op_observability_driver.py | 8 ++++++-- .../open_telemetry_observability_driver.py | 13 +++++++++---- griptape/observability/observability.py | 14 ++++++++++---- griptape/tools/calculator/tool.py | 2 +- ...test_griptape_cloud_observability_driver.py | 5 ++--- 11 files changed, 56 insertions(+), 34 deletions(-) diff --git a/docs/griptape-framework/drivers/observability-drivers.md b/docs/griptape-framework/drivers/observability-drivers.md index e2ea6c9e6..41cbf26af 100644 --- a/docs/griptape-framework/drivers/observability-drivers.md +++ b/docs/griptape-framework/drivers/observability-drivers.md @@ -23,9 +23,7 @@ from griptape.rules import Rule from griptape.structures import Agent from griptape.observability import Observability -observability_driver = GriptapeCloudObservabilityDriver( - service_name="my-gt-app", -) +observability_driver = GriptapeCloudObservabilityDriver() with Observability(observability_driver=observability_driver): agent = Agent(rules=[Rule("Output one word")]) diff --git a/docs/griptape-framework/structures/observability.md b/docs/griptape-framework/structures/observability.md index 59ff21881..5a9e9c51c 100644 --- a/docs/griptape-framework/structures/observability.md +++ b/docs/griptape-framework/structures/observability.md @@ -9,7 +9,7 @@ from griptape.drivers import GriptapeCloudObservabilityDriver from griptape.structures import Agent from griptape.observability import Observability -observability_driver = GriptapeCloudObservabilityDriver(service_name="hot-fire") +observability_driver = GriptapeCloudObservabilityDriver() with Observability(observability_driver=observability_driver): # Important! Only code within this block is subject to observability @@ -47,7 +47,7 @@ class MyClass: my_function() time.sleep(2) -observability_driver = GriptapeCloudObservabilityDriver(service_name="my-app") +observability_driver = GriptapeCloudObservabilityDriver() # When invoking the instrumented code from within the Observability context manager, the # telemetry for the custom code will be sent to the destination specified by the driver. diff --git a/griptape/common/observable.py b/griptape/common/observable.py index 2558ee6a0..aa675dfbe 100644 --- a/griptape/common/observable.py +++ b/griptape/common/observable.py @@ -23,7 +23,7 @@ class Call: decorator_args: tuple[Any, ...] = field(default=Factory(tuple), kw_only=True) decorator_kwargs: dict[str, Any] = field(default=Factory(dict), kw_only=True) - def __call__(self): + def __call__(self) -> Any: # If self.func has a __self__ attribute, it is a bound method and we do not need to pass the instance. args = (self.instance, *self.args) if self.instance and not hasattr(self.func, "__self__") else self.args return self.func(*args, **self.kwargs) @@ -32,7 +32,7 @@ def __call__(self): def tags(self) -> Optional[list[str]]: return self.decorator_kwargs.get("tags") - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: self._instance = None if len(args) == 1 and len(kwargs) == 0 and isfunction(args[0]): # Parameterless call. In otherwords, the `@observable` annotation @@ -49,11 +49,11 @@ def __init__(self, *args, **kwargs): self.decorator_args = args self.decorator_kwargs = kwargs - def __get__(self, obj, objtype=None): + def __get__(self, obj: Any, objtype: Any = None) -> Observable: self._instance = obj return self - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Any: if self._func: # Parameterless call (self._func was a set in __init__) from griptape.observability.observability import Observability diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 14dd44154..4b379fc5b 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -39,7 +39,7 @@ def validate_run_id(self, _: Attribute, structure_run_id: str) -> None: "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).", ) - def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None: + def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None: from griptape.observability.observability import Observability event_payload = event.to_dict() if isinstance(event, BaseEvent) else event @@ -48,7 +48,7 @@ def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None: if span_id is not None: event_payload["span_id"] = span_id - super().publish_event(event_payload, flush) + super().publish_event(event_payload, flush=flush) def try_publish_event_payload(self, event_payload: dict) -> None: url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events") diff --git a/griptape/drivers/observability/base_observability_driver.py b/griptape/drivers/observability/base_observability_driver.py index c957486bf..41b779578 100644 --- a/griptape/drivers/observability/base_observability_driver.py +++ b/griptape/drivers/observability/base_observability_driver.py @@ -1,10 +1,14 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from types import TracebackType -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import define -from griptape.common import Observable +if TYPE_CHECKING: + from types import TracebackType + + from griptape.common import Observable @define diff --git a/griptape/drivers/observability/griptape_cloud_observability_driver.py b/griptape/drivers/observability/griptape_cloud_observability_driver.py index 1758be569..c6df9cd35 100644 --- a/griptape/drivers/observability/griptape_cloud_observability_driver.py +++ b/griptape/drivers/observability/griptape_cloud_observability_driver.py @@ -1,23 +1,28 @@ from __future__ import annotations import os -from collections.abc import Sequence -from typing import Optional +from typing import TYPE_CHECKING, Optional from urllib.parse import urljoin from uuid import UUID import requests -from attrs import Factory, define, field -from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor +from attrs import Attribute, Factory, define, field +from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter, SpanExportResult from opentelemetry.sdk.util import ns_to_iso_str from opentelemetry.trace import INVALID_SPAN, get_current_span from griptape.drivers.observability.open_telemetry_observability_driver import OpenTelemetryObservabilityDriver +if TYPE_CHECKING: + from collections.abc import Sequence + + from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor + @define class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver): + service_name: str = field(default="griptape-cloud", kw_only=True) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True ) @@ -40,6 +45,7 @@ class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver): ), kw_only=True, ) + trace_provider: TracerProvider = field(default=Factory(lambda: TracerProvider()), kw_only=True) @staticmethod def format_trace_id(trace_id: int) -> str: @@ -84,8 +90,8 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: response = requests.post(url=url, json=payload, headers=self.headers) return SpanExportResult.SUCCESS if response.status_code == 200 else SpanExportResult.FAILURE - @structure_run_id.validator # pyright: ignore - def validate_run_id(self, _, structure_run_id: str): + @structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue] + def validate_run_id(self, _: Attribute, structure_run_id: str) -> None: if structure_run_id is None: raise ValueError( "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)." diff --git a/griptape/drivers/observability/no_op_observability_driver.py b/griptape/drivers/observability/no_op_observability_driver.py index bd3ce4fd6..c0fc9bfcf 100644 --- a/griptape/drivers/observability/no_op_observability_driver.py +++ b/griptape/drivers/observability/no_op_observability_driver.py @@ -1,10 +1,14 @@ -from typing import Any, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional from attrs import define -from griptape.common import Observable from griptape.drivers import BaseObservabilityDriver +if TYPE_CHECKING: + from griptape.common import Observable + @define class NoOpObservabilityDriver(BaseObservabilityDriver): diff --git a/griptape/drivers/observability/open_telemetry_observability_driver.py b/griptape/drivers/observability/open_telemetry_observability_driver.py index e5707b4c2..fec067a4b 100644 --- a/griptape/drivers/observability/open_telemetry_observability_driver.py +++ b/griptape/drivers/observability/open_telemetry_observability_driver.py @@ -1,5 +1,6 @@ -from types import TracebackType -from typing import Any, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field from opentelemetry.instrumentation.threading import ThreadingInstrumentor @@ -7,9 +8,13 @@ from opentelemetry.sdk.trace import SpanProcessor, TracerProvider from opentelemetry.trace import INVALID_SPAN, Status, StatusCode, Tracer, format_span_id, get_current_span, get_tracer -from griptape.common import Observable from griptape.drivers import BaseObservabilityDriver +if TYPE_CHECKING: + from types import TracebackType + + from griptape.common import Observable + @define class OpenTelemetryObservabilityDriver(BaseObservabilityDriver): @@ -25,7 +30,7 @@ class OpenTelemetryObservabilityDriver(BaseObservabilityDriver): _tracer: Optional[Tracer] = None _root_span_context_manager: Any = None - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self.trace_provider.add_span_processor(self.span_processor) self._tracer = get_tracer(self.service_name, tracer_provider=self.trace_provider) diff --git a/griptape/observability/observability.py b/griptape/observability/observability.py index 5397e74d2..1cbc589eb 100644 --- a/griptape/observability/observability.py +++ b/griptape/observability/observability.py @@ -1,5 +1,6 @@ -from types import TracebackType -from typing import Any, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional from attrs import define, field @@ -9,6 +10,11 @@ _no_op_observability_driver = NoOpObservabilityDriver() _global_observability_driver: Optional[BaseObservabilityDriver] = None +if TYPE_CHECKING: + from types import TracebackType + + from griptape.common import Observable + @define class Observability: @@ -20,7 +26,7 @@ def get_global_driver() -> Optional[BaseObservabilityDriver]: return _global_observability_driver @staticmethod - def set_global_driver(driver: Optional[BaseObservabilityDriver]): + def set_global_driver(driver: Optional[BaseObservabilityDriver]) -> None: global _global_observability_driver _global_observability_driver = driver @@ -34,7 +40,7 @@ def get_span_id() -> Optional[str]: driver = Observability.get_global_driver() or _no_op_observability_driver return driver.get_span_id() - def __enter__(self): + def __enter__(self) -> None: if Observability.get_global_driver() is not None: raise ValueError("Observability driver already set.") Observability.set_global_driver(self.observability_driver) diff --git a/griptape/tools/calculator/tool.py b/griptape/tools/calculator/tool.py index 141d1a333..e8fcb1ed4 100644 --- a/griptape/tools/calculator/tool.py +++ b/griptape/tools/calculator/tool.py @@ -21,7 +21,7 @@ class Calculator(BaseTool): }, ) def calculate(self, params: dict) -> BaseArtifact: - import numexpr + import numexpr # pyright: ignore[reportMissingImports] try: expression = params["values"]["expression"] diff --git a/tests/unit/drivers/observability/test_griptape_cloud_observability_driver.py b/tests/unit/drivers/observability/test_griptape_cloud_observability_driver.py index fec2ae800..bfb72f998 100644 --- a/tests/unit/drivers/observability/test_griptape_cloud_observability_driver.py +++ b/tests/unit/drivers/observability/test_griptape_cloud_observability_driver.py @@ -1,4 +1,3 @@ -from datetime import datetime import pytest from griptape.common import Observable from griptape.drivers import GriptapeCloudObservabilityDriver @@ -11,7 +10,7 @@ class TestGriptapeCloudObservabilityDriver: @pytest.fixture def driver(self): return GriptapeCloudObservabilityDriver( - service_name="test", base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id" + base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id" ) @pytest.fixture(autouse=True) @@ -26,7 +25,7 @@ def mock_span_exporter(self, mock_span_exporter_class): def test_init(self, mock_span_exporter_class, mock_span_exporter): GriptapeCloudObservabilityDriver( - service_name="test", base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id" + base_url="http://base-url:1234", api_key="api-key", structure_run_id="structure-run-id" ) assert mock_span_exporter_class.call_count == 1