Skip to content

Commit

Permalink
add indirect execution context access
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Oct 24, 2023
1 parent 4f91e20 commit b6ba500
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 41 deletions.
67 changes: 51 additions & 16 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, ABCMeta, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import _empty as EmptyAnnotation
from typing import (
AbstractSet,
Expand Down Expand Up @@ -308,7 +310,10 @@ def my_asset(context: AssetExecutionContext):
"""
return self._step_execution_context.partition_key

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@deprecated(
breaking_version="2.0",
additional_warn_text="Use `partition_key_range` instead.",
)
@public
@property
def asset_partition_key_range(self) -> PartitionKeyRange:
Expand Down Expand Up @@ -1300,15 +1305,24 @@ def typed_event_stream_error_message(self) -> Optional[str]:
def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None:
self._step_execution_context.set_requires_typed_event_stream(error_message=error_message)

@staticmethod
def get_current() -> Optional["OpExecutionContext"]:
return _current_op_execution_context.get()


class AssetExecutionContext(OpExecutionContext):
def __init__(self, step_execution_context: StepExecutionContext):
super().__init__(step_execution_context=step_execution_context)

@staticmethod
def get_current() -> Optional["AssetExecutionContext"]:
return _current_asset_execution_context.get()


def build_execution_context(
@contextmanager
def enter_execution_context(
step_context: StepExecutionContext,
) -> Union[OpExecutionContext, AssetExecutionContext]:
) -> Iterator[Union[OpExecutionContext, AssetExecutionContext]]:
"""Get the correct context based on the type of step (op or asset) and the user provided context
type annotation. Follows these rules.
Expand Down Expand Up @@ -1359,16 +1373,37 @@ def build_execution_context(
" OpExecutionContext, or left blank."
)

if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_asset_check:
return AssetExecutionContext(step_context)
if is_op_in_graph_asset or not is_sda_step:
return OpExecutionContext(step_context)
return AssetExecutionContext(step_context)
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(step_context)
return OpExecutionContext(step_context)
asset_ctx = AssetExecutionContext(step_context)
op_ctx = OpExecutionContext(step_context)

asset_token = _current_asset_execution_context.set(asset_ctx)
op_token = _current_op_execution_context.set(op_ctx)

try:
if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_asset_check:
yield asset_ctx
elif is_op_in_graph_asset or not is_sda_step:
yield op_ctx
else:
yield asset_ctx
elif context_annotation is AssetExecutionContext:
yield asset_ctx
else:
yield op_ctx
finally:
_current_asset_execution_context.reset(asset_token)
_current_op_execution_context.reset(op_token)


_current_op_execution_context: ContextVar[Optional[OpExecutionContext]] = ContextVar(
"_current_op_execution_context", default=None
)

_current_asset_execution_context: ContextVar[Optional[AssetExecutionContext]] = ContextVar(
"_current_asset_execution_context", default=None
)
44 changes: 26 additions & 18 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Expand Down Expand Up @@ -31,9 +30,15 @@
from dagster._core.definitions.asset_layer import AssetLayer
from dagster._core.definitions.op_definition import OpComputeFunction
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.errors import (
DagsterExecutionStepExecutionError,
DagsterInvariantViolationError,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.compute import (
AssetExecutionContext,
OpExecutionContext,
)
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand All @@ -57,7 +62,10 @@


def create_step_outputs(
node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer
node: Node,
handle: NodeHandle,
resolved_run_config: ResolvedRunConfig,
asset_layer: AssetLayer,
) -> Sequence[StepOutput]:
check.inst_param(node, "node", Node)
check.inst_param(handle, "handle", NodeHandle)
Expand Down Expand Up @@ -143,12 +151,12 @@ def gen_from_async_gen(async_gen: AsyncIterator[T]) -> Iterator[T]:


def _yield_compute_results(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: Callable
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = build_execution_context(step_context)
user_event_generator = compute_fn(context, inputs)
user_event_generator = compute_fn(compute_context, inputs)

if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
Expand Down Expand Up @@ -181,27 +189,27 @@ def _yield_compute_results(
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
if compute_context.has_events():
yield from compute_context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()
if compute_context.has_events():
yield from compute_context.consume_events()


def execute_core_compute(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
) -> Iterator[OpOutputUnion]:
"""Execute the user-specified compute for the op. Wrap in an error boundary and do
all relevant logging and metrics tracking.
"""
check.inst_param(step_context, "step_context", StepExecutionContext)
check.mapping_param(inputs, "inputs", key_type=str)

step = step_context.step

emitted_result_names = set()
for step_output in _yield_compute_results(step_context, inputs, compute_fn):
for step_output in _yield_compute_results(step_context, inputs, compute_fn, compute_context):
yield step_output
if isinstance(step_output, (DynamicOutput, Output)):
emitted_result_names.add(step_output.output_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
user_code_error_boundary,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import enter_execution_context
from dagster._core.execution.context.output import OutputContext
from dagster._core.execution.context.system import StepExecutionContext, TypeCheckContext
from dagster._core.execution.plan.compute import execute_core_compute
Expand Down Expand Up @@ -462,13 +463,14 @@ def core_dagster_event_sequence_for_step(
else:
core_gen = step_context.op_def.compute_fn

with time_execution_scope() as timer_result:
user_event_sequence = check.generator(
execute_core_compute(
step_context,
inputs,
core_gen,
)
with time_execution_scope() as timer_result, enter_execution_context(
step_context
) as compute_context:
user_event_sequence = execute_core_compute(
step_context,
inputs,
core_gen,
compute_context,
)

# It is important for this loop to be indented within the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,18 @@ def the_op(context: int):
@asset
def the_asset(context: int):
pass


def test_get_context():
ctx = OpExecutionContext.get_current()
assert ctx is None, ctx.job_name

@op
def o(context):
assert context == OpExecutionContext.get_current()

@job
def j():
o()

assert j.execute_in_process().success

0 comments on commit b6ba500

Please sign in to comment.