Skip to content

Commit

Permalink
first stab
Browse files Browse the repository at this point in the history
continue

time window partitions subset changes

asset backfill serialization

partition mapping update

continue refactor

fix more tests

more test fixes

fix partition mapping tests

adjust test

fix more tests

add tests
  • Loading branch information
clairelin135 committed Nov 3, 2023
1 parent 263cd3b commit c4e4110
Show file tree
Hide file tree
Showing 30 changed files with 413 additions and 266 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,13 @@
DagsterEventType,
DagsterInstance,
EventRecordsFilter,
MultiPartitionKey,
MultiPartitionsDefinition,
_check as check,
)
from dagster._core.definitions.data_time import CachingDataTimeResolver
from dagster._core.definitions.external_asset_graph import ExternalAssetGraph
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionsSubset,
)
from dagster._core.definitions.partition import (
CachingDynamicPartitionsLoader,
DefaultPartitionsSubset,
PartitionsDefinition,
PartitionsSubset,
)
Expand Down Expand Up @@ -420,6 +415,7 @@ def build_partition_statuses(
materialized_partitions_subset: Optional[PartitionsSubset],
failed_partitions_subset: Optional[PartitionsSubset],
in_progress_partitions_subset: Optional[PartitionsSubset],
partitions_def: Optional[PartitionsDefinition],
) -> Union[
"GrapheneTimePartitionStatuses",
"GrapheneDefaultPartitionStatuses",
Expand Down Expand Up @@ -469,7 +465,7 @@ def build_partition_statuses(
graphene_ranges = []
for r in ranges:
partition_key_range = cast(
TimeWindowPartitionsDefinition, materialized_partitions_subset.partitions_def
TimeWindowPartitionsDefinition, partitions_def
).get_partition_key_range_for_time_window(r.time_window)
graphene_ranges.append(
GrapheneTimePartitionRangeStatus(
Expand All @@ -481,14 +477,15 @@ def build_partition_statuses(
)
)
return GrapheneTimePartitionStatuses(ranges=graphene_ranges)
elif isinstance(materialized_partitions_subset, MultiPartitionsSubset):
elif isinstance(partitions_def, MultiPartitionsDefinition):
return get_2d_run_length_encoded_partitions(
dynamic_partitions_store,
materialized_partitions_subset,
failed_partitions_subset,
in_progress_partitions_subset,
partitions_def,
)
elif isinstance(materialized_partitions_subset, DefaultPartitionsSubset):
elif partitions_def:
materialized_keys = materialized_partitions_subset.get_partition_keys()
failed_keys = failed_partitions_subset.get_partition_keys()
in_progress_keys = in_progress_partitions_subset.get_partition_keys()
Expand All @@ -499,7 +496,7 @@ def build_partition_statuses(
- set(in_progress_keys),
failedPartitions=failed_keys,
unmaterializedPartitions=materialized_partitions_subset.get_partition_keys_not_in_subset(
dynamic_partitions_store=dynamic_partitions_store
partitions_def=partitions_def, dynamic_partitions_store=dynamic_partitions_store
),
materializingPartitions=in_progress_keys,
)
Expand All @@ -512,57 +509,51 @@ def get_2d_run_length_encoded_partitions(
materialized_partitions_subset: PartitionsSubset,
failed_partitions_subset: PartitionsSubset,
in_progress_partitions_subset: PartitionsSubset,
partitions_def: MultiPartitionsDefinition,
) -> "GrapheneMultiPartitionStatuses":
from ..schema.pipelines.pipeline import (
GrapheneMultiPartitionRangeStatuses,
GrapheneMultiPartitionStatuses,
)

if (
not isinstance(materialized_partitions_subset.partitions_def, MultiPartitionsDefinition)
or not isinstance(failed_partitions_subset.partitions_def, MultiPartitionsDefinition)
or not isinstance(in_progress_partitions_subset.partitions_def, MultiPartitionsDefinition)
):
if not isinstance(partitions_def, MultiPartitionsDefinition):
check.failed("Can only fetch 2D run length encoded partitions for multipartitioned assets")

primary_dim = materialized_partitions_subset.partitions_def.primary_dimension
secondary_dim = materialized_partitions_subset.partitions_def.secondary_dimension
primary_dim = partitions_def.primary_dimension
secondary_dim = partitions_def.secondary_dimension

dim2_materialized_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict(
lambda: secondary_dim.partitions_def.empty_subset()
)
for partition_key in cast(
Sequence[MultiPartitionKey], materialized_partitions_subset.get_partition_keys()
):
for partition_key in materialized_partitions_subset.get_partition_keys():
multipartition_key = partitions_def.get_partition_key_from_str(partition_key)
dim2_materialized_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
multipartition_key.keys_by_dimension[primary_dim.name]
] = dim2_materialized_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([partition_key.keys_by_dimension[secondary_dim.name]])
multipartition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([multipartition_key.keys_by_dimension[secondary_dim.name]])

dim2_failed_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict(
lambda: secondary_dim.partitions_def.empty_subset()
)
for partition_key in cast(
Sequence[MultiPartitionKey], failed_partitions_subset.get_partition_keys()
):
for partition_key in failed_partitions_subset.get_partition_keys():
multipartition_key = partitions_def.get_partition_key_from_str(partition_key)
dim2_failed_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
multipartition_key.keys_by_dimension[primary_dim.name]
] = dim2_failed_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([partition_key.keys_by_dimension[secondary_dim.name]])
multipartition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([multipartition_key.keys_by_dimension[secondary_dim.name]])

dim2_in_progress_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict(
lambda: secondary_dim.partitions_def.empty_subset()
)
for partition_key in cast(
Sequence[MultiPartitionKey], in_progress_partitions_subset.get_partition_keys()
):
for partition_key in in_progress_partitions_subset.get_partition_keys():
multipartition_key = partitions_def.get_partition_key_from_str(partition_key)
dim2_in_progress_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
multipartition_key.keys_by_dimension[primary_dim.name]
] = dim2_in_progress_partition_subset_by_dim1[
partition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([partition_key.keys_by_dimension[secondary_dim.name]])
multipartition_key.keys_by_dimension[primary_dim.name]
].with_partition_keys([multipartition_key.keys_by_dimension[secondary_dim.name]])

