Skip to content

Commit

Permalink
add indirect execution context access
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Oct 25, 2023
1 parent 98b3f32 commit 9889bd2
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 42 deletions.
71 changes: 55 additions & 16 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, ABCMeta, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import _empty as EmptyAnnotation
from typing import (
AbstractSet,
Expand Down Expand Up @@ -48,6 +50,7 @@
from dagster._core.instance import DagsterInstance
from dagster._core.log_manager import DagsterLogManager
from dagster._core.storage.dagster_run import DagsterRun
from dagster._utils.cached_method import cached_method
from dagster._utils.forked_pdb import ForkedPdb
from dagster._utils.warnings import (
deprecation_warning,
Expand Down Expand Up @@ -308,7 +311,10 @@ def my_asset(context: AssetExecutionContext):
"""
return self._step_execution_context.partition_key

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@deprecated(
breaking_version="2.0",
additional_warn_text="Use `partition_key_range` instead.",
)
@public
@property
def asset_partition_key_range(self) -> PartitionKeyRange:
Expand Down Expand Up @@ -1300,15 +1306,34 @@ def typed_event_stream_error_message(self) -> Optional[str]:
def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None:
self._step_execution_context.set_requires_typed_event_stream(error_message=error_message)

@staticmethod
def get() -> "OpExecutionContext":
ctx = _current_asset_execution_context.get()
if ctx is None:
raise DagsterInvariantViolationError("No current OpExecutionContext in scope.")
return ctx.get_op_execution_context()


class AssetExecutionContext(OpExecutionContext):
def __init__(self, step_execution_context: StepExecutionContext):
super().__init__(step_execution_context=step_execution_context)

@staticmethod
def get() -> "AssetExecutionContext":
ctx = _current_asset_execution_context.get()
if ctx is None:
raise DagsterInvariantViolationError("No current AssetExecutionContext in scope.")
return ctx

@cached_method
def get_op_execution_context(self) -> "OpExecutionContext":
return OpExecutionContext(self._step_execution_context)


def build_execution_context(
@contextmanager
def enter_execution_context(
step_context: StepExecutionContext,
) -> Union[OpExecutionContext, AssetExecutionContext]:
) -> Iterator[Union[OpExecutionContext, AssetExecutionContext]]:
"""Get the correct context based on the type of step (op or asset) and the user provided context
type annotation. Follows these rules.
Expand Down Expand Up @@ -1359,16 +1384,30 @@ def build_execution_context(
" OpExecutionContext, or left blank."
)

if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_asset_check:
return AssetExecutionContext(step_context)
if is_op_in_graph_asset or not is_sda_step:
return OpExecutionContext(step_context)
return AssetExecutionContext(step_context)
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(step_context)
return OpExecutionContext(step_context)
# Structured assuming upcoming changes to make AssetExecutionContext contain an OpExecutionContext
asset_ctx = AssetExecutionContext(step_context)
asset_token = _current_asset_execution_context.set(asset_ctx)

try:
if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_asset_check:
yield asset_ctx
elif is_op_in_graph_asset or not is_sda_step:
yield asset_ctx.get_op_execution_context()
else:
yield asset_ctx
elif context_annotation is AssetExecutionContext:
yield asset_ctx
else:
yield asset_ctx.get_op_execution_context()
finally:
_current_asset_execution_context.reset(asset_token)


_current_asset_execution_context: ContextVar[Optional[AssetExecutionContext]] = ContextVar(
"_current_asset_execution_context", default=None
)
44 changes: 26 additions & 18 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Expand Down Expand Up @@ -31,9 +30,15 @@
from dagster._core.definitions.asset_layer import AssetLayer
from dagster._core.definitions.op_definition import OpComputeFunction
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.errors import (
DagsterExecutionStepExecutionError,
DagsterInvariantViolationError,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.compute import (
AssetExecutionContext,
OpExecutionContext,
)
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand All @@ -57,7 +62,10 @@


def create_step_outputs(
node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer
node: Node,
handle: NodeHandle,
resolved_run_config: ResolvedRunConfig,
asset_layer: AssetLayer,
) -> Sequence[StepOutput]:
check.inst_param(node, "node", Node)
check.inst_param(handle, "handle", NodeHandle)
Expand Down Expand Up @@ -143,12 +151,12 @@ def gen_from_async_gen(async_gen: AsyncIterator[T]) -> Iterator[T]:


def _yield_compute_results(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: Callable
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = build_execution_context(step_context)
user_event_generator = compute_fn(context, inputs)
user_event_generator = compute_fn(compute_context, inputs)

if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
Expand Down Expand Up @@ -181,27 +189,27 @@ def _yield_compute_results(
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
if compute_context.has_events():
yield from compute_context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()
if compute_context.has_events():
yield from compute_context.consume_events()


def execute_core_compute(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
) -> Iterator[OpOutputUnion]:
"""Execute the user-specified compute for the op. Wrap in an error boundary and do
all relevant logging and metrics tracking.
"""
check.inst_param(step_context, "step_context", StepExecutionContext)
check.mapping_param(inputs, "inputs", key_type=str)

step = step_context.step

emitted_result_names = set()
for step_output in _yield_compute_results(step_context, inputs, compute_fn):
for step_output in _yield_compute_results(step_context, inputs, compute_fn, compute_context):
yield step_output
if isinstance(step_output, (DynamicOutput, Output)):
emitted_result_names.add(step_output.output_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
user_code_error_boundary,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import enter_execution_context
from dagster._core.execution.context.output import OutputContext
from dagster._core.execution.context.system import StepExecutionContext, TypeCheckContext
from dagster._core.execution.plan.compute import execute_core_compute
Expand Down Expand Up @@ -462,13 +463,14 @@ def core_dagster_event_sequence_for_step(
else:
core_gen = step_context.op_def.compute_fn

with time_execution_scope() as timer_result:
user_event_sequence = check.generator(
execute_core_compute(
step_context,
inputs,
core_gen,
)
with time_execution_scope() as timer_result, enter_execution_context(
step_context
) as compute_context:
user_event_sequence = execute_core_compute(
step_context,
inputs,
core_gen,
compute_context,
)

# It is important for this loop to be indented within the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dagster._core.definitions.asset_checks import build_asset_with_blocking_check
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster._core.storage.dagster_run import DagsterRun


Expand Down Expand Up @@ -405,3 +405,24 @@ def the_op(context: int):
@asset
def the_asset(context: int):
pass


def test_get_context():
with pytest.raises(DagsterInvariantViolationError):
OpExecutionContext.get()

@op
def o(context):
assert context == OpExecutionContext.get()

@job
def j():
o()

assert j.execute_in_process().success

@asset
def a(context: AssetExecutionContext):
assert context == AssetExecutionContext.get()

assert materialize([a]).success

0 comments on commit 9889bd2

Please sign in to comment.