Skip to content

Commit

Permalink
add indirect execution context access
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Oct 24, 2023
1 parent 4f91e20 commit d13f3fc
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, ABCMeta, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import _empty as EmptyAnnotation
from typing import (
AbstractSet,
Expand Down Expand Up @@ -308,7 +310,10 @@ def my_asset(context: AssetExecutionContext):
"""
return self._step_execution_context.partition_key

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@deprecated(
breaking_version="2.0",
additional_warn_text="Use `partition_key_range` instead.",
)
@public
@property
def asset_partition_key_range(self) -> PartitionKeyRange:
Expand Down Expand Up @@ -1372,3 +1377,25 @@ def build_execution_context(
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(step_context)
return OpExecutionContext(step_context)


_current_context: ContextVar[Optional[OpExecutionContext]] = ContextVar(
"execution_context", default=None
)


@contextmanager
def enter_execution_context(
step_context: StepExecutionContext,
) -> Iterator[OpExecutionContext]:
ctx = build_execution_context(step_context)

token = _current_context.set(ctx)
try:
yield ctx
finally:
_current_context.reset(token)


def get_execution_context():
return _current_context.get()
80 changes: 44 additions & 36 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
from dagster._core.definitions.asset_layer import AssetLayer
from dagster._core.definitions.op_definition import OpComputeFunction
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.errors import (
DagsterExecutionStepExecutionError,
DagsterInvariantViolationError,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.compute import enter_execution_context
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand All @@ -57,7 +60,10 @@


def create_step_outputs(
node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer
node: Node,
handle: NodeHandle,
resolved_run_config: ResolvedRunConfig,
asset_layer: AssetLayer,
) -> Sequence[StepOutput]:
check.inst_param(node, "node", Node)
check.inst_param(handle, "handle", NodeHandle)
Expand Down Expand Up @@ -147,50 +153,52 @@ def _yield_compute_results(
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = build_execution_context(step_context)
user_event_generator = compute_fn(context, inputs)
with enter_execution_context(step_context) as context:
user_event_generator = compute_fn(context, inputs)

if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
(
"Compute function for {described_node} returned an Output rather than "
"yielding it. The compute_fn of the {node_type} must yield "
"its results"
).format(
described_node=step_context.describe_op(),
node_type=step_context.op_def.node_type_str,
if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
(
"Compute function for {described_node} returned an Output rather than "
"yielding it. The compute_fn of the {node_type} must yield "
"its results"
).format(
described_node=step_context.describe_op(),
node_type=step_context.op_def.node_type_str,
)
)
)

if user_event_generator is None:
return
if user_event_generator is None:
return

if inspect.isasyncgen(user_event_generator):
user_event_generator = gen_from_async_gen(user_event_generator)
if inspect.isasyncgen(user_event_generator):
user_event_generator = gen_from_async_gen(user_event_generator)

op_label = step_context.describe_op()
op_label = step_context.describe_op()

for event in iterate_with_context(
lambda: op_execution_error_boundary(
DagsterExecutionStepExecutionError,
msg_fn=lambda: f"Error occurred while executing {op_label}:",
step_context=step_context,
step_key=step_context.step.key,
op_def_name=step_context.op_def.name,
op_name=step_context.op.name,
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
yield _validate_event(event, step_context)

for event in iterate_with_context(
lambda: op_execution_error_boundary(
DagsterExecutionStepExecutionError,
msg_fn=lambda: f"Error occurred while executing {op_label}:",
step_context=step_context,
step_key=step_context.step.key,
op_def_name=step_context.op_def.name,
op_name=step_context.op.name,
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()


def execute_core_compute(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
) -> Iterator[OpOutputUnion]:
"""Execute the user-specified compute for the op. Wrap in an error boundary and do
all relevant logging and metrics tracking.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.execution.context.compute import get_execution_context
from dagster._core.storage.dagster_run import DagsterRun


Expand Down Expand Up @@ -405,3 +406,17 @@ def the_op(context: int):
@asset
def the_asset(context: int):
pass


def test_get_context():
assert get_execution_context() is None

@op
def o(context):
assert context == get_execution_context()

@job
def j():
o()

assert j.execute_in_process().success

0 comments on commit d13f3fc

Please sign in to comment.