materialized_2d_ranges = []

Expand Down Expand Up @@ -626,6 +617,7 @@ def get_2d_run_length_encoded_partitions(
dim2_materialized_partition_subset_by_dim1[start_key],
dim2_failed_partition_subset_by_dim1[start_key],
dim2_in_progress_partition_subset_by_dim1[start_key],
secondary_dim.partitions_def,
),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,12 @@ def resolve_assetPartitionStatuses(
if not self._dynamic_partitions_loader:
check.failed("dynamic_partitions_loader must be provided to get partition keys")

partitions_def = (
self._external_asset_node.partitions_def_data.get_partitions_definition()
if self._external_asset_node.partitions_def_data
else None
)

(
materialized_partition_subset,
failed_partition_subset,
Expand All @@ -942,18 +948,15 @@ def resolve_assetPartitionStatuses(
graphene_info.context.instance,
asset_key,
self._dynamic_partitions_loader,
(
self._external_asset_node.partitions_def_data.get_partitions_definition()
if self._external_asset_node.partitions_def_data
else None
),
partitions_def,
)

return build_partition_statuses(
self._dynamic_partitions_loader,
materialized_partition_subset,
failed_partition_subset,
in_progress_subset,
partitions_def,
)

def resolve_partitionStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def __init__(self, partition_subset: PartitionsSubset):
if isinstance(partition_subset, BaseTimeWindowPartitionsSubset):
ranges = [
GraphenePartitionKeyRange(start, end)
for start, end in partition_subset.get_partition_key_ranges()
for start, end in partition_subset.get_partition_key_ranges(
partition_subset.get_partitions_def()
)
]
partition_keys = None
else: # Default partitions subset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def _build_run_requests_with_backfill_policy(
) -> Sequence[RunRequest]:
run_requests = []
partition_subset = partitions_def.subset_with_partition_keys(partition_keys)
partition_key_ranges = partition_subset.get_partition_key_ranges()
partition_key_ranges = partition_subset.get_partition_key_ranges(partitions_def)
for partition_key_range in partition_key_ranges:
# We might resolve more than one partition key range for the given partition keys.
# We can only apply chunking on individual partition key ranges.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_unhandled_partitions(
)

