Skip to content

Commit

Permalink
add indirect execution context access
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Aug 18, 2023
1 parent 47735b7 commit d0752c8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 36 deletions.
28 changes: 28 additions & 0 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
AbstractSet,
Any,
Expand Down Expand Up @@ -48,6 +50,8 @@

from .system import StepExecutionContext

8


class AbstractComputeExecutionContext(ABC):
"""Base class for op context implemented by OpExecutionContext and DagstermillExecutionContext."""
Expand Down Expand Up @@ -657,3 +661,27 @@ def set_data_version(self, asset_key: AssetKey, data_version: DataVersion) -> No
# * having ops in a graph that form a graph backed asset
# so we have a single type that users can call by their preferred name where appropriate
AssetExecutionContext: TypeAlias = OpExecutionContext

_current_context: ContextVar[Optional[OpExecutionContext]] = ContextVar(
"execution_context", default=None
)


@contextmanager
def enter_execution_context(
step_context: StepExecutionContext,
) -> Iterator[OpExecutionContext]:
if step_context.is_sda_step:
ctx = AssetExecutionContext(step_context)
else:
ctx = OpExecutionContext(step_context)

token = _current_context.set(ctx)
try:
yield ctx
finally:
_current_context.reset(token)


def get_execution_context():
return _current_context.get()
72 changes: 36 additions & 36 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dagster._core.definitions.op_definition import OpComputeFunction
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.compute import enter_execution_context
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand Down Expand Up @@ -135,46 +135,46 @@ def _yield_compute_results(
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = OpExecutionContext(step_context)
user_event_generator = compute_fn(context, inputs)

if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
(
"Compute function for {described_node} returned an Output rather than "
"yielding it. The compute_fn of the {node_type} must yield "
"its results"
).format(
described_node=step_context.describe_op(),
node_type=step_context.op_def.node_type_str,
with enter_execution_context(step_context) as context:
user_event_generator = compute_fn(context, inputs)

if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
(
"Compute function for {described_node} returned an Output rather than "
"yielding it. The compute_fn of the {node_type} must yield "
"its results"
).format(
described_node=step_context.describe_op(),
node_type=step_context.op_def.node_type_str,
)
)
)

if user_event_generator is None:
return

if inspect.isasyncgen(user_event_generator):
user_event_generator = gen_from_async_gen(user_event_generator)

op_label = step_context.describe_op()
if user_event_generator is None:
return

if inspect.isasyncgen(user_event_generator):
user_event_generator = gen_from_async_gen(user_event_generator)

op_label = step_context.describe_op()

for event in iterate_with_context(
lambda: op_execution_error_boundary(
DagsterExecutionStepExecutionError,
msg_fn=lambda: f"Error occurred while executing {op_label}:",
step_context=step_context,
step_key=step_context.step.key,
op_def_name=step_context.op_def.name,
op_name=step_context.op.name,
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
yield _validate_event(event, step_context)

for event in iterate_with_context(
lambda: op_execution_error_boundary(
DagsterExecutionStepExecutionError,
msg_fn=lambda: f"Error occurred while executing {op_label}:",
step_context=step_context,
step_key=step_context.step.key,
op_def_name=step_context.op_def.name,
op_name=step_context.op.name,
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()


def execute_core_compute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dagster import OpExecutionContext, job, op
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.execution.context.compute import get_execution_context
from dagster._core.storage.dagster_run import DagsterRun


Expand All @@ -20,3 +21,17 @@ def foo():
ctx_op()

assert foo.execute_in_process().success


def test_get_context():
assert get_execution_context() is None

@op
def o(context):
assert context == get_execution_context()

@job
def j():
o()

assert j.execute_in_process().success

0 comments on commit d0752c8

Please sign in to comment.