Skip to content

Commit

Permalink
[4/n subset refactor] Add whitelist_for_serdes to DefaultPartitionsSu…
Browse files Browse the repository at this point in the history
…bset (#17703)

This PR makes the `DefaultPartitionsSubset` serializable by making it a
`NamedTuple` and removing `partitions_def` from it (as these partitions
defs cannot be serialized).

This causes a cascading set of changes:
- `PartitionsSubset` methods such as `get_partition_keys_in_range`,
`get_partition_keys_not_in_subset` must now accept a `partitions_def`
arg
- `PartitionMapping` methods now must accept a partitions def
corresponding to a partitions subset, otherwise the partitions def is
inaccessible
- Subclassing named tuples doesn't work well since you can't override
methods, so this PR removes `MultiPartitionsSubset` and modifies
callsites to transform partition keys to `MultiPartitionKey`s if needed
  • Loading branch information
clairelin135 authored Nov 16, 2023
1 parent a0656fc commit 02c11a8
Show file tree
Hide file tree
Showing 30 changed files with 461 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,13 @@
DagsterEventType,
DagsterInstance,
EventRecordsFilter,
MultiPartitionKey,
MultiPartitionsDefinition,
_check as check,
)
from dagster._core.definitions.data_time import CachingDataTimeResolver
from dagster._core.definitions.external_asset_graph import ExternalAssetGraph
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionsSubset,
)
from dagster._core.definitions.partition import (
CachingDynamicPartitionsLoader,
DefaultPartitionsSubset,
PartitionsDefinition,
PartitionsSubset,
)
Expand Down Expand Up @@ -420,6 +415,7 @@ def build_partition_statuses(
materialized_partitions_subset: Optional[PartitionsSubset],
failed_partitions_subset: Optional[PartitionsSubset],
in_progress_partitions_subset: Optional[PartitionsSubset],
partitions_def: Optional[PartitionsDefinition],
) -> Union[
"GrapheneTimePartitionStatuses",
"GrapheneDefaultPartitionStatuses",
Expand Down Expand Up @@ -481,14 +477,15 @@ def build_partition_statuses(
)
)
return GrapheneTimePartitionStatuses(ranges=graphene_ranges)
elif isinstance(materialized_partitions_subset, MultiPartitionsSubset):
elif isinstance(partitions_def, MultiPartitionsDefinition):
return get_2d_run_length_encoded_partitions(
dynamic_partitions_store,
materialized_partitions_subset,
failed_partitions_subset,
in_progress_partitions_subset,
partitions_def,
)
elif isinstance(materialized_partitions_subset, DefaultPartitionsSubset):
elif partitions_def:
materialized_keys = materialized_partitions_subset.get_partition_keys()
failed_keys = failed_partitions_subset.get_partition_keys()
in_progress_keys = in_progress_partitions_subset.get_partition_keys()
Expand All @@ -499,7 +496,7 @@ def build_partition_statuses(
- set(in_progress_keys),
failedPartitions=failed_keys,
unmaterializedPartitions=materialized_partitions_subset.get_partition_keys_not_in_subset(
dynamic_partitions_store=dynamic_partitions_store
partitions_def=partitions_def, dynamic_partitions_store=dynamic_partitions_store
),
materializingPartitions=in_progress_keys,
)
Expand All @@ -512,57 +509,53 @@ def get_2d_run_length_encoded_partitions(
materialized_partitions_subset: PartitionsSubset,
failed_partitions_subset: PartitionsSubset,
in_progress_partitions_subset: PartitionsSubset,
partitions_def: MultiPartitionsDefinition,
) -> "GrapheneMultiPartitionStatuses":
from ..schema.pipelines.pipeline import (
GrapheneMultiPartitionRangeStatuses,
GrapheneMultiPartitionStatuses,
)

if (
not isinstance(materialized_partitions_subset.partitions_def, MultiPartitionsDefinition)
or not isinstance(failed_partitions_subset.partitions_def, MultiPartitionsDefinition)
or not isinstance(in_progress_partitions_subset.partitions_def, MultiPartitionsDefinition)
):
check.failed("Can only fetch 2D run length encoded partitions for multipartitioned assets")
check.invariant(
isinstance(partitions_def, MultiPartitionsDefinition),
"Partitions definition should be multipartitioned",
)

primary_dim = materialized_partitions_subset.partitions_def.primary_dimension
secondary_dim = materialized_partitions_subset.partitions_def.secondary_dimension
primary_dim = partitions_def.primary_dimension
secondary_dim = partitions_def.secondary_dimension

