diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index edaccaa2f041c..f607eea6ca56b 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1,4 +1,6 @@ from abc import ABC, ABCMeta, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar from inspect import _empty as EmptyAnnotation from typing import ( AbstractSet, @@ -308,7 +310,10 @@ def my_asset(context: AssetExecutionContext): """ return self._step_execution_context.partition_key - @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") + @deprecated( + breaking_version="2.0", + additional_warn_text="Use `partition_key_range` instead.", + ) @public @property def asset_partition_key_range(self) -> PartitionKeyRange: @@ -1372,3 +1377,25 @@ def build_execution_context( if context_annotation is AssetExecutionContext: return AssetExecutionContext(step_context) return OpExecutionContext(step_context) + + +_current_context: ContextVar[Optional[OpExecutionContext]] = ContextVar( + "execution_context", default=None +) + + +@contextmanager +def enter_execution_context( + step_context: StepExecutionContext, +) -> Iterator[OpExecutionContext]: + ctx = build_execution_context(step_context) + + token = _current_context.set(ctx) + try: + yield ctx + finally: + _current_context.reset(token) + + +def get_execution_context(): + return _current_context.get() diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 8b19e99eda55e..8aa2131262f42 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -31,9 +31,12 @@ from dagster._core.definitions.asset_layer import AssetLayer from dagster._core.definitions.op_definition import OpComputeFunction from dagster._core.definitions.result import MaterializeResult -from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError +from dagster._core.errors import ( + DagsterExecutionStepExecutionError, + DagsterInvariantViolationError, +) from dagster._core.events import DagsterEvent -from dagster._core.execution.context.compute import build_execution_context +from dagster._core.execution.context.compute import enter_execution_context from dagster._core.execution.context.system import StepExecutionContext from dagster._core.system_config.objects import ResolvedRunConfig from dagster._utils import iterate_with_context @@ -57,7 +60,10 @@ def create_step_outputs( - node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer + node: Node, + handle: NodeHandle, + resolved_run_config: ResolvedRunConfig, + asset_layer: AssetLayer, ) -> Sequence[StepOutput]: check.inst_param(node, "node", Node) check.inst_param(handle, "handle", NodeHandle) @@ -147,50 +153,52 @@ def _yield_compute_results( ) -> Iterator[OpOutputUnion]: check.inst_param(step_context, "step_context", StepExecutionContext) - context = build_execution_context(step_context) - user_event_generator = compute_fn(context, inputs) + with enter_execution_context(step_context) as context: + user_event_generator = compute_fn(context, inputs) - if isinstance(user_event_generator, Output): - raise DagsterInvariantViolationError( - ( - "Compute function for {described_node} returned an Output rather than " - "yielding it. The compute_fn of the {node_type} must yield " - "its results" - ).format( - described_node=step_context.describe_op(), - node_type=step_context.op_def.node_type_str, + if isinstance(user_event_generator, Output): + raise DagsterInvariantViolationError( + ( + "Compute function for {described_node} returned an Output rather than " + "yielding it. The compute_fn of the {node_type} must yield " + "its results" + ).format( + described_node=step_context.describe_op(), + node_type=step_context.op_def.node_type_str, + ) ) - ) - if user_event_generator is None: - return + if user_event_generator is None: + return - if inspect.isasyncgen(user_event_generator): - user_event_generator = gen_from_async_gen(user_event_generator) + if inspect.isasyncgen(user_event_generator): + user_event_generator = gen_from_async_gen(user_event_generator) - op_label = step_context.describe_op() + op_label = step_context.describe_op() + + for event in iterate_with_context( + lambda: op_execution_error_boundary( + DagsterExecutionStepExecutionError, + msg_fn=lambda: f"Error occurred while executing {op_label}:", + step_context=step_context, + step_key=step_context.step.key, + op_def_name=step_context.op_def.name, + op_name=step_context.op.name, + ), + user_event_generator, + ): + if context.has_events(): + yield from context.consume_events() + yield _validate_event(event, step_context) - for event in iterate_with_context( - lambda: op_execution_error_boundary( - DagsterExecutionStepExecutionError, - msg_fn=lambda: f"Error occurred while executing {op_label}:", - step_context=step_context, - step_key=step_context.step.key, - op_def_name=step_context.op_def.name, - op_name=step_context.op.name, - ), - user_event_generator, - ): if context.has_events(): yield from context.consume_events() - yield _validate_event(event, step_context) - - if context.has_events(): - yield from context.consume_events() def execute_core_compute( - step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction + step_context: StepExecutionContext, + inputs: Mapping[str, Any], + compute_fn: OpComputeFunction, ) -> Iterator[OpOutputUnion]: """Execute the user-specified compute for the op. Wrap in an error boundary and do all relevant logging and metrics tracking. diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index c0214049f5d31..cf344ca4474d2 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -24,6 +24,7 @@ from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.op_definition import OpDefinition from dagster._core.errors import DagsterInvalidDefinitionError +from dagster._core.execution.context.compute import get_execution_context from dagster._core.storage.dagster_run import DagsterRun @@ -405,3 +406,17 @@ def the_op(context: int): @asset def the_asset(context: int): pass + + +def test_get_context(): + assert get_execution_context() is None + + @op + def o(context): + assert context == get_execution_context() + + @job + def j(): + o() + + assert j.execute_in_process().success