diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index edaccaa2f041c..251bf4b2c8625 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1,4 +1,6 @@ from abc import ABC, ABCMeta, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar from inspect import _empty as EmptyAnnotation from typing import ( AbstractSet, @@ -308,7 +310,10 @@ def my_asset(context: AssetExecutionContext): """ return self._step_execution_context.partition_key - @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") + @deprecated( + breaking_version="2.0", + additional_warn_text="Use `partition_key_range` instead.", + ) @public @property def asset_partition_key_range(self) -> PartitionKeyRange: @@ -1300,15 +1305,24 @@ def typed_event_stream_error_message(self) -> Optional[str]: def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None: self._step_execution_context.set_requires_typed_event_stream(error_message=error_message) + @staticmethod + def get_current() -> Optional["OpExecutionContext"]: + return _current_op_execution_context.get() + class AssetExecutionContext(OpExecutionContext): def __init__(self, step_execution_context: StepExecutionContext): super().__init__(step_execution_context=step_execution_context) + @staticmethod + def get_current() -> Optional["AssetExecutionContext"]: + return _current_asset_execution_context.get() + -def build_execution_context( +@contextmanager +def enter_execution_context( step_context: StepExecutionContext, -) -> Union[OpExecutionContext, AssetExecutionContext]: +) -> Iterator[Union[OpExecutionContext, AssetExecutionContext]]: """Get the correct context based on the type of step (op or asset) and the user provided context type annotation. Follows these rules. @@ -1359,16 +1373,37 @@ def build_execution_context( " OpExecutionContext, or left blank." ) - if context_annotation is EmptyAnnotation: - # if no type hint has been given, default to: - # * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks - # * OpExecutionContext for non sda steps - # * OpExecutionContext for ops in graph-backed assets - if is_asset_check: - return AssetExecutionContext(step_context) - if is_op_in_graph_asset or not is_sda_step: - return OpExecutionContext(step_context) - return AssetExecutionContext(step_context) - if context_annotation is AssetExecutionContext: - return AssetExecutionContext(step_context) - return OpExecutionContext(step_context) + asset_ctx = AssetExecutionContext(step_context) + op_ctx = OpExecutionContext(step_context) + + asset_token = _current_asset_execution_context.set(asset_ctx) + op_token = _current_op_execution_context.set(op_ctx) + + try: + if context_annotation is EmptyAnnotation: + # if no type hint has been given, default to: + # * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks + # * OpExecutionContext for non sda steps + # * OpExecutionContext for ops in graph-backed assets + if is_asset_check: + yield asset_ctx + elif is_op_in_graph_asset or not is_sda_step: + yield op_ctx + else: + yield asset_ctx + elif context_annotation is AssetExecutionContext: + yield asset_ctx + else: + yield op_ctx + finally: + _current_asset_execution_context.reset(asset_token) + _current_op_execution_context.reset(op_token) + + +_current_op_execution_context: ContextVar[Optional[OpExecutionContext]] = ContextVar( + "_current_op_execution_context", default=None +) + +_current_asset_execution_context: ContextVar[Optional[AssetExecutionContext]] = ContextVar( + "_current_asset_execution_context", default=None +) diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 8b19e99eda55e..2d67fb38bab2e 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -3,7 +3,6 @@ from typing import ( Any, AsyncIterator, - Callable, Iterator, List, Mapping, @@ -31,9 +30,15 @@ from dagster._core.definitions.asset_layer import AssetLayer from dagster._core.definitions.op_definition import OpComputeFunction from dagster._core.definitions.result import MaterializeResult -from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError +from dagster._core.errors import ( + DagsterExecutionStepExecutionError, + DagsterInvariantViolationError, +) from dagster._core.events import DagsterEvent -from dagster._core.execution.context.compute import build_execution_context +from dagster._core.execution.context.compute import ( + AssetExecutionContext, + OpExecutionContext, +) from dagster._core.execution.context.system import StepExecutionContext from dagster._core.system_config.objects import ResolvedRunConfig from dagster._utils import iterate_with_context @@ -57,7 +62,10 @@ def create_step_outputs( - node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer + node: Node, + handle: NodeHandle, + resolved_run_config: ResolvedRunConfig, + asset_layer: AssetLayer, ) -> Sequence[StepOutput]: check.inst_param(node, "node", Node) check.inst_param(handle, "handle", NodeHandle) @@ -143,12 +151,12 @@ def gen_from_async_gen(async_gen: AsyncIterator[T]) -> Iterator[T]: def _yield_compute_results( - step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: Callable + step_context: StepExecutionContext, + inputs: Mapping[str, Any], + compute_fn: OpComputeFunction, + compute_context: Union[OpExecutionContext, AssetExecutionContext], ) -> Iterator[OpOutputUnion]: - check.inst_param(step_context, "step_context", StepExecutionContext) - - context = build_execution_context(step_context) - user_event_generator = compute_fn(context, inputs) + user_event_generator = compute_fn(compute_context, inputs) if isinstance(user_event_generator, Output): raise DagsterInvariantViolationError( @@ -181,27 +189,27 @@ def _yield_compute_results( ), user_event_generator, ): - if context.has_events(): - yield from context.consume_events() + if compute_context.has_events(): + yield from compute_context.consume_events() yield _validate_event(event, step_context) - if context.has_events(): - yield from context.consume_events() + if compute_context.has_events(): + yield from compute_context.consume_events() def execute_core_compute( - step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction + step_context: StepExecutionContext, + inputs: Mapping[str, Any], + compute_fn: OpComputeFunction, + compute_context: Union[OpExecutionContext, AssetExecutionContext], ) -> Iterator[OpOutputUnion]: """Execute the user-specified compute for the op. Wrap in an error boundary and do all relevant logging and metrics tracking. """ - check.inst_param(step_context, "step_context", StepExecutionContext) - check.mapping_param(inputs, "inputs", key_type=str) - step = step_context.step emitted_result_names = set() - for step_output in _yield_compute_results(step_context, inputs, compute_fn): + for step_output in _yield_compute_results(step_context, inputs, compute_fn, compute_context): yield step_output if isinstance(step_output, (DynamicOutput, Output)): emitted_result_names.add(step_output.output_name) diff --git a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py index cceb87ee738aa..78dd4fa3ee561 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py +++ b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py @@ -57,6 +57,7 @@ user_code_error_boundary, ) from dagster._core.events import DagsterEvent +from dagster._core.execution.context.compute import enter_execution_context from dagster._core.execution.context.output import OutputContext from dagster._core.execution.context.system import StepExecutionContext, TypeCheckContext from dagster._core.execution.plan.compute import execute_core_compute @@ -462,13 +463,14 @@ def core_dagster_event_sequence_for_step( else: core_gen = step_context.op_def.compute_fn - with time_execution_scope() as timer_result: - user_event_sequence = check.generator( - execute_core_compute( - step_context, - inputs, - core_gen, - ) + with time_execution_scope() as timer_result, enter_execution_context( + step_context + ) as compute_context: + user_event_sequence = execute_core_compute( + step_context, + inputs, + core_gen, + compute_context, ) # It is important for this loop to be indented within the diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index c0214049f5d31..5d7d661098b39 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -405,3 +405,18 @@ def the_op(context: int): @asset def the_asset(context: int): pass + + +def test_get_context(): + ctx = OpExecutionContext.get_current() + assert ctx is None, ctx.job_name + + @op + def o(context): + assert context == OpExecutionContext.get_current() + + @job + def j(): + o() + + assert j.execute_in_process().success