dim2_materialized_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict(
lambda: secondary_dim.partitions_def.empty_subset()
)
for partition_key in cast(
Sequence[MultiPartitionKey], materialized_partitions_subset.get_partition_keys()
):
for partition_key in materialized_partitions_subset.get_partition_keys():
multipartition_key = partitions_def.get_partition_key_from_str(partition_key)
dim2_materialized_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
multipartition_key.keys_by_dimension[primary_dim.name]
] = dim2_materialized_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([partition_key.keys_by_dimension[secondary_dim.name]])
multipartition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([multipartition_key.keys_by_dimension[secondary_dim.name]])

dim2_failed_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict(
lambda: secondary_dim.partitions_def.empty_subset()
)
for partition_key in cast(
Sequence[MultiPartitionKey], failed_partitions_subset.get_partition_keys()
):
for partition_key in failed_partitions_subset.get_partition_keys():
multipartition_key = partitions_def.get_partition_key_from_str(partition_key)
dim2_failed_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
multipartition_key.keys_by_dimension[primary_dim.name]
] = dim2_failed_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([partition_key.keys_by_dimension[secondary_dim.name]])
multipartition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([multipartition_key.keys_by_dimension[secondary_dim.name]])

dim2_in_progress_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict(
lambda: secondary_dim.partitions_def.empty_subset()
)
for partition_key in cast(
Sequence[MultiPartitionKey], in_progress_partitions_subset.get_partition_keys()
):
for partition_key in in_progress_partitions_subset.get_partition_keys():
multipartition_key = partitions_def.get_partition_key_from_str(partition_key)
dim2_in_progress_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
multipartition_key.keys_by_dimension[primary_dim.name]
] = dim2_in_progress_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([partition_key.keys_by_dimension[secondary_dim.name]])
multipartition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([multipartition_key.keys_by_dimension[secondary_dim.name]])

materialized_2d_ranges = []