return handled_subset.get_partition_keys_not_in_subset(
partitions_def=partitions_def,
current_time=current_time,
dynamic_partitions_store=dynamic_partitions_store,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def get_child_partition_keys_of_parent(
partition_mapping = self.get_partition_mapping(child_asset_key, parent_asset_key)
child_partitions_subset = partition_mapping.get_downstream_partitions_for_partitions(
parent_partitions_def.empty_subset().with_partition_keys([parent_partition_key]),
parent_partitions_def,
downstream_partitions_def=child_partitions_def,
dynamic_partitions_store=dynamic_partitions_store,
current_time=current_time,
Expand Down Expand Up @@ -437,7 +438,7 @@ def get_parent_partition_keys_for_child(
"""
partition_key = check.opt_str_param(partition_key, "partition_key")

child_partitions_def = self.get_partitions_def(child_asset_key)
child_partitions_def = cast(PartitionsDefinition, self.get_partitions_def(child_asset_key))
parent_partitions_def = self.get_partitions_def(parent_asset_key)

if parent_partitions_def is None:
Expand All @@ -449,12 +450,11 @@ def get_parent_partition_keys_for_child(

return partition_mapping.get_upstream_mapped_partitions_result_for_partitions(
(
cast(PartitionsDefinition, child_partitions_def).subset_with_partition_keys(
[partition_key]
)
child_partitions_def.subset_with_partition_keys([partition_key])
if partition_key
else None
),
downstream_partitions_def=child_partitions_def,
upstream_partitions_def=parent_partitions_def,
dynamic_partitions_store=dynamic_partitions_store,
current_time=current_time,
Expand Down Expand Up @@ -611,6 +611,7 @@ def bfs_filter_subsets(
child_partitions_subset = (
partition_mapping.get_downstream_partitions_for_partitions(
partitions_subset,
check.not_none(self.get_partitions_def(asset_key)),
downstream_partitions_def=child_partitions_def,
dynamic_partitions_store=dynamic_partitions_store,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,18 @@ def to_storage_dict(
for key, value in self.partitions_subsets_by_asset_key.items()
},
"serializable_partitions_def_ids_by_asset_key": {
key.to_user_string(): value.partitions_def.get_serializable_unique_identifier(
key.to_user_string(): check.not_none(
self._asset_graph.get_partitions_def(key)
).get_serializable_unique_identifier(
dynamic_partitions_store=dynamic_partitions_store
)
for key, value in self.partitions_subsets_by_asset_key.items()
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"partitions_def_class_names_by_asset_key": {
key.to_user_string(): value.partitions_def.__class__.__name__
for key, value in self.partitions_subsets_by_asset_key.items()
key.to_user_string(): check.not_none(
self._asset_graph.get_partitions_def(key)
).__class__.__name__
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"non_partitioned_asset_keys": [
key.to_user_string() for key in self._non_partitioned_asset_keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def get_downstream_partition_keys(
downstream_partition_key_subset = (
partition_mapping.get_downstream_partitions_for_partitions(
from_asset.partitions_def.empty_subset().with_partition_keys([partition_key]),
from_asset.partitions_def,
downstream_partitions_def=to_partitions_def,
dynamic_partitions_store=self.instance,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import lru_cache, reduce
from typing import (
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Expand Down Expand Up @@ -216,7 +215,7 @@ def __init__(self, partitions_defs: Mapping[str, PartitionsDefinition]):

@property
def partitions_subset_class(self) -> Type["PartitionsSubset"]:
return MultiPartitionsSubset
return DefaultPartitionsSubset

def get_serializable_unique_identifier(
self, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None
Expand Down Expand Up @@ -509,33 +508,6 @@ def get_num_partitions(
return reduce(lambda x, y: x * y, dimension_counts, 1)


class MultiPartitionsSubset(DefaultPartitionsSubset):
def __init__(
self,
partitions_def: MultiPartitionsDefinition,
subset: Optional[Set[str]] = None,
):
check.inst_param(partitions_def, "partitions_def", MultiPartitionsDefinition)
subset = (
set(
[
partitions_def.get_partition_key_from_str(key)
for key in subset
if MULTIPARTITION_KEY_DELIMITER in key
]
)
if subset
else set()
)
super(MultiPartitionsSubset, self).__init__(partitions_def, subset)

def with_partition_keys(self, partition_keys: Iterable[str]) -> "MultiPartitionsSubset":
return MultiPartitionsSubset(
cast(MultiPartitionsDefinition, self._partitions_def),
self._subset | set(partition_keys),
)


def get_tags_from_multi_partition_key(multi_partition_key: MultiPartitionKey) -> Mapping[str, str]:
check.inst_param(multi_partition_key, "multi_partition_key", MultiPartitionKey)

Expand Down
Loading

0 comments on commit c4e4110

Please sign in to comment.