Skip to content

Commit

Permalink
add indirect execution context access
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Oct 23, 2023
1 parent 6d850d8 commit 1ea5b5c
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 83 deletions.
123 changes: 98 additions & 25 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, ABCMeta, abstractmethod
from inspect import _empty as EmptyAnnotation
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
AbstractSet,
Any,
Expand Down Expand Up @@ -126,7 +129,9 @@ def __instancecheck__(cls, instance) -> bool:
return super().__instancecheck__(instance)


class OpExecutionContext(AbstractComputeExecutionContext, metaclass=OpExecutionContextMetaClass):
class OpExecutionContext(
AbstractComputeExecutionContext, metaclass=OpExecutionContextMetaClass
):
"""The ``context`` object that can be made available as the first argument to the function
used for computing an op or asset.
Expand Down Expand Up @@ -308,7 +313,10 @@ def my_asset(context: AssetExecutionContext):
"""
return self._step_execution_context.partition_key

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@deprecated(
breaking_version="2.0",
additional_warn_text="Use `partition_key_range` instead.",
)
@public
@property
def asset_partition_key_range(self) -> PartitionKeyRange:
Expand Down Expand Up @@ -432,10 +440,14 @@ def log_materialization(context):
DagsterEvent.asset_materialization(self._step_execution_context, event)
)
elif isinstance(event, AssetObservation):
self._events.append(DagsterEvent.asset_observation(self._step_execution_context, event))
self._events.append(
DagsterEvent.asset_observation(self._step_execution_context, event)
)
elif isinstance(event, ExpectationResult):
self._events.append(
DagsterEvent.step_expectation_result(self._step_execution_context, event)
DagsterEvent.step_expectation_result(
self._step_execution_context, event
)
)
else:
check.failed(f"Unexpected event {event}")
Expand Down Expand Up @@ -568,7 +580,10 @@ def has_asset_checks_def(self) -> bool:
Returns:
bool: True if there is a backing AssetChecksDefinition for the current execution, otherwise False.
"""
return self.job_def.asset_layer.asset_checks_def_for_node(self.node_handle) is not None
return (
self.job_def.asset_layer.asset_checks_def_for_node(self.node_handle)
is not None
)

@public
@property
Expand All @@ -579,7 +594,9 @@ def asset_checks_def(self) -> AssetChecksDefinition:
Returns:
AssetChecksDefinition.
"""
asset_checks_def = self.job_def.asset_layer.asset_checks_def_for_node(self.node_handle)
asset_checks_def = self.job_def.asset_layer.asset_checks_def_for_node(
self.node_handle
)
if asset_checks_def is None:
raise DagsterInvalidPropertyError(
f"Op '{self.op.name}' does not have an asset checks definition."
Expand All @@ -594,7 +611,9 @@ def selected_asset_check_keys(self) -> AbstractSet[AssetCheckKey]:
return self.assets_def.check_keys

if self.has_asset_checks_def:
check.failed("Subset selection is not yet supported within an AssetChecksDefinition")
check.failed(
"Subset selection is not yet supported within an AssetChecksDefinition"
)

return set()

Expand Down Expand Up @@ -635,7 +654,9 @@ def asset_key_for_output(self, output_name: str = "result") -> AssetKey:
@public
def output_for_asset_key(self, asset_key: AssetKey) -> str:
"""Return the output name for the corresponding asset key."""
node_output_handle = self.job_def.asset_layer.node_output_handle_for_asset(asset_key)
node_output_handle = self.job_def.asset_layer.node_output_handle_for_asset(
asset_key
)
if node_output_handle is None:
check.failed(f"Asset key '{asset_key}' has no output")
else:
Expand Down Expand Up @@ -710,7 +731,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
return self._step_execution_context.asset_partition_key_for_output(output_name)

@public
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
def asset_partitions_time_window_for_output(
self, output_name: str = "result"
) -> TimeWindow:
"""The time window for the partitions of the output asset.
If you want to write your asset to support running a backfill of several partitions in a single run,
Expand Down Expand Up @@ -782,7 +805,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# TimeWindow("2023-08-21", "2023-08-26")
"""
return self._step_execution_context.asset_partitions_time_window_for_output(output_name)
return self._step_execution_context.asset_partitions_time_window_for_output(
output_name
)

@public
def asset_partition_key_range_for_output(
Expand Down Expand Up @@ -845,7 +870,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# PartitionKeyRange(start="2023-08-21", end="2023-08-25")
"""
return self._step_execution_context.asset_partition_key_range_for_output(output_name)
return self._step_execution_context.asset_partition_key_range_for_output(
output_name
)

@public
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
Expand Down Expand Up @@ -908,7 +935,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
"""
return self._step_execution_context.asset_partition_key_range_for_input(input_name)
return self._step_execution_context.asset_partition_key_range_for_input(
input_name
)

@public
def asset_partition_key_for_input(self, input_name: str) -> str:
Expand Down Expand Up @@ -954,7 +983,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
return self._step_execution_context.asset_partition_key_for_input(input_name)

@public
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
def asset_partitions_def_for_output(
self, output_name: str = "result"
) -> PartitionsDefinition:
"""The PartitionsDefinition on the asset corresponding to this output.
Args:
Expand Down Expand Up @@ -994,8 +1025,10 @@ def a_multi_asset(context: AssetExecutionContext):
"""
asset_key = self.asset_key_for_output(output_name)
result = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset(
asset_key
result = (
self._step_execution_context.job_def.asset_layer.partitions_def_for_asset(
asset_key
)
)
if result is None:
raise DagsterInvariantViolationError(
Expand Down Expand Up @@ -1034,8 +1067,10 @@ def upstream_asset(context: AssetExecutionContext, upstream_asset):
"""
asset_key = self.asset_key_for_input(input_name)
result = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset(
asset_key
result = (
self._step_execution_context.job_def.asset_layer.partitions_def_for_asset(
asset_key
)
)
if result is None:
raise DagsterInvariantViolationError(
Expand All @@ -1046,7 +1081,9 @@ def upstream_asset(context: AssetExecutionContext, upstream_asset):
return result

@public
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
def asset_partition_keys_for_output(
self, output_name: str = "result"
) -> Sequence[str]:
"""Returns a list of the partition keys for the given output.
If you want to write your asset to support running a backfill of several partitions in a single run,
Expand Down Expand Up @@ -1103,8 +1140,12 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# running a backfill of the 2023-08-21 through 2023-08-25 partitions of this asset will log:
# ["2023-08-21", "2023-08-22", "2023-08-23", "2023-08-24", "2023-08-25"]
"""
return self.asset_partitions_def_for_output(output_name).get_partition_keys_in_range(
self._step_execution_context.asset_partition_key_range_for_output(output_name),
return self.asset_partitions_def_for_output(
output_name
).get_partition_keys_in_range(
self._step_execution_context.asset_partition_key_range_for_output(
output_name
),
dynamic_partitions_store=self.instance,
)

Expand Down Expand Up @@ -1174,7 +1215,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
)

@public
def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow:
def asset_partitions_time_window_for_input(
self, input_name: str = "result"
) -> TimeWindow:
"""The time window for the partitions of the input asset.
If you want to write your asset to support running a backfill of several partitions in a single run,
Expand Down Expand Up @@ -1247,7 +1290,9 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# TimeWindow("2023-08-20", "2023-08-25")
"""
return self._step_execution_context.asset_partitions_time_window_for_input(input_name)
return self._step_execution_context.asset_partitions_time_window_for_input(
input_name
)

@public
@experimental
Expand All @@ -1265,7 +1310,9 @@ def get_asset_provenance(self, asset_key: AssetKey) -> Optional[DataProvenance]:
record = self.instance.get_latest_data_version_record(asset_key)

return (
None if record is None else extract_data_provenance_from_entry(record.event_log_entry)
None
if record is None
else extract_data_provenance_from_entry(record.event_log_entry)
)

def set_data_version(self, asset_key: AssetKey, data_version: DataVersion) -> None:
Expand Down Expand Up @@ -1297,8 +1344,12 @@ def requires_typed_event_stream(self) -> bool:
def typed_event_stream_error_message(self) -> Optional[str]:
return self._step_execution_context.typed_event_stream_error_message

def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None) -> None:
self._step_execution_context.set_requires_typed_event_stream(error_message=error_message)
def set_requires_typed_event_stream(
self, *, error_message: Optional[str] = None
) -> None:
self._step_execution_context.set_requires_typed_event_stream(
error_message=error_message
)


class AssetExecutionContext(OpExecutionContext):
Expand Down Expand Up @@ -1372,3 +1423,25 @@ def build_execution_context(
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(step_context)
return OpExecutionContext(step_context)


_current_context: ContextVar[Optional[OpExecutionContext]] = ContextVar(
"execution_context", default=None
)


@contextmanager
def enter_execution_context(
step_context: StepExecutionContext,
) -> Iterator[OpExecutionContext]:
ctx = build_execution_context(step_context)

token = _current_context.set(ctx)
try:
yield ctx
finally:
_current_context.reset(token)


def get_execution_context():
return _current_context.get()
Loading

0 comments on commit 1ea5b5c

Please sign in to comment.