diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index edaccaa2f041c..dd5dd44e178a5 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1,5 +1,8 @@ from abc import ABC, ABCMeta, abstractmethod from inspect import _empty as EmptyAnnotation +from abc import ABC, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar from typing import ( AbstractSet, Any, @@ -126,7 +129,9 @@ def __instancecheck__(cls, instance) -> bool: return super().__instancecheck__(instance) -class OpExecutionContext(AbstractComputeExecutionContext, metaclass=OpExecutionContextMetaClass): +class OpExecutionContext( + AbstractComputeExecutionContext, metaclass=OpExecutionContextMetaClass +): """The ``context`` object that can be made available as the first argument to the function used for computing an op or asset. @@ -308,7 +313,10 @@ def my_asset(context: AssetExecutionContext): """ return self._step_execution_context.partition_key - @deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.") + @deprecated( + breaking_version="2.0", + additional_warn_text="Use `partition_key_range` instead.", + ) @public @property def asset_partition_key_range(self) -> PartitionKeyRange: @@ -432,10 +440,14 @@ def log_materialization(context): DagsterEvent.asset_materialization(self._step_execution_context, event) ) elif isinstance(event, AssetObservation): - self._events.append(DagsterEvent.asset_observation(self._step_execution_context, event)) + self._events.append( + DagsterEvent.asset_observation(self._step_execution_context, event) + ) elif isinstance(event, ExpectationResult): self._events.append( - DagsterEvent.step_expectation_result(self._step_execution_context, event) + DagsterEvent.step_expectation_result( + self._step_execution_context, event + ) ) else: check.failed(f"Unexpected event {event}") @@ -568,7 +580,10 @@ def has_asset_checks_def(self) -> bool: Returns: bool: True if there is a backing AssetChecksDefinition for the current execution, otherwise False. """ - return self.job_def.asset_layer.asset_checks_def_for_node(self.node_handle) is not None + return ( + self.job_def.asset_layer.asset_checks_def_for_node(self.node_handle) + is not None + ) @public @property @@ -579,7 +594,9 @@ def asset_checks_def(self) -> AssetChecksDefinition: Returns: AssetChecksDefinition. """ - asset_checks_def = self.job_def.asset_layer.asset_checks_def_for_node(self.node_handle) + asset_checks_def = self.job_def.asset_layer.asset_checks_def_for_node( + self.node_handle + ) if asset_checks_def is None: raise DagsterInvalidPropertyError( f"Op '{self.op.name}' does not have an asset checks definition." @@ -594,7 +611,9 @@ def selected_asset_check_keys(self) -> AbstractSet[AssetCheckKey]: return self.assets_def.check_keys if self.has_asset_checks_def: - check.failed("Subset selection is not yet supported within an AssetChecksDefinition") + check.failed( + "Subset selection is not yet supported within an AssetChecksDefinition" + ) return set() @@ -635,7 +654,9 @@ def asset_key_for_output(self, output_name: str = "result") -> AssetKey: @public def output_for_asset_key(self, asset_key: AssetKey) -> str: """Return the output name for the corresponding asset key.""" - node_output_handle = self.job_def.asset_layer.node_output_handle_for_asset(asset_key) + node_output_handle = self.job_def.asset_layer.node_output_handle_for_asset( + asset_key + ) if node_output_handle is None: check.failed(f"Asset key '{asset_key}' has no output") else: @@ -710,7 +731,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): return self._step_execution_context.asset_partition_key_for_output(output_name) @public - def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: + def asset_partitions_time_window_for_output( + self, output_name: str = "result" + ) -> TimeWindow: """The time window for the partitions of the output asset. If you want to write your asset to support running a backfill of several partitions in a single run, @@ -782,7 +805,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # TimeWindow("2023-08-21", "2023-08-26") """ - return self._step_execution_context.asset_partitions_time_window_for_output(output_name) + return self._step_execution_context.asset_partitions_time_window_for_output( + output_name + ) @public def asset_partition_key_range_for_output( @@ -845,7 +870,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # PartitionKeyRange(start="2023-08-21", end="2023-08-25") """ - return self._step_execution_context.asset_partition_key_range_for_output(output_name) + return self._step_execution_context.asset_partition_key_range_for_output( + output_name + ) @public def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: @@ -908,7 +935,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): """ - return self._step_execution_context.asset_partition_key_range_for_input(input_name) + return self._step_execution_context.asset_partition_key_range_for_input( + input_name + ) @public def asset_partition_key_for_input(self, input_name: str) -> str: @@ -954,7 +983,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): return self._step_execution_context.asset_partition_key_for_input(input_name) @public - def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: + def asset_partitions_def_for_output( + self, output_name: str = "result" + ) -> PartitionsDefinition: """The PartitionsDefinition on the asset corresponding to this output. Args: @@ -994,8 +1025,10 @@ def a_multi_asset(context: AssetExecutionContext): """ asset_key = self.asset_key_for_output(output_name) - result = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset( - asset_key + result = ( + self._step_execution_context.job_def.asset_layer.partitions_def_for_asset( + asset_key + ) ) if result is None: raise DagsterInvariantViolationError( @@ -1034,8 +1067,10 @@ def upstream_asset(context: AssetExecutionContext, upstream_asset): """ asset_key = self.asset_key_for_input(input_name) - result = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset( - asset_key + result = ( + self._step_execution_context.job_def.asset_layer.partitions_def_for_asset( + asset_key + ) ) if result is None: raise DagsterInvariantViolationError( @@ -1046,7 +1081,9 @@ def upstream_asset(context: AssetExecutionContext, upstream_asset): return result @public - def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]: + def asset_partition_keys_for_output( + self, output_name: str = "result" + ) -> Sequence[str]: """Returns a list of the partition keys for the given output. If you want to write your asset to support running a backfill of several partitions in a single run, @@ -1103,8 +1140,12 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # running a backfill of the 2023-08-21 through 2023-08-25 partitions of this asset will log: # ["2023-08-21", "2023-08-22", "2023-08-23", "2023-08-24", "2023-08-25"] """ - return self.asset_partitions_def_for_output(output_name).get_partition_keys_in_range( - self._step_execution_context.asset_partition_key_range_for_output(output_name), + return self.asset_partitions_def_for_output( + output_name + ).get_partition_keys_in_range( + self._step_execution_context.asset_partition_key_range_for_output( + output_name + ), dynamic_partitions_store=self.instance, ) @@ -1174,7 +1215,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): ) @public - def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow: + def asset_partitions_time_window_for_input( + self, input_name: str = "result" + ) -> TimeWindow: """The time window for the partitions of the input asset. If you want to write your asset to support running a backfill of several partitions in a single run, @@ -1247,7 +1290,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # TimeWindow("2023-08-20", "2023-08-25") """ - return self._step_execution_context.asset_partitions_time_window_for_input(input_name) + return self._step_execution_context.asset_partitions_time_window_for_input( + input_name + ) @public @experimental @@ -1265,7 +1310,9 @@ def get_asset_provenance(self, asset_key: AssetKey) -> Optional[DataProvenance]: record = self.instance.get_latest_data_version_record(asset_key) return ( - None if record is None else extract_data_provenance_from_entry(record.event_log_entry) + None + if record is None + else extract_data_provenance_from_entry(record.event_log_entry) ) def set_data_version(self, asset_key: AssetKey, data_version: DataVersion) -> None: @@ -1297,8 +1344,12 @@ def requires_typed_event_stream(self) -> bool: def typed_event_stream_error_message(self) -> Optional[str]: return self._step_execution_context.typed_event_stream_error_message - def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None: - self._step_execution_context.set_requires_typed_event_stream(error_message=error_message) + def set_requires_typed_event_stream( + self, *, error_message: Optional[str] = None + ) -> None: + self._step_execution_context.set_requires_typed_event_stream( + error_message=error_message + ) class AssetExecutionContext(OpExecutionContext): @@ -1372,3 +1423,25 @@ def build_execution_context( if context_annotation is AssetExecutionContext: return AssetExecutionContext(step_context) return OpExecutionContext(step_context) + + +_current_context: ContextVar[Optional[OpExecutionContext]] = ContextVar( + "execution_context", default=None +) + + +@contextmanager +def enter_execution_context( + step_context: StepExecutionContext, +) -> Iterator[OpExecutionContext]: + ctx = build_execution_context(step_context) + + token = _current_context.set(ctx) + try: + yield ctx + finally: + _current_context.reset(token) + + +def get_execution_context(): + return _current_context.get() diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 8b19e99eda55e..1bc72b74edbd6 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -31,9 +31,12 @@ from dagster._core.definitions.asset_layer import AssetLayer from dagster._core.definitions.op_definition import OpComputeFunction from dagster._core.definitions.result import MaterializeResult -from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError +from dagster._core.errors import ( + DagsterExecutionStepExecutionError, + DagsterInvariantViolationError, +) from dagster._core.events import DagsterEvent -from dagster._core.execution.context.compute import build_execution_context +from dagster._core.execution.context.compute import enter_execution_context from dagster._core.execution.context.system import StepExecutionContext from dagster._core.system_config.objects import ResolvedRunConfig from dagster._utils import iterate_with_context @@ -57,7 +60,10 @@ def create_step_outputs( - node: Node, handle: NodeHandle, resolved_run_config: ResolvedRunConfig, asset_layer: AssetLayer + node: Node, + handle: NodeHandle, + resolved_run_config: ResolvedRunConfig, + asset_layer: AssetLayer, ) -> Sequence[StepOutput]: check.inst_param(node, "node", Node) check.inst_param(handle, "handle", NodeHandle) @@ -84,9 +90,15 @@ def create_step_outputs( is_dynamic=output_def.is_dynamic, is_asset=asset_info is not None, should_materialize=output_def.name in config_output_names, - asset_key=asset_info.key if asset_info and asset_info.is_required else None, - is_asset_partitioned=bool(asset_info.partitions_def) if asset_info else False, - asset_check_key=asset_layer.asset_check_key_for_output(handle, name), + asset_key=asset_info.key + if asset_info and asset_info.is_required + else None, + is_asset_partitioned=bool(asset_info.partitions_def) + if asset_info + else False, + asset_check_key=asset_layer.asset_check_key_for_output( + handle, name + ), ), ) ) @@ -147,50 +159,52 @@ def _yield_compute_results( ) -> Iterator[OpOutputUnion]: check.inst_param(step_context, "step_context", StepExecutionContext) - context = build_execution_context(step_context) - user_event_generator = compute_fn(context, inputs) + with enter_execution_context(step_context) as context: + user_event_generator = compute_fn(context, inputs) - if isinstance(user_event_generator, Output): - raise DagsterInvariantViolationError( - ( - "Compute function for {described_node} returned an Output rather than " - "yielding it. The compute_fn of the {node_type} must yield " - "its results" - ).format( - described_node=step_context.describe_op(), - node_type=step_context.op_def.node_type_str, + if isinstance(user_event_generator, Output): + raise DagsterInvariantViolationError( + ( + "Compute function for {described_node} returned an Output rather than " + "yielding it. The compute_fn of the {node_type} must yield " + "its results" + ).format( + described_node=step_context.describe_op(), + node_type=step_context.op_def.node_type_str, + ) ) - ) - if user_event_generator is None: - return + if user_event_generator is None: + return - if inspect.isasyncgen(user_event_generator): - user_event_generator = gen_from_async_gen(user_event_generator) + if inspect.isasyncgen(user_event_generator): + user_event_generator = gen_from_async_gen(user_event_generator) - op_label = step_context.describe_op() + op_label = step_context.describe_op() + + for event in iterate_with_context( + lambda: op_execution_error_boundary( + DagsterExecutionStepExecutionError, + msg_fn=lambda: f"Error occurred while executing {op_label}:", + step_context=step_context, + step_key=step_context.step.key, + op_def_name=step_context.op_def.name, + op_name=step_context.op.name, + ), + user_event_generator, + ): + if context.has_events(): + yield from context.consume_events() + yield _validate_event(event, step_context) - for event in iterate_with_context( - lambda: op_execution_error_boundary( - DagsterExecutionStepExecutionError, - msg_fn=lambda: f"Error occurred while executing {op_label}:", - step_context=step_context, - step_key=step_context.step.key, - op_def_name=step_context.op_def.name, - op_name=step_context.op.name, - ), - user_event_generator, - ): if context.has_events(): yield from context.consume_events() - yield _validate_event(event, step_context) - - if context.has_events(): - yield from context.consume_events() def execute_core_compute( - step_context: StepExecutionContext, inputs: Mapping[str, Any], compute_fn: OpComputeFunction + step_context: StepExecutionContext, + inputs: Mapping[str, Any], + compute_fn: OpComputeFunction, ) -> Iterator[OpOutputUnion]: """Execute the user-specified compute for the op. Wrap in an error boundary and do all relevant logging and metrics tracking. @@ -208,29 +222,43 @@ def execute_core_compute( elif isinstance(step_output, MaterializeResult): asset_key = ( step_output.asset_key - or step_context.job_def.asset_layer.asset_key_for_node(step_context.node_handle) + or step_context.job_def.asset_layer.asset_key_for_node( + step_context.node_handle + ) ) emitted_result_names.add( - step_context.job_def.asset_layer.node_output_handle_for_asset(asset_key).output_name + step_context.job_def.asset_layer.node_output_handle_for_asset( + asset_key + ).output_name ) # Check results embedded in MaterializeResult are counted for check_result in step_output.check_results or []: - handle = check_result.to_asset_check_evaluation(step_context).asset_check_key - output_name = step_context.job_def.asset_layer.get_output_name_for_asset_check( - handle + handle = check_result.to_asset_check_evaluation( + step_context + ).asset_check_key + output_name = ( + step_context.job_def.asset_layer.get_output_name_for_asset_check( + handle + ) ) emitted_result_names.add(output_name) elif isinstance(step_output, AssetCheckEvaluation): - output_name = step_context.job_def.asset_layer.get_output_name_for_asset_check( - step_output.asset_check_key + output_name = ( + step_context.job_def.asset_layer.get_output_name_for_asset_check( + step_output.asset_check_key + ) ) emitted_result_names.add(output_name) elif isinstance(step_output, AssetCheckResult): if step_output.asset_key and step_output.check_name: handle = AssetCheckKey(step_output.asset_key, step_output.check_name) else: - handle = step_output.to_asset_check_evaluation(step_context).asset_check_key - output_name = step_context.job_def.asset_layer.get_output_name_for_asset_check(handle) + handle = step_output.to_asset_check_evaluation( + step_context + ).asset_check_key + output_name = ( + step_context.job_def.asset_layer.get_output_name_for_asset_check(handle) + ) emitted_result_names.add(output_name) expected_op_output_names = { diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index c0214049f5d31..8603ff2a00217 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -24,6 +24,7 @@ from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.op_definition import OpDefinition from dagster._core.errors import DagsterInvalidDefinitionError +from dagster._core.execution.context.compute import get_execution_context from dagster._core.storage.dagster_run import DagsterRun @@ -149,21 +150,27 @@ def op_annotation_job(): def test_context_provided_to_multi_asset(): - @multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}) + @multi_asset( + outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)} + ) def no_annotation(context): assert isinstance(context, AssetExecutionContext) return None, None materialize([no_annotation]) - @multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}) + @multi_asset( + outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)} + ) def asset_annotation(context: AssetExecutionContext): assert isinstance(context, AssetExecutionContext) return None, None materialize([asset_annotation]) - @multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}) + @multi_asset( + outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)} + ) def op_annotation(context: OpExecutionContext): assert isinstance(context, OpExecutionContext) # AssetExecutionContext is an instance of OpExecutionContext, so add this additional check @@ -279,14 +286,18 @@ def no_annotation(context, *args): yield Output(1) no_annotation_op = OpDefinition(compute_fn=no_annotation, name="no_annotation_op") - no_annotation_graph = GraphDefinition(name="no_annotation_graph", node_defs=[no_annotation_op]) + no_annotation_graph = GraphDefinition( + name="no_annotation_graph", node_defs=[no_annotation_op] + ) no_annotation_graph.to_job(name="no_annotation_job").execute_in_process() def asset_annotation(context: AssetExecutionContext, *args): assert False, "Test should error during context creation" - asset_annotation_op = OpDefinition(compute_fn=asset_annotation, name="asset_annotation_op") + asset_annotation_op = OpDefinition( + compute_fn=asset_annotation, name="asset_annotation_op" + ) asset_annotation_graph = GraphDefinition( name="asset_annotation_graph", node_defs=[asset_annotation_op] ) @@ -304,7 +315,9 @@ def op_annotation(context: OpExecutionContext, *args): yield Output(1) op_annotation_op = OpDefinition(compute_fn=op_annotation, name="op_annotation_op") - op_annotation_graph = GraphDefinition(name="op_annotation_graph", node_defs=[op_annotation_op]) + op_annotation_graph = GraphDefinition( + name="op_annotation_graph", node_defs=[op_annotation_op] + ) op_annotation_graph.to_job(name="op_annotation_job").execute_in_process() @@ -312,10 +325,14 @@ def op_annotation(context: OpExecutionContext, *args): def test_context_provided_to_asset_check(): instance = DagsterInstance.ephemeral() - def execute_assets_and_checks(assets=None, asset_checks=None, raise_on_error: bool = True): + def execute_assets_and_checks( + assets=None, asset_checks=None, raise_on_error: bool = True + ): defs = Definitions(assets=assets, asset_checks=asset_checks) job_def = defs.get_implicit_global_asset_job_def() - return job_def.execute_in_process(raise_on_error=raise_on_error, instance=instance) + return job_def.execute_in_process( + raise_on_error=raise_on_error, instance=instance + ) @asset def to_check(): @@ -345,10 +362,14 @@ def op_annotation(context: OpExecutionContext): def test_context_provided_to_blocking_asset_check(): instance = DagsterInstance.ephemeral() - def execute_assets_and_checks(assets=None, asset_checks=None, raise_on_error: bool = True): + def execute_assets_and_checks( + assets=None, asset_checks=None, raise_on_error: bool = True + ): defs = Definitions(assets=assets, asset_checks=asset_checks) job_def = defs.get_implicit_global_asset_job_def() - return job_def.execute_in_process(raise_on_error=raise_on_error, instance=instance) + return job_def.execute_in_process( + raise_on_error=raise_on_error, instance=instance + ) @asset def to_check(): @@ -405,3 +426,17 @@ def the_op(context: int): @asset def the_asset(context: int): pass + + +def test_get_context(): + assert get_execution_context() is None + + @op + def o(context): + assert context == get_execution_context() + + @job + def j(): + o() + + assert j.execute_in_process().success