From d0752c81ccc92ccfd3fc152dcfbd90901ec416bf Mon Sep 17 00:00:00 2001 From: alangenfeld Date: Mon, 26 Jun 2023 12:08:02 -0500 Subject: [PATCH] add indirect execution context access --- .../_core/execution/context/compute.py | 28 ++++++++ .../dagster/_core/execution/plan/compute.py | 72 +++++++++---------- .../execution_tests/test_context.py | 15 ++++ 3 files changed, 79 insertions(+), 36 deletions(-) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 7b5d7d746f174..d6bd5ac40f86f 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1,4 +1,6 @@ from abc import ABC, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar from typing import ( AbstractSet, Any, @@ -48,6 +50,8 @@ from .system import StepExecutionContext +8 + class AbstractComputeExecutionContext(ABC): """Base class for op context implemented by OpExecutionContext and DagstermillExecutionContext.""" @@ -657,3 +661,27 @@ def set_data_version(self, asset_key: AssetKey, data_version: DataVersion) -> No # * having ops in a graph that form a graph backed asset # so we have a single type that users can call by their preferred name where appropriate AssetExecutionContext: TypeAlias = OpExecutionContext + +_current_context: ContextVar[Optional[OpExecutionContext]] = ContextVar( + "execution_context", default=None +) + + +@contextmanager +def enter_execution_context( + step_context: StepExecutionContext, +) -> Iterator[OpExecutionContext]: + if step_context.is_sda_step: + ctx = AssetExecutionContext(step_context) + else: + ctx = OpExecutionContext(step_context) + + token = _current_context.set(ctx) + try: + yield ctx + finally: + _current_context.reset(token) + + +def get_execution_context(): + return _current_context.get() diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 726cc858e7505..b95a272e9e949 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -29,7 +29,7 @@ from dagster._core.definitions.op_definition import OpComputeFunction from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError from dagster._core.events import DagsterEvent -from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.execution.context.compute import enter_execution_context from dagster._core.execution.context.system import StepExecutionContext from dagster._core.system_config.objects import ResolvedRunConfig from dagster._utils import iterate_with_context @@ -135,46 +135,46 @@ def _yield_compute_results( ) -> Iterator[OpOutputUnion]: check.inst_param(step_context, "step_context", StepExecutionContext) - context = OpExecutionContext(step_context) - user_event_generator = compute_fn(context, inputs) - - if isinstance(user_event_generator, Output): - raise DagsterInvariantViolationError( - ( - "Compute function for {described_node} returned an Output rather than " - "yielding it. The compute_fn of the {node_type} must yield " - "its results" - ).format( - described_node=step_context.describe_op(), - node_type=step_context.op_def.node_type_str, + with enter_execution_context(step_context) as context: + user_event_generator = compute_fn(context, inputs) + + if isinstance(user_event_generator, Output): + raise DagsterInvariantViolationError( + ( + "Compute function for {described_node} returned an Output rather than " + "yielding it. The compute_fn of the {node_type} must yield " + "its results" + ).format( + described_node=step_context.describe_op(), + node_type=step_context.op_def.node_type_str, + ) ) - ) - - if user_event_generator is None: - return - - if inspect.isasyncgen(user_event_generator): - user_event_generator = gen_from_async_gen(user_event_generator) - op_label = step_context.describe_op() + if user_event_generator is None: + return + + if inspect.isasyncgen(user_event_generator): + user_event_generator = gen_from_async_gen(user_event_generator) + + op_label = step_context.describe_op() + + for event in iterate_with_context( + lambda: op_execution_error_boundary( + DagsterExecutionStepExecutionError, + msg_fn=lambda: f"Error occurred while executing {op_label}:", + step_context=step_context, + step_key=step_context.step.key, + op_def_name=step_context.op_def.name, + op_name=step_context.op.name, + ), + user_event_generator, + ): + if context.has_events(): + yield from context.consume_events() + yield _validate_event(event, step_context) - for event in iterate_with_context( - lambda: op_execution_error_boundary( - DagsterExecutionStepExecutionError, - msg_fn=lambda: f"Error occurred while executing {op_label}:", - step_context=step_context, - step_key=step_context.step.key, - op_def_name=step_context.op_def.name, - op_name=step_context.op.name, - ), - user_event_generator, - ): if context.has_events(): yield from context.consume_events() - yield _validate_event(event, step_context) - - if context.has_events(): - yield from context.consume_events() def execute_core_compute( diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index bd4d523dfa6d5..c92c825c1c2c5 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -2,6 +2,7 @@ from dagster import OpExecutionContext, job, op from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.op_definition import OpDefinition +from dagster._core.execution.context.compute import get_execution_context from dagster._core.storage.dagster_run import DagsterRun @@ -20,3 +21,17 @@ def foo(): ctx_op() assert foo.execute_in_process().success + + +def test_get_context(): + assert get_execution_context() is None + + @op + def o(context): + assert context == get_execution_context() + + @job + def j(): + o() + + assert j.execute_in_process().success