From 75ea1c610bb639480de2edafd51ad90f29a88c74 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Thu, 4 Jan 2024 12:02:35 -0800 Subject: [PATCH] make input partition methods based on asset key --- .../_core/execution/context/compute.py | 18 ++- .../dagster/_core/execution/context/system.py | 125 ++++++++---------- 2 files changed, 70 insertions(+), 73 deletions(-) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 4d2e23e394fe8..17ee47e324fa3 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -949,7 +949,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): """ - return self._step_execution_context.asset_partition_key_range_for_input(input_name) + upstream_asset_key = self.asset_key_for_input(input_name) + return self._step_execution_context.asset_partition_key_range_for_upstream( + upstream_asset_key + ) @public def asset_partition_key_for_input(self, input_name: str) -> str: @@ -992,7 +995,8 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # "2023-08-20" """ - return self._step_execution_context.asset_partition_key_for_input(input_name) + upstream_asset_key = self.asset_key_for_input(input_name) + return self._step_execution_context.asset_partition_key_for_upstream(upstream_asset_key) @public def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition: @@ -1208,9 +1212,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # running a backfill of the 2023-08-21 through 2023-08-25 partitions of this asset will log: # ["2023-08-20", "2023-08-21", "2023-08-22", "2023-08-23", "2023-08-24"] """ + upstream_asset_key = self.asset_key_for_input(input_name) return list( - self._step_execution_context.asset_partitions_subset_for_input( - input_name + self._step_execution_context.asset_partitions_subset_for_upstream( + upstream_asset_key ).get_partition_keys() ) @@ -1288,7 +1293,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset): # TimeWindow("2023-08-20", "2023-08-25") """ - return self._step_execution_context.asset_partitions_time_window_for_input(input_name) + upstream_asset_key = self.asset_key_for_input(input_name) + return self._step_execution_context.asset_partitions_time_window_for_upstream( + upstream_asset_key + ) @public @experimental diff --git a/python_modules/dagster/dagster/_core/execution/context/system.py b/python_modules/dagster/dagster/_core/execution/context/system.py index bbf09a89fd4a1..c63dea3d8dcbb 100644 --- a/python_modules/dagster/dagster/_core/execution/context/system.py +++ b/python_modules/dagster/dagster/_core/execution/context/system.py @@ -700,8 +700,8 @@ def for_input_manager( node_handle=self.node_handle, input_name=name ) asset_partitions_subset = ( - self.asset_partitions_subset_for_input(name) - if self.has_asset_partitions_for_input(name) + self.asset_partitions_subset_for_upstream(asset_key) + if asset_key and self.has_asset_partitions_for_upstream(asset_key) else None ) @@ -1006,11 +1006,13 @@ def _fetch_input_asset_materialization_and_version_info(self, key: AssetKey) -> ) storage_id = event.storage_id # Input name will be none if this is an internal dep - input_name = self.job_def.asset_layer.input_for_asset_key(self.node_handle, key) + input_name = self.job_def.asset_layer.input_for_asset_key( + self.node_handle, key + ) # TODO - see if we can check for internal dep another way # Exclude AllPartitionMapping for now to avoid huge queries - if input_name and self.has_asset_partitions_for_input(input_name): - subset = self.asset_partitions_subset_for_input( - input_name, require_valid_partitions=False + if input_name and self.has_asset_partitions_for_upstream(key): + subset = self.asset_partitions_subset_for_upstream( + key, require_valid_partitions=False ) input_keys = list(subset.get_partition_keys()) @@ -1034,6 +1036,7 @@ def _fetch_input_asset_materialization_and_version_info(self, key: AssetKey) -> storage_id, data_version, event.run_id, event.timestamp ) + # TODO - here def partition_mapping_for_input(self, input_name: str) -> Optional[PartitionMapping]: asset_layer = self.job_def.asset_layer upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name) @@ -1109,22 +1112,16 @@ def get_output_asset_keys(self) -> AbstractSet[AssetKey]: output_keys.add(asset_info.key) return output_keys - def has_asset_partitions_for_input(self, input_name: str) -> bool: + def has_asset_partitions_for_upstream(self, upstream_asset_key: AssetKey) -> bool: asset_layer = self.job_def.asset_layer - upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name) - - return ( - upstream_asset_key is not None - and asset_layer.partitions_def_for_asset(upstream_asset_key) is not None - ) + return asset_layer.partitions_def_for_asset(upstream_asset_key) is not None - def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange: - subset = self.asset_partitions_subset_for_input(input_name) + def asset_partition_key_range_for_upstream( + self, upstream_asset_key: AssetKey + ) -> PartitionKeyRange: + subset = self.asset_partitions_subset_for_upstream(upstream_asset_key) asset_layer = self.job_def.asset_layer - upstream_asset_key = check.not_none( - asset_layer.asset_key_for_input(self.node_handle, input_name) - ) upstream_asset_partitions_def = check.not_none( asset_layer.partitions_def_for_asset(upstream_asset_key) ) @@ -1142,65 +1139,61 @@ def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRa return partition_key_ranges[0] - def asset_partitions_subset_for_input( - self, input_name: str, *, require_valid_partitions: bool = True + def asset_partitions_subset_for_upstream( + self, upstream_asset_key: AssetKey, *, require_valid_partitions: bool = True ) -> PartitionsSubset: asset_layer = self.job_def.asset_layer assets_def = asset_layer.assets_def_for_node(self.node_handle) - upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name) - if upstream_asset_key is not None: - upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key) + upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key) - if upstream_asset_partitions_def is not None: - partitions_def = assets_def.partitions_def if assets_def else None - partitions_subset = ( - partitions_def.empty_subset().with_partition_key_range( - partitions_def, - self.asset_partition_key_range, - dynamic_partitions_store=self.instance, - ) - if partitions_def - else None + if upstream_asset_partitions_def is not None: + partitions_def = assets_def.partitions_def if assets_def else None + partitions_subset = ( + partitions_def.empty_subset().with_partition_key_range( + partitions_def, + self.asset_partition_key_range, + dynamic_partitions_store=self.instance, ) - partition_mapping = infer_partition_mapping( - asset_layer.partition_mapping_for_node_input( - self.node_handle, upstream_asset_key - ), + if partitions_def + else None + ) + partition_mapping = infer_partition_mapping( + asset_layer.partition_mapping_for_node_input(self.node_handle, upstream_asset_key), + partitions_def, + upstream_asset_partitions_def, + ) + mapped_partitions_result = ( + partition_mapping.get_upstream_mapped_partitions_result_for_partitions( + partitions_subset, partitions_def, upstream_asset_partitions_def, + dynamic_partitions_store=self.instance, ) - mapped_partitions_result = ( - partition_mapping.get_upstream_mapped_partitions_result_for_partitions( - partitions_subset, - partitions_def, - upstream_asset_partitions_def, - dynamic_partitions_store=self.instance, - ) - ) + ) - if ( - require_valid_partitions - and mapped_partitions_result.required_but_nonexistent_partition_keys - ): - raise DagsterInvariantViolationError( - f"Partition key range {self.asset_partition_key_range} in" - f" {self.node_handle.name} depends on invalid partition keys" - f" {mapped_partitions_result.required_but_nonexistent_partition_keys} in" - f" upstream asset {upstream_asset_key}" - ) + if ( + require_valid_partitions + and mapped_partitions_result.required_but_nonexistent_partition_keys + ): + raise DagsterInvariantViolationError( + f"Partition key range {self.asset_partition_key_range} in" + f" {self.node_handle.name} depends on invalid partition keys" + f" {mapped_partitions_result.required_but_nonexistent_partition_keys} in" + f" upstream asset {upstream_asset_key}" + ) - return mapped_partitions_result.partitions_subset + return mapped_partitions_result.partitions_subset - check.failed("The input has no asset partitions") + check.failed(f"The asset {upstream_asset_key.to_user_string()} has no asset partitions") - def asset_partition_key_for_input(self, input_name: str) -> str: - start, end = self.asset_partition_key_range_for_input(input_name) + def asset_partition_key_for_upstream(self, upstream_asset_key: AssetKey) -> str: + start, end = self.asset_partition_key_range_for_upstream(upstream_asset_key) if start == end: return start else: check.failed( - f"Tried to access partition key for input '{input_name}' of step '{self.step.key}'," + f"Tried to access partition key for '{upstream_asset_key.to_user_string()}' of step '{self.step.key}'," f" but the step input has a partition range: '{start}' to '{end}'." ) @@ -1268,8 +1261,8 @@ def asset_partitions_time_window_for_output(self, output_name: str) -> TimeWindo partitions_def.time_window_for_partition_key(partition_key_range.end).end, ) - def asset_partitions_time_window_for_input(self, input_name: str) -> TimeWindow: - """The time window for the partitions of the asset correponding to the given input. + def asset_partitions_time_window_for_upstream(self, upstream_asset_key: AssetKey) -> TimeWindow: + """The time window for the partitions of the asset corresponding to the given input. Raises an error if either of the following are true: - The input asset has no partitioning. @@ -1277,10 +1270,6 @@ def asset_partitions_time_window_for_input(self, input_name: str) -> TimeWindow: MultiPartitionsDefinition with one time-partitioned dimension. """ asset_layer = self.job_def.asset_layer - upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name) - - if upstream_asset_key is None: - raise ValueError("The input has no corresponding asset") upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key) @@ -1292,7 +1281,7 @@ def asset_partitions_time_window_for_input(self, input_name: str) -> TimeWindow: if not has_one_dimension_time_window_partitioning(upstream_asset_partitions_def): raise ValueError( - "Tried to get asset partitions for an input that correponds to a partitioned " + "Tried to get asset partitions for an input that corresponds to a partitioned " "asset that is not time-partitioned." ) @@ -1300,7 +1289,7 @@ def asset_partitions_time_window_for_input(self, input_name: str) -> TimeWindow: Union[TimeWindowPartitionsDefinition, MultiPartitionsDefinition], upstream_asset_partitions_def, ) - partition_key_range = self.asset_partition_key_range_for_input(input_name) + partition_key_range = self.asset_partition_key_range_for_upstream(upstream_asset_key) return TimeWindow( upstream_asset_partitions_def.time_window_for_partition_key(