Expand Down Expand Up @@ -626,6 +619,7 @@ def get_2d_run_length_encoded_partitions(
dim2_materialized_partition_subset_by_dim1[start_key],
dim2_failed_partition_subset_by_dim1[start_key],
dim2_in_progress_partition_subset_by_dim1[start_key],
secondary_dim.partitions_def,
),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,12 @@ def resolve_assetPartitionStatuses(
if not self._dynamic_partitions_loader:
check.failed("dynamic_partitions_loader must be provided to get partition keys")

partitions_def = (
self._external_asset_node.partitions_def_data.get_partitions_definition()
if self._external_asset_node.partitions_def_data
else None
)

(
materialized_partition_subset,
failed_partition_subset,
Expand All @@ -944,18 +950,15 @@ def resolve_assetPartitionStatuses(
graphene_info.context.instance,
asset_key,
self._dynamic_partitions_loader,
(
self._external_asset_node.partitions_def_data.get_partitions_definition()
if self._external_asset_node.partitions_def_data
else None
),
partitions_def,
)

return build_partition_statuses(
self._dynamic_partitions_loader,
materialized_partition_subset,
failed_partition_subset,
in_progress_subset,
partitions_def,
)

def resolve_partitionStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def __init__(self, partition_subset: PartitionsSubset):
if isinstance(partition_subset, BaseTimeWindowPartitionsSubset):
ranges = [
GraphenePartitionKeyRange(start, end)
for start, end in partition_subset.get_partition_key_ranges()
for start, end in partition_subset.get_partition_key_ranges(
partition_subset.partitions_def
)
]
partition_keys = None
else: # Default partitions subset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def _build_run_requests_with_backfill_policy(
run_requests = []
partition_subset = partitions_def.subset_with_partition_keys(partition_keys)
partition_key_ranges = partition_subset.get_partition_key_ranges(
dynamic_partitions_store=dynamic_partitions_store
partitions_def, dynamic_partitions_store=dynamic_partitions_store
)
for partition_key_range in partition_key_ranges:
# We might resolve more than one partition key range for the given partition keys.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_unhandled_partitions(
)

return handled_subset.get_partition_keys_not_in_subset(
partitions_def=partitions_def,
current_time=current_time,
dynamic_partitions_store=dynamic_partitions_store,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def get_child_partition_keys_of_parent(
partition_mapping = self.get_partition_mapping(child_asset_key, parent_asset_key)
child_partitions_subset = partition_mapping.get_downstream_partitions_for_partitions(
parent_partitions_def.empty_subset().with_partition_keys([parent_partition_key]),
parent_partitions_def,
downstream_partitions_def=child_partitions_def,
dynamic_partitions_store=dynamic_partitions_store,
current_time=current_time,
Expand Down Expand Up @@ -436,7 +437,7 @@ def get_parent_partition_keys_for_child(
"""
partition_key = check.opt_str_param(partition_key, "partition_key")

child_partitions_def = self.get_partitions_def(child_asset_key)
child_partitions_def = cast(PartitionsDefinition, self.get_partitions_def(child_asset_key))
parent_partitions_def = self.get_partitions_def(parent_asset_key)

if parent_partitions_def is None:
Expand All @@ -448,12 +449,11 @@ def get_parent_partition_keys_for_child(

return partition_mapping.get_upstream_mapped_partitions_result_for_partitions(
(
cast(PartitionsDefinition, child_partitions_def).subset_with_partition_keys(
[partition_key]
)
child_partitions_def.subset_with_partition_keys([partition_key])
if partition_key
else None
),
downstream_partitions_def=child_partitions_def,
upstream_partitions_def=parent_partitions_def,
dynamic_partitions_store=dynamic_partitions_store,
current_time=current_time,
Expand Down Expand Up @@ -610,6 +610,7 @@ def bfs_filter_subsets(
child_partitions_subset = (
partition_mapping.get_downstream_partitions_for_partitions(
partitions_subset,
check.not_none(self.get_partitions_def(asset_key)),
downstream_partitions_def=child_partitions_def,
dynamic_partitions_store=dynamic_partitions_store,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,18 @@ def to_storage_dict(
for key, value in self.partitions_subsets_by_asset_key.items()
},
"serializable_partitions_def_ids_by_asset_key": {
key.to_user_string(): value.partitions_def.get_serializable_unique_identifier(
key.to_user_string(): check.not_none(
self._asset_graph.get_partitions_def(key)
).get_serializable_unique_identifier(
dynamic_partitions_store=dynamic_partitions_store
)
for key, value in self.partitions_subsets_by_asset_key.items()
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"partitions_def_class_names_by_asset_key": {
key.to_user_string(): value.partitions_def.__class__.__name__
for key, value in self.partitions_subsets_by_asset_key.items()
key.to_user_string(): check.not_none(
self._asset_graph.get_partitions_def(key)
).__class__.__name__
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"non_partitioned_asset_keys": [
key.to_user_string() for key in self._non_partitioned_asset_keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def get_downstream_partition_keys(
downstream_partition_key_subset = (
partition_mapping.get_downstream_partitions_for_partitions(
from_asset.partitions_def.empty_subset().with_partition_keys([partition_key]),
from_asset.partitions_def,
downstream_partitions_def=to_partitions_def,
dynamic_partitions_store=self.instance,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import lru_cache, reduce
from typing import (
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Expand Down Expand Up @@ -216,7 +215,7 @@ def __init__(self, partitions_defs: Mapping[str, PartitionsDefinition]):

@property
def partitions_subset_class(self) -> Type["PartitionsSubset"]:
return MultiPartitionsSubset
return DefaultPartitionsSubset

def get_serializable_unique_identifier(
self, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None
Expand Down Expand Up @@ -509,33 +508,6 @@ def get_num_partitions(
return reduce(lambda x, y: x * y, dimension_counts, 1)


class MultiPartitionsSubset(DefaultPartitionsSubset):
def __init__(
self,
partitions_def: MultiPartitionsDefinition,
subset: Optional[Set[str]] = None,
):
check.inst_param(partitions_def, "partitions_def", MultiPartitionsDefinition)
subset = (
set(
[
partitions_def.get_partition_key_from_str(key)
for key in subset
if MULTIPARTITION_KEY_DELIMITER in key
]
)
if subset
else set()
)
super(MultiPartitionsSubset, self).__init__(partitions_def, subset)

def with_partition_keys(self, partition_keys: Iterable[str]) -> "MultiPartitionsSubset":
return MultiPartitionsSubset(
cast(MultiPartitionsDefinition, self._partitions_def),
self._subset | set(partition_keys),
)


def get_tags_from_multi_partition_key(multi_partition_key: MultiPartitionKey) -> Mapping[str, str]:
check.inst_param(multi_partition_key, "multi_partition_key", MultiPartitionKey)

Expand Down
Loading

0 comments on commit 02c11a8

Please sign in to comment.