Skip to content

Commit

Permalink
make input partition methods based on asset key
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Jan 4, 2024
1 parent 310782a commit 75ea1c6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 73 deletions.
18 changes: 13 additions & 5 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
"""
return self._step_execution_context.asset_partition_key_range_for_input(input_name)
upstream_asset_key = self.asset_key_for_input(input_name)
return self._step_execution_context.asset_partition_key_range_for_upstream(
upstream_asset_key
)

@public
def asset_partition_key_for_input(self, input_name: str) -> str:
Expand Down Expand Up @@ -992,7 +995,8 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# "2023-08-20"
"""
return self._step_execution_context.asset_partition_key_for_input(input_name)
upstream_asset_key = self.asset_key_for_input(input_name)
return self._step_execution_context.asset_partition_key_for_upstream(upstream_asset_key)

@public
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
Expand Down Expand Up @@ -1208,9 +1212,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# running a backfill of the 2023-08-21 through 2023-08-25 partitions of this asset will log:
# ["2023-08-20", "2023-08-21", "2023-08-22", "2023-08-23", "2023-08-24"]
"""
upstream_asset_key = self.asset_key_for_input(input_name)
return list(
self._step_execution_context.asset_partitions_subset_for_input(
input_name
self._step_execution_context.asset_partitions_subset_for_upstream(
upstream_asset_key
).get_partition_keys()
)

Expand Down Expand Up @@ -1288,7 +1293,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# TimeWindow("2023-08-20", "2023-08-25")
"""
return self._step_execution_context.asset_partitions_time_window_for_input(input_name)
upstream_asset_key = self.asset_key_for_input(input_name)
return self._step_execution_context.asset_partitions_time_window_for_upstream(
upstream_asset_key
)

@public
@experimental
Expand Down
125 changes: 57 additions & 68 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,8 @@ def for_input_manager(
node_handle=self.node_handle, input_name=name
)
asset_partitions_subset = (
self.asset_partitions_subset_for_input(name)
if self.has_asset_partitions_for_input(name)
self.asset_partitions_subset_for_upstream(asset_key)
if asset_key and self.has_asset_partitions_for_upstream(asset_key)
else None
)

Expand Down Expand Up @@ -1006,11 +1006,13 @@ def _fetch_input_asset_materialization_and_version_info(self, key: AssetKey) ->
)
storage_id = event.storage_id
# Input name will be none if this is an internal dep
input_name = self.job_def.asset_layer.input_for_asset_key(self.node_handle, key)
input_name = self.job_def.asset_layer.input_for_asset_key(
self.node_handle, key
) # TODO - see if we can check for internal dep another way
# Exclude AllPartitionMapping for now to avoid huge queries
if input_name and self.has_asset_partitions_for_input(input_name):
subset = self.asset_partitions_subset_for_input(
input_name, require_valid_partitions=False
if input_name and self.has_asset_partitions_for_upstream(key):
subset = self.asset_partitions_subset_for_upstream(
key, require_valid_partitions=False
)
input_keys = list(subset.get_partition_keys())

Expand All @@ -1034,6 +1036,7 @@ def _fetch_input_asset_materialization_and_version_info(self, key: AssetKey) ->
storage_id, data_version, event.run_id, event.timestamp
)

# TODO - here
def partition_mapping_for_input(self, input_name: str) -> Optional[PartitionMapping]:
asset_layer = self.job_def.asset_layer
upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name)
Expand Down Expand Up @@ -1109,22 +1112,16 @@ def get_output_asset_keys(self) -> AbstractSet[AssetKey]:
output_keys.add(asset_info.key)
return output_keys

def has_asset_partitions_for_input(self, input_name: str) -> bool:
def has_asset_partitions_for_upstream(self, upstream_asset_key: AssetKey) -> bool:
asset_layer = self.job_def.asset_layer
upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name)

return (
upstream_asset_key is not None
and asset_layer.partitions_def_for_asset(upstream_asset_key) is not None
)
return asset_layer.partitions_def_for_asset(upstream_asset_key) is not None

def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
subset = self.asset_partitions_subset_for_input(input_name)
def asset_partition_key_range_for_upstream(
self, upstream_asset_key: AssetKey
) -> PartitionKeyRange:
subset = self.asset_partitions_subset_for_upstream(upstream_asset_key)

asset_layer = self.job_def.asset_layer
upstream_asset_key = check.not_none(
asset_layer.asset_key_for_input(self.node_handle, input_name)
)
upstream_asset_partitions_def = check.not_none(
asset_layer.partitions_def_for_asset(upstream_asset_key)
)
Expand All @@ -1142,65 +1139,61 @@ def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRa

return partition_key_ranges[0]

def asset_partitions_subset_for_input(
self, input_name: str, *, require_valid_partitions: bool = True
def asset_partitions_subset_for_upstream(
self, upstream_asset_key: AssetKey, *, require_valid_partitions: bool = True
) -> PartitionsSubset:
asset_layer = self.job_def.asset_layer
assets_def = asset_layer.assets_def_for_node(self.node_handle)
upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name)

