diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index d3c17752..2a9ac5d4 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -1948,13 +1948,13 @@ def __init__(self, span: trace_api.Span) -> None: self._token = context_api.attach(trace_api.set_span_in_context(self._span)) def __enter__(self) -> FastLogfireSpan: + self._span.__enter__() return self @handle_internal_errors() def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: context_api.detach(self._token) - _exit_span(self._span, exc_value) - self._span.end() + self._span.__exit__(exc_type, exc_value, traceback) # Changes to this class may need to be reflected in `FastLogfireSpan` and `NoopSpan` as well. @@ -1990,6 +1990,7 @@ def __enter__(self) -> LogfireSpan: attributes=self._otlp_attributes, links=self._links, ) + self._span.__enter__() if self._token is None: # pragma: no branch self._token = context_api.attach(trace_api.set_span_in_context(self._span)) @@ -1999,14 +2000,17 @@ def __enter__(self) -> LogfireSpan: def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: if self._token is None: # pragma: no cover return + assert self._span is not None context_api.detach(self._token) self._token = None - - assert self._span is not None - _exit_span(self._span, exc_value) - - self.end() + if self._span.is_recording(): + with handle_internal_errors(): + if self._added_attributes: + self._span.set_attribute( + ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties) + ) + self._span.__exit__(exc_type, exc_value, traceback) @property def message_template(self) -> str | None: # pragma: no cover @@ -2032,26 +2036,6 @@ def message(self) -> str: def message(self, message: str): self._set_attribute(ATTRIBUTES_MESSAGE_KEY, message) - def end(self, end_time: int | None = None) -> None: - """Sets the current time as the span's end time. - - The span's end time is the wall time at which the operation finished. - - Only the first call to this method is recorded, further calls are ignored so you - can call this within the span's context manager to end it before the context manager - exits. - """ - if self._span is None: # pragma: no cover - raise RuntimeError('Span has not been started') - if self._span.is_recording(): - with handle_internal_errors(): - if self._added_attributes: - self._span.set_attribute( - ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties) - ) - - self._span.end(end_time) - @handle_internal_errors() def set_attribute(self, key: str, value: Any) -> None: """Sets an attribute on the span. @@ -2183,16 +2167,6 @@ def is_recording(self) -> bool: return False -def _exit_span(span: trace_api.Span, exception: BaseException | None) -> None: - if not span.is_recording(): - return - - # record exception if present - # isinstance is to ignore BaseException - if isinstance(exception, Exception): - record_exception(span, exception, escaped=True) - - AttributesValueType = TypeVar('AttributesValueType', bound=Union[Any, otel_types.AttributeValue]) diff --git a/logfire/_internal/tracer.py b/logfire/_internal/tracer.py index 5d7ad841..1b8b9fc4 100644 --- a/logfire/_internal/tracer.py +++ b/logfire/_internal/tracer.py @@ -171,6 +171,12 @@ def record_exception( timestamp = timestamp or self.ns_timestamp_generator() record_exception(self.span, exception, attributes=attributes, timestamp=timestamp, escaped=escaped) + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: + if self.is_recording(): + if isinstance(exc_value, BaseException): + self.record_exception(exc_value, escaped=True) + self.end() + if not TYPE_CHECKING: # pragma: no branch # for ReadableSpan def __getattr__(self, name: str) -> Any: diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 420e8d39..be83dc12 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -9,6 +9,7 @@ from functools import partial from logging import getLogger from typing import Any, Callable +from unittest.mock import patch import pytest from dirty_equals import IsInt, IsJson, IsStr @@ -36,6 +37,7 @@ ) from logfire._internal.formatter import FormattingFailedWarning, InspectArgumentsFailedWarning from logfire._internal.main import NoopSpan +from logfire._internal.tracer import record_exception from logfire._internal.utils import is_instrumentation_suppressed from logfire.integrations.logging import LogfireLoggingHandler from logfire.testing import TestExporter @@ -3171,3 +3173,22 @@ def test_suppress_scopes(exporter: TestExporter, metrics_reader: InMemoryMetricR } ] ) + + +def test_logfire_span_records_exceptions_once(): + n_calls_to_record_exception = 0 + + def patched_record_exception(*args: Any, **kwargs: Any) -> Any: + nonlocal n_calls_to_record_exception + n_calls_to_record_exception += 1 + + return record_exception(*args, **kwargs) + + with patch('logfire._internal.tracer.record_exception', patched_record_exception), patch( + 'logfire._internal.main.record_exception', patched_record_exception + ): + with pytest.raises(RuntimeError): + with logfire.span('foo'): + raise RuntimeError('error') + + assert n_calls_to_record_exception == 1