Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(tracer): core additions #10716

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions ddtrace/_trace/trace_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __getattribute__(self, name):
def _get_parameters_for_new_span_directly_from_context(ctx: core.ExecutionContext) -> Dict[str, str]:
span_kwargs = {}
for parameter_name in {"span_type", "resource", "service", "child_of", "activate"}:
parameter_value = ctx.get_item(parameter_name, traverse=False)
parameter_value = ctx.get_local_item(parameter_name)
if parameter_value:
span_kwargs[parameter_name] = parameter_value
return span_kwargs
Expand All @@ -111,7 +111,7 @@ def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -
trace_utils.activate_distributed_headers(
tracer, int_config=distributed_headers_config, request_headers=ctx["distributed_headers"]
)
distributed_context = ctx.get_item("distributed_context", traverse=True)
distributed_context = ctx.get_item("distributed_context")
if distributed_context and not call_trace:
span_kwargs["child_of"] = distributed_context
span_kwargs.update(kwargs)
Expand Down Expand Up @@ -335,9 +335,11 @@ def _on_request_complete(ctx, closing_iterable, app_is_iterator):
# start flask.response span. This span will be finished after iter(result) is closed.
# start_span(child_of=...) is used to ensure correct parenting.
resp_span = middleware.tracer.start_span(
middleware._response_call_name
if hasattr(middleware, "_response_call_name")
else middleware._response_span_name,
(
middleware._response_call_name
if hasattr(middleware, "_response_call_name")
else middleware._response_span_name
),
child_of=req_span,
activate=True,
)
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/appsec/_iast/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def on_span_finish(self, span: Span):

span.set_metric(IAST.ENABLED, 1.0)

report_data: IastSpanReporter = core.get_item(IAST.CONTEXT_KEY, span=span) # type: ignore
report_data: IastSpanReporter = core.get_item(IAST.CONTEXT_KEY, span=span)

if report_data:
report_data.build_and_scrub_value_parts()
Expand Down
94 changes: 56 additions & 38 deletions ddtrace/internal/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _on_jsonify_context_started_flask(ctx):

