Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(tracer): core additions #10716

Merged
merged 11 commits into from
Sep 24, 2024
80 changes: 43 additions & 37 deletions ddtrace/internal/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _on_jsonify_context_started_flask(ctx):

The names of these events follow the pattern ``context.[started|ended].<context_name>``.
"""

from contextlib import contextmanager
import logging
import sys
Expand All @@ -111,6 +112,7 @@ def _on_jsonify_context_started_flask(ctx):
from typing import Optional # noqa:F401
from typing import Tuple # noqa:F401

from ddtrace._trace.span import Span # noqa:F401
from ddtrace.vendor.debtcollector import deprecate

from ..utils.deprecations import DDTraceDeprecationWarning
Expand All @@ -124,10 +126,6 @@ def _on_jsonify_context_started_flask(ctx):
from .event_hub import reset as reset_listeners # noqa:F401


if TYPE_CHECKING:
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
from ddtrace._trace.span import Span # noqa:F401


try:
import contextvars
except ImportError:
Expand All @@ -137,7 +135,6 @@ def _on_jsonify_context_started_flask(ctx):
log = logging.getLogger(__name__)


_CURRENT_CONTEXT = None
ROOT_CONTEXT_ID = "__root"
SPAN_DEPRECATION_MESSAGE = (
"The 'span' keyword argument on ExecutionContext methods is deprecated and will be removed in a future version."
Expand Down Expand Up @@ -168,15 +165,17 @@ class ExecutionContext:
__slots__ = ["identifier", "_data", "_parents", "_span", "_token"]

def __init__(self, identifier, parent=None, span=None, **kwargs):
_deprecate_span_kwarg(span)
if span is not None:
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
_deprecate_span_kwarg(span)
self.identifier = identifier
self._data = {}
self._parents = []
self._span = span
if parent is not None:
self.addParent(parent)
self._data.update(kwargs)
if self._span is None and _CURRENT_CONTEXT is not None:

if self._span is None and "_CURRENT_CONTEXT" in globals():
self._token = _CURRENT_CONTEXT.set(self)
dispatch("context.started.%s" % self.identifier, (self,))
dispatch("context.started.start_span.%s" % self.identifier, (self,))
Expand Down Expand Up @@ -225,7 +224,7 @@ def context_with_data(cls, identifier, parent=None, span=None, **kwargs):
finally:
new_context.end()

def get_item(self, data_key: str, default: Optional[Any] = None, traverse: Optional[bool] = True) -> Any:
def get_item(self, data_key: str, default: Optional[Any] = None, traverse: bool = True) -> Any:
# NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance
current = self
while current is not None:
Expand All @@ -242,25 +241,32 @@ def __getitem__(self, key: str):
raise KeyError
return value

def get_items(self, data_keys):
# type: (List[str]) -> Optional[Any]
def get_items(self, data_keys: List[str]) -> List[Optional[Any]]:
return [self.get_item(key) for key in data_keys]

def set_item(self, data_key, data_value):
# type: (str, Optional[Any]) -> None
def set_item(self, data_key: str, data_value: Optional[Any]) -> None:
self._data[data_key] = data_value

def set_safe(self, data_key, data_value):
# type: (str, Optional[Any]) -> None
def set_safe(self, data_key: str, data_value: Optional[Any]) -> None:
if data_key in self._data:
raise ValueError("Cannot overwrite ExecutionContext data key '%s'", data_key)
return self.set_item(data_key, data_value)

def set_items(self, keys_values):
# type: (Dict[str, Optional[Any]]) -> None
def set_items(self, keys_values: Dict[str, Optional[Any]]) -> None:
for data_key, data_value in keys_values.items():
self.set_item(data_key, data_value)

def discard_item(self, data_key: str, traverse: bool = True) -> None:
# NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance
current = self
while current is not None:
if data_key in current._data:
del current._data[data_key]
return
if not traverse:
return
current = current.parent

def root(self):
if self.identifier == ROOT_CONTEXT_ID:
return self
Expand All @@ -276,56 +282,56 @@ def __getattr__(name):
raise AttributeError


def _reset_context():
global _CURRENT_CONTEXT
_CURRENT_CONTEXT = contextvars.ContextVar("ExecutionContext_var", default=ExecutionContext(ROOT_CONTEXT_ID))


_reset_context()
_CURRENT_CONTEXT = contextvars.ContextVar("ExecutionContext_var", default=ExecutionContext(ROOT_CONTEXT_ID))
_CONTEXT_CLASS = ExecutionContext


def context_with_data(identifier, parent=None, **kwargs):
return _CONTEXT_CLASS.context_with_data(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs)


def get_item(data_key, span=None):
# type: (str, Optional[Span]) -> Optional[Any]
def get_root() -> ExecutionContext:
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
return _CURRENT_CONTEXT.get().root()


def get_item(
data_key: str, span: Optional[Span] = None, *, default: Optional[Any] = None, traverse: bool = True
) -> Any:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
return span._local_root._get_ctx_item(data_key)
else:
return _CURRENT_CONTEXT.get().get_item(data_key) # type: ignore
return _CURRENT_CONTEXT.get().get_item(data_key, default=default, traverse=traverse)


def get_items(data_keys, span=None):
# type: (List[str], Optional[Span]) -> Optional[Any]
def get_items(data_keys: List[str], span: Optional[Span] = None) -> List[Optional[Any]]:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
return [span._local_root._get_ctx_item(key) for key in data_keys]
else:
return _CURRENT_CONTEXT.get().get_items(data_keys) # type: ignore
return _CURRENT_CONTEXT.get().get_items(data_keys)


def set_safe(data_key, data_value):
# type: (str, Optional[Any]) -> None
_CURRENT_CONTEXT.get().set_safe(data_key, data_value) # type: ignore
def set_safe(data_key: str, data_value: Optional[Any]) -> None:
_CURRENT_CONTEXT.get().set_safe(data_key, data_value)


# NB Don't call these set_* functions from `ddtrace.contrib`, only from product code!
def set_item(data_key, data_value, span=None):
# type: (str, Optional[Any], Optional[Span]) -> None
def set_item(data_key: str, data_value: Optional[Any], span: Optional[Span] = None) -> None:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
span._local_root._set_ctx_item(data_key, data_value)
else:
_CURRENT_CONTEXT.get().set_item(data_key, data_value) # type: ignore
_CURRENT_CONTEXT.get().set_item(data_key, data_value)


def set_items(keys_values, span=None):
# type: (Dict[str, Optional[Any]], Optional[Span]) -> None
def set_items(keys_values: Dict[str, Optional[Any]], span: Optional[Span] = None) -> None:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
span._local_root._set_ctx_items(keys_values)
else:
_CURRENT_CONTEXT.get().set_items(keys_values) # type: ignore
_CURRENT_CONTEXT.get().set_items(keys_values)


def discard_item(data_key: str, traverse: bool = True) -> None:
_CURRENT_CONTEXT.get().discard_item(data_key, traverse=traverse)
41 changes: 41 additions & 0 deletions tests/internal/test_context_events_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,44 @@ def make_another_context():
task2.join()

assert results[data_key] == "right"


@pytest.mark.parametrize("nb_threads", [2, 16, 256])
def test_core_in_threads(nb_threads):
"""Test nested contexts in multiple threads with set/get/discard and global values in main context."""
import asyncio
import random

witness = object()

async def get_set_isolated(value: str):
with core.context_with_data(value):
core.set_item("key", value)
with core.context_with_data(value):
v = f"in {value}"
core.set_item("key", v)
await asyncio.sleep(random.random())
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
assert core.get_item("key") == v
core.discard_item("key")
assert core.get_item("key", default=witness, traverse=False) is witness
assert core.get_item("key") == value
core.get_item("global_counter")["value"] += 1
assert core.get_item("key") == value
core.discard_item("key")
assert core.get_item("key", default=witness) is witness
return witness

async def create_tasks_func():
tasks = [loop.create_task(get_set_isolated(str(i))) for i in range(nb_threads)]
await asyncio.wait(tasks)
return [task.result() for task in tasks]

with core.context_with_data("main"):
core.set_item("global_counter", {"value": 0})
loop = asyncio.new_event_loop()
assert not loop.is_closed()
res = loop.run_until_complete(create_tasks_func())
assert isinstance(res, list)
assert all(s is witness for s in res)
loop.close()
assert core.get_item("global_counter")["value"] == nb_threads
Loading