Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add indirect execution context access #14954

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, ABCMeta, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import _empty as EmptyAnnotation
from typing import (
AbstractSet,
Expand Down Expand Up @@ -48,6 +50,7 @@
from dagster._core.instance import DagsterInstance
from dagster._core.log_manager import DagsterLogManager
from dagster._core.storage.dagster_run import DagsterRun
from dagster._utils.cached_method import cached_method
from dagster._utils.forked_pdb import ForkedPdb
from dagster._utils.warnings import (
deprecation_warning,
Expand Down Expand Up @@ -1334,15 +1337,34 @@ def typed_event_stream_error_message(self) -> Optional[str]:
def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None:
self._step_execution_context.set_requires_typed_event_stream(error_message=error_message)

@staticmethod
def get() -> "OpExecutionContext":
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i updated these to just get now that they raise + slopps feedback

ctx = _current_asset_execution_context.get()
if ctx is None:
raise DagsterInvariantViolationError("No current OpExecutionContext in scope.")
return ctx.get_op_execution_context()


class AssetExecutionContext(OpExecutionContext):
def __init__(self, step_execution_context: StepExecutionContext):
super().__init__(step_execution_context=step_execution_context)

@staticmethod
def get() -> "AssetExecutionContext":
ctx = _current_asset_execution_context.get()
if ctx is None:
raise DagsterInvariantViolationError("No current AssetExecutionContext in scope.")
return ctx

@cached_method
def get_op_execution_context(self) -> "OpExecutionContext":
return OpExecutionContext(self._step_execution_context)

def build_execution_context(

@contextmanager
def enter_execution_context(
step_context: StepExecutionContext,
) -> Union[OpExecutionContext, AssetExecutionContext]:
) -> Iterator[Union[OpExecutionContext, AssetExecutionContext]]:
"""Get the correct context based on the type of step (op or asset) and the user provided context
type annotation. Follows these rules.

Expand Down Expand Up @@ -1393,16 +1415,30 @@ def build_execution_context(
" OpExecutionContext, or left blank."
)

if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_asset_check:
return AssetExecutionContext(step_context)
if is_op_in_graph_asset or not is_sda_step:
return OpExecutionContext(step_context)
return AssetExecutionContext(step_context)
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(step_context)
return OpExecutionContext(step_context)
# Structured assuming upcoming changes to make AssetExecutionContext contain an OpExecutionContext
asset_ctx = AssetExecutionContext(step_context)
asset_token = _current_asset_execution_context.set(asset_ctx)

try:
if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps not in graph-backed assets, and asset_checks
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_asset_check:
yield asset_ctx
elif is_op_in_graph_asset or not is_sda_step:
yield asset_ctx.get_op_execution_context()
else:
yield asset_ctx
elif context_annotation is AssetExecutionContext:
yield asset_ctx
else:
yield asset_ctx.get_op_execution_context()
finally:
_current_asset_execution_context.reset(asset_token)
Comment on lines +1419 to +1439
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻 thanks much more comfortable with this



_current_asset_execution_context: ContextVar[Optional[AssetExecutionContext]] = ContextVar(
"_current_asset_execution_context", default=None
)
44 changes: 26 additions & 18 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Expand Down Expand Up @@ -31,9 +30,15 @@
from dagster._core.definitions.asset_layer import AssetLayer
from dagster._core.definitions.op_definition import OpComputeFunction
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.errors import (
DagsterExecutionStepExecutionError,
DagsterInvariantViolationError,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.compute import (
AssetExecutionContext,
OpExecutionContext,
)
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand All @@ -57,7 +62,10 @@


def create_step_outputs(
node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer
node: Node,
handle: NodeHandle,
resolved_run_config: ResolvedRunConfig,
asset_layer: AssetLayer,
) -> Sequence[StepOutput]:
check.inst_param(node, "node", Node)
check.inst_param(handle, "handle", NodeHandle)
Expand Down Expand Up @@ -143,12 +151,12 @@ def gen_from_async_gen(async_gen: AsyncIterator[T]) -> Iterator[T]:


def _yield_compute_results(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: Callable
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = build_execution_context(step_context)
user_event_generator = compute_fn(context, inputs)
user_event_generator = compute_fn(compute_context, inputs)

if isinstance(user_event_generator, Output):
raise DagsterInvariantViolationError(
Expand Down Expand Up @@ -181,27 +189,27 @@ def _yield_compute_results(
),
user_event_generator,
):
if context.has_events():
yield from context.consume_events()
if compute_context.has_events():
yield from compute_context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()
if compute_context.has_events():
yield from compute_context.consume_events()


def execute_core_compute(
step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
) -> Iterator[OpOutputUnion]:
"""Execute the user-specified compute for the op. Wrap in an error boundary and do
all relevant logging and metrics tracking.
"""
check.inst_param(step_context, "step_context", StepExecutionContext)
check.mapping_param(inputs, "inputs", key_type=str)

step = step_context.step

emitted_result_names = set()
for step_output in _yield_compute_results(step_context, inputs, compute_fn):
for step_output in _yield_compute_results(step_context, inputs, compute_fn, compute_context):
yield step_output
if isinstance(step_output, (DynamicOutput, Output)):
emitted_result_names.add(step_output.output_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
user_code_error_boundary,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import enter_execution_context
from dagster._core.execution.context.output import OutputContext
from dagster._core.execution.context.system import StepExecutionContext, TypeCheckContext
from dagster._core.execution.plan.compute import execute_core_compute
Expand Down Expand Up @@ -462,13 +463,14 @@ def core_dagster_event_sequence_for_step(
else:
core_gen = step_context.op_def.compute_fn

with time_execution_scope() as timer_result:
user_event_sequence = check.generator(
execute_core_compute(
step_context,
inputs,
core_gen,
)
with time_execution_scope() as timer_result, enter_execution_context(
step_context
) as compute_context:
Comment on lines +466 to +468
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is hoisted up here to the top of the iterator call stack since context managers + active generators get goofy and if you raise an exception based on a yielded value in a frame above where the context manager is opened it does not get closed

user_event_sequence = execute_core_compute(
step_context,
inputs,
core_gen,
compute_context,
)

# It is important for this loop to be indented within the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dagster._core.definitions.asset_checks import build_asset_with_blocking_check
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster._core.storage.dagster_run import DagsterRun


Expand Down Expand Up @@ -405,3 +405,24 @@ def the_op(context: int):
@asset
def the_asset(context: int):
pass


def test_get_context():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth testing thread behavior? I guess ContextVar should handle it fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya feel pretty good about thread & coroutine concurrency
https://docs.python.org/3/library/contextvars.html

with pytest.raises(DagsterInvariantViolationError):
OpExecutionContext.get()

@op
def o(context):
assert context == OpExecutionContext.get()

@job
def j():
o()

assert j.execute_in_process().success

@asset
def a(context: AssetExecutionContext):
assert context == AssetExecutionContext.get()

assert materialize([a]).success