The names of these events follow the pattern ``context.[started|ended].<context_name>``.
"""

from contextlib import contextmanager
import logging
import sys
Expand All @@ -111,6 +112,7 @@ def _on_jsonify_context_started_flask(ctx):
from typing import Optional # noqa:F401
from typing import Tuple # noqa:F401

from ddtrace._trace.span import Span # noqa:F401
from ddtrace.vendor.debtcollector import deprecate

from ..utils.deprecations import DDTraceDeprecationWarning
Expand All @@ -124,10 +126,6 @@ def _on_jsonify_context_started_flask(ctx):
from .event_hub import reset as reset_listeners # noqa:F401


if TYPE_CHECKING:
from ddtrace._trace.span import Span # noqa:F401


try:
import contextvars
except ImportError:
Expand All @@ -137,7 +135,6 @@ def _on_jsonify_context_started_flask(ctx):
log = logging.getLogger(__name__)


_CURRENT_CONTEXT = None
ROOT_CONTEXT_ID = "__root"
SPAN_DEPRECATION_MESSAGE = (
"The 'span' keyword argument on ExecutionContext methods is deprecated and will be removed in a future version."
Expand Down Expand Up @@ -168,15 +165,17 @@ class ExecutionContext:
__slots__ = ["identifier", "_data", "_parents", "_span", "_token"]

def __init__(self, identifier, parent=None, span=None, **kwargs):
_deprecate_span_kwarg(span)
if span is not None:
_deprecate_span_kwarg(span)
self.identifier = identifier
self._data = {}
self._parents = []
self._span = span
if parent is not None:
self.addParent(parent)
self._data.update(kwargs)
if self._span is None and _CURRENT_CONTEXT is not None:

if self._span is None and "_CURRENT_CONTEXT" in globals():
self._token = _CURRENT_CONTEXT.set(self)
dispatch("context.started.%s" % self.identifier, (self,))
dispatch("context.started.start_span.%s" % self.identifier, (self,))
Expand Down Expand Up @@ -225,42 +224,49 @@ def context_with_data(cls, identifier, parent=None, span=None, **kwargs):
finally:
new_context.end()

def get_item(self, data_key: str, default: Optional[Any] = None, traverse: Optional[bool] = True) -> Any:
def get_item(current, data_key: str, default: Optional[Any] = None) -> Any:
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
# NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance
current = self
while current is not None:
if data_key in current._data:
return current._data.get(data_key)
if not traverse:
break
current = current.parent
return default

def get_local_item(self, data_key: str, default: Optional[Any] = None) -> Any:
return self._data.get(data_key, default)

def __getitem__(self, key: str):
value = self.get_item(key)
if value is None and key not in self._data:
raise KeyError
return value

def get_items(self, data_keys):
# type: (List[str]) -> Optional[Any]
def get_items(self, data_keys: List[str]) -> List[Optional[Any]]:
return [self.get_item(key) for key in data_keys]

def set_item(self, data_key, data_value):
# type: (str, Optional[Any]) -> None
def set_item(self, data_key: str, data_value: Optional[Any]) -> None:
self._data[data_key] = data_value

def set_safe(self, data_key, data_value):
# type: (str, Optional[Any]) -> None
def set_safe(self, data_key: str, data_value: Optional[Any]) -> None:
if data_key in self._data:
raise ValueError("Cannot overwrite ExecutionContext data key '%s'", data_key)
return self.set_item(data_key, data_value)

def set_items(self, keys_values):
# type: (Dict[str, Optional[Any]]) -> None
def set_items(self, keys_values: Dict[str, Optional[Any]]) -> None:
for data_key, data_value in keys_values.items():
self.set_item(data_key, data_value)

def discard_item(current, data_key: str) -> None:
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
# NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance
while current is not None:
if data_key in current._data:
del current._data[data_key]
return
current = current.parent

def discard_local_item(self, data_key: str) -> None:
self._data.pop(data_key, None)

def root(self):
if self.identifier == ROOT_CONTEXT_ID:
return self
Expand All @@ -276,56 +282,68 @@ def __getattr__(name):
raise AttributeError


_CURRENT_CONTEXT = contextvars.ContextVar("ExecutionContext_var", default=ExecutionContext(ROOT_CONTEXT_ID))
_CONTEXT_CLASS = ExecutionContext


def _reset_context():
"""private function to reset the context. Only used in testing"""
global _CURRENT_CONTEXT
_CURRENT_CONTEXT = contextvars.ContextVar("ExecutionContext_var", default=ExecutionContext(ROOT_CONTEXT_ID))


_reset_context()
_CONTEXT_CLASS = ExecutionContext


def context_with_data(identifier, parent=None, **kwargs):
return _CONTEXT_CLASS.context_with_data(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs)


def get_item(data_key, span=None):
# type: (str, Optional[Span]) -> Optional[Any]
def get_root() -> ExecutionContext:
return _CURRENT_CONTEXT.get().root()


def get_item(data_key: str, span: Optional[Span] = None) -> Any:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
return span._local_root._get_ctx_item(data_key)
else:
return _CURRENT_CONTEXT.get().get_item(data_key) # type: ignore
return _CURRENT_CONTEXT.get().get_item(data_key)


def get_items(data_keys, span=None):
# type: (List[str], Optional[Span]) -> Optional[Any]
def get_local_item(data_key: str, span: Optional[Span] = None) -> Any:
return _CURRENT_CONTEXT.get().get_local_item(data_key)


def get_items(data_keys: List[str], span: Optional[Span] = None) -> List[Optional[Any]]:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
return [span._local_root._get_ctx_item(key) for key in data_keys]
else:
return _CURRENT_CONTEXT.get().get_items(data_keys) # type: ignore
return _CURRENT_CONTEXT.get().get_items(data_keys)


def set_safe(data_key, data_value):
# type: (str, Optional[Any]) -> None
_CURRENT_CONTEXT.get().set_safe(data_key, data_value) # type: ignore
def set_safe(data_key: str, data_value: Optional[Any]) -> None:
_CURRENT_CONTEXT.get().set_safe(data_key, data_value)


# NB Don't call these set_* functions from `ddtrace.contrib`, only from product code!
def set_item(data_key, data_value, span=None):
# type: (str, Optional[Any], Optional[Span]) -> None
def set_item(data_key: str, data_value: Optional[Any], span: Optional[Span] = None) -> None:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
span._local_root._set_ctx_item(data_key, data_value)
else:
_CURRENT_CONTEXT.get().set_item(data_key, data_value) # type: ignore
_CURRENT_CONTEXT.get().set_item(data_key, data_value)


def set_items(keys_values, span=None):
# type: (Dict[str, Optional[Any]], Optional[Span]) -> None
def set_items(keys_values: Dict[str, Optional[Any]], span: Optional[Span] = None) -> None:
_deprecate_span_kwarg(span)
if span is not None and span._local_root is not None:
span._local_root._set_ctx_items(keys_values)
else:
_CURRENT_CONTEXT.get().set_items(keys_values) # type: ignore
_CURRENT_CONTEXT.get().set_items(keys_values)


def discard_item(data_key: str) -> None:
_CURRENT_CONTEXT.get().discard_item(data_key)


def discard_local_item(data_key: str) -> None:
_CURRENT_CONTEXT.get().discard_local_item(data_key)
41 changes: 41 additions & 0 deletions tests/internal/test_context_events_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,44 @@ def make_another_context():
task2.join()

assert results[data_key] == "right"


@pytest.mark.parametrize("nb_threads", [2, 16, 256])
def test_core_in_threads(nb_threads):
"""Test nested contexts in multiple threads with set/get/discard and global values in main context."""
import asyncio
import random

witness = object()

async def get_set_isolated(value: str):
with core.context_with_data(value):
core.set_item("key", value)
with core.context_with_data(value):
v = f"in {value}"
core.set_item("key", v)
await asyncio.sleep(random.random())
christophe-papazian marked this conversation as resolved.
Show resolved Hide resolved
assert core.get_item("key") == v
core.discard_item("key")
assert core.get_local_item("key") is None
assert core.get_item("key") == value
core.get_item("global_counter")["value"] += 1
assert core.get_item("key") == value
core.discard_item("key")
assert core.get_item("key") is None
return witness

async def create_tasks_func():
tasks = [loop.create_task(get_set_isolated(str(i))) for i in range(nb_threads)]
await asyncio.wait(tasks)
return [task.result() for task in tasks]

with core.context_with_data("main"):
core.set_item("global_counter", {"value": 0})
loop = asyncio.new_event_loop()
assert not loop.is_closed()
res = loop.run_until_complete(create_tasks_func())
assert isinstance(res, list)
assert all(s is witness for s in res)
loop.close()
assert core.get_item("global_counter")["value"] == nb_threads
Loading