diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index fb1aa01ac2d0f..143a4ff28d8de 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1,4 +1,6 @@ from abc import ABC, ABCMeta, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar from inspect import _empty as EmptyAnnotation from typing import ( AbstractSet, @@ -48,6 +50,7 @@ from dagster._core.instance import DagsterInstance from dagster._core.log_manager import DagsterLogManager from dagster._core.storage.dagster_run import DagsterRun +from dagster._utils.cached_method import cached_method from dagster._utils.forked_pdb import ForkedPdb from dagster._utils.warnings import ( deprecation_warning, @@ -1334,15 +1337,34 @@ def typed_event_stream_error_message(self) -> Optional[str]: def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None: self._step_execution_context.set_requires_typed_event_stream(error_message=error_message) + @staticmethod + def get() -> "OpExecutionContext": + ctx = _current_asset_execution_context.get() + if ctx is None: + raise DagsterInvariantViolationError("No current OpExecutionContext in scope.") + return ctx.get_op_execution_context() + class AssetExecutionContext(OpExecutionContext): def __init__(self, step_execution_context: StepExecutionContext): super().__init__(step_execution_context=step_execution_context) + @staticmethod + def get() -> "AssetExecutionContext": + ctx = _current_asset_execution_context.get() + if ctx is None: + raise DagsterInvariantViolationError("No current AssetExecutionContext in scope.") + return ctx + + @cached_method + def get_op_execution_context(self) -> "OpExecutionContext": + return OpExecutionContext(self._step_execution_context) -def build_execution_context( + +@contextmanager +def enter_execution_context( step_context: StepExecutionContext, -) -> Union[OpExecutionContext, AssetExecutionContext]: +) -> Iterator[Union[OpExecutionContext, AssetExecutionContext]]: """Get the correct context based on the type of step (op or asset) and the user provided context type annotation. Follows these rules. @@ -1393,16 +1415,30 @@ def build_execution_context( " OpExecutionContext, or left blank." ) - if context_annotation is EmptyAnnotation: - # if no type hint has been given, default to: - # * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks - # * OpExecutionContext for non sda steps - # * OpExecutionContext for ops in graph-backed assets - if is_asset_check: - return AssetExecutionContext(step_context) - if is_op_in_graph_asset or not is_sda_step: - return OpExecutionContext(step_context) - return AssetExecutionContext(step_context) - if context_annotation is AssetExecutionContext: - return AssetExecutionContext(step_context) - return OpExecutionContext(step_context) + # Structured assuming upcoming changes to make AssetExecutionContext contain an OpExecutionContext + asset_ctx = AssetExecutionContext(step_context) + asset_token = _current_asset_execution_context.set(asset_ctx) + + try: + if context_annotation is EmptyAnnotation: + # if no type hint has been given, default to: + # * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks + # * OpExecutionContext for non sda steps + # * OpExecutionContext for ops in graph-backed assets + if is_asset_check: + yield asset_ctx + elif is_op_in_graph_asset or not is_sda_step: + yield asset_ctx.get_op_execution_context() + else: + yield asset_ctx + elif context_annotation is AssetExecutionContext: + yield asset_ctx + else: + yield asset_ctx.get_op_execution_context() + finally: + _current_asset_execution_context.reset(asset_token) + + +_current_asset_execution_context: ContextVar[Optional[AssetExecutionContext]] = ContextVar( + "_current_asset_execution_context", default=None +) diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 8b19e99eda55e..2d67fb38bab2e 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -3,7 +3,6 @@ from typing import ( Any, AsyncIterator, - Callable, Iterator, List, Mapping, @@ -31,9 +30,15 @@ from dagster._core.definitions.asset_layer import AssetLayer from dagster._core.definitions.op_definition import OpComputeFunction from dagster._core.definitions.result import MaterializeResult -from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError +from dagster._core.errors import ( + DagsterExecutionStepExecutionError, + DagsterInvariantViolationError, +) from dagster._core.events import DagsterEvent -from dagster._core.execution.context.compute import build_execution_context +from dagster._core.execution.context.compute import ( + AssetExecutionContext, + OpExecutionContext, +) from dagster._core.execution.context.system import StepExecutionContext from dagster._core.system_config.objects import ResolvedRunConfig from dagster._utils import iterate_with_context @@ -57,7 +62,10 @@ def create_step_outputs( - node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer + node: Node, + handle: NodeHandle, + resolved_run_config: ResolvedRunConfig, + asset_layer: AssetLayer, ) -> Sequence[StepOutput]: check.inst_param(node, "node", Node) check.inst_param(handle, "handle", NodeHandle) @@ -143,12 +151,12 @@ def gen_from_async_gen(async_gen: AsyncIterator[T]) -> Iterator[T]: def _yield_compute_results( - step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: Callable + step_context: StepExecutionContext, + inputs: Mapping[str, Any], + compute_fn: OpComputeFunction, + compute_context: Union[OpExecutionContext, AssetExecutionContext], ) -> Iterator[OpOutputUnion]: - check.inst_param(step_context, "step_context", StepExecutionContext) - - context = build_execution_context(step_context) - user_event_generator = compute_fn(context, inputs) + user_event_generator = compute_fn(compute_context, inputs) if isinstance(user_event_generator, Output): raise DagsterInvariantViolationError( @@ -181,27 +189,27 @@ def _yield_compute_results( ), user_event_generator, ): - if context.has_events(): - yield from context.consume_events() + if compute_context.has_events(): + yield from compute_context.consume_events() yield _validate_event(event, step_context) - if context.has_events(): - yield from context.consume_events() + if compute_context.has_events(): + yield from compute_context.consume_events() def execute_core_compute( - step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction + step_context: StepExecutionContext, + inputs: Mapping[str, Any], + compute_fn: OpComputeFunction, + compute_context: Union[OpExecutionContext, AssetExecutionContext], ) -> Iterator[OpOutputUnion]: """Execute the user-specified compute for the op. Wrap in an error boundary and do all relevant logging and metrics tracking. """ - check.inst_param(step_context, "step_context", StepExecutionContext) - check.mapping_param(inputs, "inputs", key_type=str) - step = step_context.step emitted_result_names = set() - for step_output in _yield_compute_results(step_context, inputs, compute_fn): + for step_output in _yield_compute_results(step_context, inputs, compute_fn, compute_context): yield step_output if isinstance(step_output, (DynamicOutput, Output)): emitted_result_names.add(step_output.output_name) diff --git a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py index cceb87ee738aa..78dd4fa3ee561 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/execute_step.py +++ b/python_modules/dagster/dagster/_core/execution/plan/execute_step.py @@ -57,6 +57,7 @@ user_code_error_boundary, ) from dagster._core.events import DagsterEvent +from dagster._core.execution.context.compute import enter_execution_context from dagster._core.execution.context.output import OutputContext from dagster._core.execution.context.system import StepExecutionContext, TypeCheckContext from dagster._core.execution.plan.compute import execute_core_compute @@ -462,13 +463,14 @@ def core_dagster_event_sequence_for_step( else: core_gen = step_context.op_def.compute_fn - with time_execution_scope() as timer_result: - user_event_sequence = check.generator( - execute_core_compute( - step_context, - inputs, - core_gen, - ) + with time_execution_scope() as timer_result, enter_execution_context( + step_context + ) as compute_context: + user_event_sequence = execute_core_compute( + step_context, + inputs, + core_gen, + compute_context, ) # It is important for this loop to be indented within the diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index c0214049f5d31..6bd4da58fb3d2 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -23,7 +23,7 @@ from dagster._core.definitions.asset_checks import build_asset_with_blocking_check from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.op_definition import OpDefinition -from dagster._core.errors import DagsterInvalidDefinitionError +from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError from dagster._core.storage.dagster_run import DagsterRun @@ -405,3 +405,24 @@ def the_op(context: int): @asset def the_asset(context: int): pass + + +def test_get_context(): + with pytest.raises(DagsterInvariantViolationError): + OpExecutionContext.get() + + @op + def o(context): + assert context == OpExecutionContext.get() + + @job + def j(): + o() + + assert j.execute_in_process().success + + @asset + def a(context: AssetExecutionContext): + assert context == AssetExecutionContext.get() + + assert materialize([a]).success