if upstream_asset_key is not None:
upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key)
upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key)

if upstream_asset_partitions_def is not None:
partitions_def = assets_def.partitions_def if assets_def else None
partitions_subset = (
partitions_def.empty_subset().with_partition_key_range(
partitions_def,
self.asset_partition_key_range,
dynamic_partitions_store=self.instance,
)
if partitions_def
else None
if upstream_asset_partitions_def is not None:
partitions_def = assets_def.partitions_def if assets_def else None
partitions_subset = (
partitions_def.empty_subset().with_partition_key_range(
partitions_def,
self.asset_partition_key_range,
dynamic_partitions_store=self.instance,
)
partition_mapping = infer_partition_mapping(
asset_layer.partition_mapping_for_node_input(
self.node_handle, upstream_asset_key
),
if partitions_def
else None
)
partition_mapping = infer_partition_mapping(
asset_layer.partition_mapping_for_node_input(self.node_handle, upstream_asset_key),
partitions_def,
upstream_asset_partitions_def,
)
mapped_partitions_result = (
partition_mapping.get_upstream_mapped_partitions_result_for_partitions(
partitions_subset,
partitions_def,
upstream_asset_partitions_def,
dynamic_partitions_store=self.instance,
)
mapped_partitions_result = (
partition_mapping.get_upstream_mapped_partitions_result_for_partitions(
partitions_subset,
partitions_def,
upstream_asset_partitions_def,
dynamic_partitions_store=self.instance,
)
)
)

if (
require_valid_partitions
and mapped_partitions_result.required_but_nonexistent_partition_keys
):
raise DagsterInvariantViolationError(
f"Partition key range {self.asset_partition_key_range} in"
f" {self.node_handle.name} depends on invalid partition keys"
f" {mapped_partitions_result.required_but_nonexistent_partition_keys} in"
f" upstream asset {upstream_asset_key}"
)
if (
require_valid_partitions
and mapped_partitions_result.required_but_nonexistent_partition_keys
):
raise DagsterInvariantViolationError(
f"Partition key range {self.asset_partition_key_range} in"
f" {self.node_handle.name} depends on invalid partition keys"
f" {mapped_partitions_result.required_but_nonexistent_partition_keys} in"
f" upstream asset {upstream_asset_key}"
)

return mapped_partitions_result.partitions_subset
return mapped_partitions_result.partitions_subset

check.failed("The input has no asset partitions")
check.failed(f"The asset {upstream_asset_key.to_user_string()} has no asset partitions")

def asset_partition_key_for_input(self, input_name: str) -> str:
start, end = self.asset_partition_key_range_for_input(input_name)
def asset_partition_key_for_upstream(self, upstream_asset_key: AssetKey) -> str:
start, end = self.asset_partition_key_range_for_upstream(upstream_asset_key)
if start == end:
return start
else:
check.failed(
f"Tried to access partition key for input '{input_name}' of step '{self.step.key}',"
f"Tried to access partition key for '{upstream_asset_key.to_user_string()}' of step '{self.step.key}',"
f" but the step input has a partition range: '{start}' to '{end}'."
)

Expand Down Expand Up @@ -1268,19 +1261,15 @@ def asset_partitions_time_window_for_output(self, output_name: str) -> TimeWindo
partitions_def.time_window_for_partition_key(partition_key_range.end).end,
)

def asset_partitions_time_window_for_input(self, input_name: str) -> TimeWindow:
"""The time window for the partitions of the asset correponding to the given input.
def asset_partitions_time_window_for_upstream(self, upstream_asset_key: AssetKey) -> TimeWindow:
"""The time window for the partitions of the asset corresponding to the given input.
Raises an error if either of the following are true:
- The input asset has no partitioning.
- The input asset is not partitioned with a TimeWindowPartitionsDefinition or a
MultiPartitionsDefinition with one time-partitioned dimension.
"""
asset_layer = self.job_def.asset_layer
upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name)

if upstream_asset_key is None:
raise ValueError("The input has no corresponding asset")

upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key)

Expand All @@ -1292,15 +1281,15 @@ def asset_partitions_time_window_for_input(self, input_name: str) -> TimeWindow:

if not has_one_dimension_time_window_partitioning(upstream_asset_partitions_def):
raise ValueError(
"Tried to get asset partitions for an input that correponds to a partitioned "
"Tried to get asset partitions for an input that corresponds to a partitioned "
"asset that is not time-partitioned."
)

upstream_asset_partitions_def = cast(
Union[TimeWindowPartitionsDefinition, MultiPartitionsDefinition],
upstream_asset_partitions_def,
)
partition_key_range = self.asset_partition_key_range_for_input(input_name)
partition_key_range = self.asset_partition_key_range_for_upstream(upstream_asset_key)

return TimeWindow(
upstream_asset_partitions_def.time_window_for_partition_key(
Expand Down

0 comments on commit 75ea1c6

Please sign in to comment.