From 584ac0b9ff8b62aab01786236396c28ffae70c50 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Fri, 3 Nov 2023 16:48:43 -0700 Subject: [PATCH] first stab continue time window partitions subset changes asset backfill serialization partition mapping update continue refactor fix more tests more test fixes fix partition mapping tests adjust test fix more tests add tests --- .../implementation/fetch_assets.py | 31 +++- .../_core/definitions/asset_daemon_cursor.py | 2 +- .../_core/definitions/asset_graph_subset.py | 4 +- .../dagster/_core/definitions/data_time.py | 2 +- .../dagster/_core/definitions/partition.py | 24 ++- .../_core/definitions/partition_mapping.py | 14 +- .../time_window_partition_mapping.py | 10 +- .../definitions/time_window_partitions.py | 172 ++++++++++-------- .../dagster/_core/execution/context/input.py | 5 +- .../test_partitions_subset.py | 31 +++- .../test_time_window_partitions.py | 12 +- 11 files changed, 184 insertions(+), 123 deletions(-) diff --git a/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py b/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py index cab8902d1a677..3eeae87dc6de2 100644 --- a/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py +++ b/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py @@ -469,7 +469,7 @@ def build_partition_statuses( graphene_ranges = [] for r in ranges: partition_key_range = cast( - TimeWindowPartitionsDefinition, materialized_partitions_subset.partitions_def + TimeWindowPartitionsDefinition, materialized_partitions_subset.get_partitions_def() ).get_partition_key_range_for_time_window(r.time_window) graphene_ranges.append( GrapheneTimePartitionRangeStatus( @@ -518,15 +518,28 @@ def get_2d_run_length_encoded_partitions( GrapheneMultiPartitionStatuses, ) - if ( - not isinstance(materialized_partitions_subset.partitions_def, MultiPartitionsDefinition) - or not isinstance(failed_partitions_subset.partitions_def, MultiPartitionsDefinition) - or not isinstance(in_progress_partitions_subset.partitions_def, MultiPartitionsDefinition) - ): - check.failed("Can only fetch 2D run length encoded partitions for multipartitioned assets") + partitions_defs = set( + [ + subset.get_partitions_def() + for subset in [ + materialized_partitions_subset, + failed_partitions_subset, + in_progress_partitions_subset, + ] + ] + ) + check.invariant( + len(partitions_defs) == 1, "All subsets should have the same partitions definition" + ) + + partitions_def = cast(MultiPartitionsDefinition, next(iter(partitions_defs))) + check.invariant( + isinstance(partitions_def, MultiPartitionsDefinition), + "Partitions definition should be multipartitioned", + ) - primary_dim = materialized_partitions_subset.partitions_def.primary_dimension - secondary_dim = materialized_partitions_subset.partitions_def.secondary_dimension + primary_dim = partitions_def.primary_dimension + secondary_dim = partitions_def.secondary_dimension dim2_materialized_partition_subset_by_dim1: Dict[str, PartitionsSubset] = defaultdict( lambda: secondary_dim.partitions_def.empty_subset() diff --git a/python_modules/dagster/dagster/_core/definitions/asset_daemon_cursor.py b/python_modules/dagster/dagster/_core/definitions/asset_daemon_cursor.py index 3414f92c75b49..10c90b6569440 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_daemon_cursor.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_daemon_cursor.py @@ -231,7 +231,7 @@ def from_serialized(cls, cursor: str, asset_graph: AssetGraph) -> "AssetDaemonCu and isinstance(partitions_def, TimeWindowPartitionsDefinition) and any( time_window.start < partitions_def.start - for time_window in subset.included_time_windows + for time_window in subset.get_included_time_windows() ) ): subset = partitions_def.empty_subset() diff --git a/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py b/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py index 773ac4b159361..bdf4539770d2e 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py @@ -94,13 +94,13 @@ def to_storage_dict( for key, value in self.partitions_subsets_by_asset_key.items() }, "serializable_partitions_def_ids_by_asset_key": { - key.to_user_string(): value.partitions_def.get_serializable_unique_identifier( + key.to_user_string(): value.get_partitions_def().get_serializable_unique_identifier( dynamic_partitions_store=dynamic_partitions_store ) for key, value in self.partitions_subsets_by_asset_key.items() }, "partitions_def_class_names_by_asset_key": { - key.to_user_string(): value.partitions_def.__class__.__name__ + key.to_user_string(): value.get_partitions_def().__class__.__name__ for key, value in self.partitions_subsets_by_asset_key.items() }, "non_partitioned_asset_keys": [ diff --git a/python_modules/dagster/dagster/_core/definitions/data_time.py b/python_modules/dagster/dagster/_core/definitions/data_time.py index ecd969308d8c5..8c913c2a8c457 100644 --- a/python_modules/dagster/dagster/_core/definitions/data_time.py +++ b/python_modules/dagster/dagster/_core/definitions/data_time.py @@ -92,7 +92,7 @@ def _calculate_data_time_partitioned( if not isinstance(partition_subset, BaseTimeWindowPartitionsSubset): check.failed(f"Invalid partition subset {type(partition_subset)}") - sorted_time_windows = sorted(partition_subset.included_time_windows) + sorted_time_windows = sorted(partition_subset.get_included_time_windows()) # no time windows, no data if len(sorted_time_windows) == 0: return None diff --git a/python_modules/dagster/dagster/_core/definitions/partition.py b/python_modules/dagster/dagster/_core/definitions/partition.py index 1d8e181408335..5f267bbefa45b 100644 --- a/python_modules/dagster/dagster/_core/definitions/partition.py +++ b/python_modules/dagster/dagster/_core/definitions/partition.py @@ -977,7 +977,7 @@ def with_partition_key_range( dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> "PartitionsSubset[T_str]": return self.with_partition_keys( - self.partitions_def.get_partition_keys_in_range( + self.get_partitions_def().get_partition_keys_in_range( partition_key_range, dynamic_partitions_store=dynamic_partitions_store ) ) @@ -989,16 +989,22 @@ def __or__(self, other: "PartitionsSubset") -> "PartitionsSubset[T_str]": def __sub__(self, other: "PartitionsSubset") -> "PartitionsSubset[T_str]": if self is other: - return self.partitions_def.empty_subset() - return self.partitions_def.empty_subset().with_partition_keys( - set(self.get_partition_keys()).difference(set(other.get_partition_keys())) + return self.get_partitions_def().empty_subset() + return ( + self.get_partitions_def() + .empty_subset() + .with_partition_keys( + set(self.get_partition_keys()).difference(set(other.get_partition_keys())) + ) ) def __and__(self, other: "PartitionsSubset") -> "PartitionsSubset[T_str]": if self is other: return self - return self.partitions_def.empty_subset().with_partition_keys( - set(self.get_partition_keys()) & set(other.get_partition_keys()) + return ( + self.get_partitions_def() + .empty_subset() + .with_partition_keys(set(self.get_partition_keys()) & set(other.get_partition_keys())) ) @abstractmethod @@ -1023,9 +1029,8 @@ def can_deserialize( ) -> bool: ... - @property @abstractmethod - def partitions_def(self) -> PartitionsDefinition[T_str]: + def get_partitions_def(self) -> PartitionsDefinition[T_str]: ... @abstractmethod @@ -1186,8 +1191,7 @@ def can_deserialize( data.get("subset") is not None and data.get("version") == cls.SERIALIZATION_VERSION ) - @property - def partitions_def(self) -> PartitionsDefinition[T_str]: + def get_partitions_def(self) -> PartitionsDefinition[T_str]: return self._partitions_def def __eq__(self, other: object) -> bool: diff --git a/python_modules/dagster/dagster/_core/definitions/partition_mapping.py b/python_modules/dagster/dagster/_core/definitions/partition_mapping.py index 5a5dddbbc6017..f317c11fb3eb0 100644 --- a/python_modules/dagster/dagster/_core/definitions/partition_mapping.py +++ b/python_modules/dagster/dagster/_core/definitions/partition_mapping.py @@ -114,7 +114,7 @@ def get_upstream_mapped_partitions_result_for_partitions( if downstream_partitions_subset is None: check.failed("downstream asset is not partitioned") - if downstream_partitions_subset.partitions_def == upstream_partitions_def: + if downstream_partitions_subset.get_partitions_def() == upstream_partitions_def: return UpstreamPartitionsResult(downstream_partitions_subset, []) upstream_partition_keys = set( @@ -141,7 +141,7 @@ def get_downstream_partitions_for_partitions( if upstream_partitions_subset is None: check.failed("upstream asset is not partitioned") - if upstream_partitions_subset.partitions_def == downstream_partitions_def: + if upstream_partitions_subset.get_partitions_def() == downstream_partitions_def: return upstream_partitions_subset upstream_partition_keys = set(upstream_partitions_subset.get_partition_keys()) @@ -218,8 +218,10 @@ def get_downstream_partitions_for_partitions( current_time: Optional[datetime] = None, dynamic_partitions_store: Optional[DynamicPartitionsStore] = None, ) -> PartitionsSubset: - last_upstream_partition = upstream_partitions_subset.partitions_def.get_last_partition_key( - current_time=current_time, dynamic_partitions_store=dynamic_partitions_store + last_upstream_partition = ( + upstream_partitions_subset.get_partitions_def().get_last_partition_key( + current_time=current_time, dynamic_partitions_store=dynamic_partitions_store + ) ) if last_upstream_partition and last_upstream_partition in upstream_partitions_subset: return downstream_partitions_def.subset_with_all_partitions( @@ -495,7 +497,7 @@ def get_upstream_mapped_partitions_result_for_partitions( check.failed("downstream asset is not partitioned") result = self._get_dependency_partitions_subset( - cast(MultiPartitionsDefinition, downstream_partitions_subset.partitions_def), + cast(MultiPartitionsDefinition, downstream_partitions_subset.get_partitions_def()), downstream_partitions_subset, cast(MultiPartitionsDefinition, upstream_partitions_def), a_upstream_of_b=False, @@ -519,7 +521,7 @@ def get_downstream_partitions_for_partitions( check.failed("upstream asset is not partitioned") result = self._get_dependency_partitions_subset( - cast(MultiPartitionsDefinition, upstream_partitions_subset.partitions_def), + cast(MultiPartitionsDefinition, upstream_partitions_subset.get_partitions_def()), upstream_partitions_subset, cast(MultiPartitionsDefinition, downstream_partitions_def), a_upstream_of_b=True, diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partition_mapping.py b/python_modules/dagster/dagster/_core/definitions/time_window_partition_mapping.py index 9d4f0d7ef1dd8..5fb1c33fc45ed 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partition_mapping.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partition_mapping.py @@ -116,7 +116,7 @@ def get_upstream_mapped_partitions_result_for_partitions( check.failed("downstream_partitions_subset must be a BaseTimeWindowPartitionsSubset") return self._map_partitions( - downstream_partitions_subset.partitions_def, + downstream_partitions_subset.get_partitions_def(), upstream_partitions_def, downstream_partitions_subset, start_offset=self.start_offset, @@ -137,7 +137,7 @@ def get_downstream_partitions_for_partitions( if not provided. """ return self._map_partitions( - upstream_partitions_subset.partitions_def, + upstream_partitions_subset.get_partitions_def(), downstream_partitions_def, upstream_partitions_subset, end_offset=-self.start_offset, @@ -200,7 +200,7 @@ def _map_partitions( last_window = to_partitions_def.get_last_partition_window(current_time=current_time) time_windows = [] - for from_partition_time_window in from_partitions_subset.included_time_windows: + for from_partition_time_window in from_partitions_subset.get_included_time_windows(): from_start_dt, from_end_dt = from_partition_time_window offsetted_start_dt = _offsetted_datetime( @@ -363,7 +363,7 @@ def _do_cheap_partition_mapping_if_possible( TimeWindowPartitionsSubset( partitions_def=to_partitions_def, num_partitions=None, - included_time_windows=from_partitions_subset.included_time_windows, + included_time_windows=from_partitions_subset.get_included_time_windows(), ), [], ) @@ -377,7 +377,7 @@ def _do_cheap_partition_mapping_if_possible( else: required_but_nonexistent_partition_keys = [ pk - for time_window in from_partitions_subset.included_time_windows + for time_window in from_partitions_subset.get_included_time_windows() for pk in to_partitions_def.get_partition_keys_in_time_window( time_window=time_window ) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index 3f0091a0fa61d..a1f2e9632d5df 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -28,6 +28,9 @@ from dagster._annotations import PublicAttr, public from dagster._core.instance import DynamicPartitionsStore from dagster._serdes import whitelist_for_serdes +from dagster._serdes import ( + whitelist_for_serdes, +) from dagster._serdes.serdes import FieldSerializer from dagster._utils import utc_datetime_from_timestamp from dagster._utils.cached_method import cached_method @@ -71,6 +74,9 @@ def unpack( return utc_datetime_from_timestamp(datetime_float) if datetime_float else None +@whitelist_for_serdes( + field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} +) class TimeWindow(NamedTuple): """An interval that is closed at the start and open at the end. @@ -1457,13 +1463,16 @@ class BaseTimeWindowPartitionsSubset(PartitionsSubset): # This will ensure that we can gracefully degrade when deserializing old data. SERIALIZATION_VERSION = 1 - @property @abstractmethod - def included_time_windows(self) -> Sequence[TimeWindow]: + def get_included_time_windows(self) -> Sequence[TimeWindow]: ... - @abstractproperty - def num_partitions(self) -> int: + @abstractmethod + def get_num_partitions(self) -> int: + ... + + @abstractmethod + def get_partitions_def(self) -> TimeWindowPartitionsDefinition: ... def _get_partition_time_windows_not_in_subset( @@ -1474,43 +1483,45 @@ def _get_partition_time_windows_not_in_subset( Each time window is a single partition. """ first_tw = cast( - TimeWindowPartitionsDefinition, self.partitions_def + TimeWindowPartitionsDefinition, self.get_partitions_def() ).get_first_partition_window(current_time=current_time) last_tw = cast( - TimeWindowPartitionsDefinition, self.partitions_def + TimeWindowPartitionsDefinition, self.get_partitions_def() ).get_last_partition_window(current_time=current_time) if not first_tw or not last_tw: check.failed("No partitions found") - if len(self.included_time_windows) == 0: + if len(self.get_included_time_windows()) == 0: return [TimeWindow(first_tw.start, last_tw.end)] time_windows = [] - if first_tw.start < self.included_time_windows[0].start: - time_windows.append(TimeWindow(first_tw.start, self.included_time_windows[0].start)) + if first_tw.start < self.get_included_time_windows()[0].start: + time_windows.append( + TimeWindow(first_tw.start, self.get_included_time_windows()[0].start) + ) - for i in range(len(self.included_time_windows) - 1): - if self.included_time_windows[i].start >= last_tw.end: + for i in range(len(self.get_included_time_windows()) - 1): + if self.get_included_time_windows()[i].start >= last_tw.end: break - if self.included_time_windows[i].end < last_tw.end: - if self.included_time_windows[i + 1].start <= last_tw.end: + if self.get_included_time_windows()[i].end < last_tw.end: + if self.get_included_time_windows()[i + 1].start <= last_tw.end: time_windows.append( TimeWindow( - self.included_time_windows[i].end, - self.included_time_windows[i + 1].start, + self.get_included_time_windows()[i].end, + self.get_included_time_windows()[i + 1].start, ) ) else: time_windows.append( TimeWindow( - self.included_time_windows[i].end, + self.get_included_time_windows()[i].end, last_tw.end, ) ) - if last_tw.end > self.included_time_windows[-1].end: - time_windows.append(TimeWindow(self.included_time_windows[-1].end, last_tw.end)) + if last_tw.end > self.get_included_time_windows()[-1].end: + time_windows.append(TimeWindow(self.get_included_time_windows()[-1].end, last_tw.end)) return time_windows @@ -1523,7 +1534,7 @@ def get_partition_keys_not_in_subset( for tw in self._get_partition_time_windows_not_in_subset(current_time): partition_keys.extend( cast( - TimeWindowPartitionsDefinition, self.partitions_def + TimeWindowPartitionsDefinition, self.get_partitions_def() ).get_partition_keys_in_time_window(tw) ) return partition_keys @@ -1553,9 +1564,9 @@ def get_partition_key_ranges( ) -> Sequence[PartitionKeyRange]: return [ cast( - TimeWindowPartitionsDefinition, self.partitions_def + TimeWindowPartitionsDefinition, self.get_partitions_def() ).get_partition_key_range_for_time_window(window) - for window in self.included_time_windows + for window in self.get_included_time_windows() ] def _add_partitions_to_time_windows( @@ -1569,7 +1580,7 @@ def _add_partitions_to_time_windows( """ result_windows = [*initial_windows] time_windows = cast( - TimeWindowPartitionsDefinition, self.partitions_def + TimeWindowPartitionsDefinition, self.get_partitions_def() ).time_windows_for_partition_keys(frozenset(partition_keys), validate=validate) num_added_partitions = 0 for window in sorted(time_windows): @@ -1621,9 +1632,9 @@ def serialize(self) -> str: # stable serialization between identical subsets "time_windows": [ (window.start.timestamp(), window.end.timestamp()) - for window in self.included_time_windows + for window in self.get_included_time_windows() ], - "num_partitions": self.num_partitions, + "num_partitions": self.get_num_partitions(), } ) @@ -1699,24 +1710,24 @@ def can_deserialize( ) def __len__(self) -> int: - return self.num_partitions + return self.get_num_partitions() def __contains__(self, partition_key: str) -> bool: time_window = cast( - TimeWindowPartitionsDefinition, self.partitions_def + TimeWindowPartitionsDefinition, self.get_partitions_def() ).time_window_for_partition_key(partition_key) return any( time_window.start >= included_time_window.start and time_window.start < included_time_window.end - for included_time_window in self.included_time_windows + for included_time_window in self.get_included_time_windows() ) def __eq__(self, other): return ( isinstance(other, BaseTimeWindowPartitionsSubset) - and self.partitions_def == other.partitions_def - and self.included_time_windows == other.included_time_windows + and self.get_partitions_def() == other.get_partitions_def() + and self.get_included_time_windows() == other.get_included_time_windows() ) @@ -1733,6 +1744,9 @@ def __init__( included_partition_keys, "included_partition_keys", of_type=str ) + def get_partitions_def(self) -> TimeWindowPartitionsDefinition: + return self._partitions_def + def with_partition_keys( self, partition_keys: Iterable[str] ) -> "BaseTimeWindowPartitionsSubset": @@ -1742,9 +1756,8 @@ def with_partition_keys( included_partition_keys=new_partitions, ) - @property @cached_method - def included_time_windows(self) -> Sequence[TimeWindow]: + def get_included_time_windows(self) -> Sequence[TimeWindow]: result_time_windows, _ = self._add_partitions_to_time_windows( initial_windows=[], partition_keys=list(check.not_none(self._included_partition_keys)), @@ -1752,13 +1765,8 @@ def included_time_windows(self) -> Sequence[TimeWindow]: ) return result_time_windows - @property - def partitions_def(self) -> TimeWindowPartitionsDefinition: - return self._partitions_def - - @property @cached_method - def num_partitions(self) -> int: + def get_num_partitions(self) -> int: return len(self._included_partition_keys) @public @@ -1774,11 +1782,11 @@ def first_start(self) -> datetime: next(iter(self._included_partition_keys)) ) else: - if len(self.included_time_windows) == 0: + if len(self.get_included_time_windows()) == 0: check.failed( f"Empty subset. self._included_partition_keys: {self._included_partition_keys}" ) - return self.included_time_windows[0].start + return self.get_included_time_windows()[0].start @property def is_empty(self) -> bool: @@ -1818,7 +1826,7 @@ def __contains__(self, partition_key: str) -> bool: def __eq__(self, other): return ( isinstance(other, TimePartitionKeyPartitionsSubset) - and self.partitions_def == other.partitions_def + and self._partitions_def == other._partitions_def and self._included_partition_keys == other._included_partition_keys ) or super(TimePartitionKeyPartitionsSubset, self).__eq__(other) @@ -1845,48 +1853,59 @@ def with_partitions_def( def resolve(self) -> "TimeWindowPartitionsSubset": return TimeWindowPartitionsSubset( partitions_def=self._partitions_def, - num_partitions=self.num_partitions, - included_time_windows=self.included_time_windows, + num_partitions=self.get_num_partitions(), + included_time_windows=self.get_included_time_windows(), ) -class TimeWindowPartitionsSubset(BaseTimeWindowPartitionsSubset): - def __init__( - self, +@whitelist_for_serdes +class TimeWindowPartitionsSubset( + BaseTimeWindowPartitionsSubset, + NamedTuple( + "_TimeWindowPartitionsSubset", + [ + ("partitions_def", TimeWindowPartitionsDefinition), + ("num_partitions", int), + ("included_time_windows", Sequence[TimeWindow]), + ], + ), +): + def __new__( + cls, partitions_def: TimeWindowPartitionsDefinition, num_partitions: Optional[int] = None, included_time_windows: Sequence[TimeWindow] = [], ): check.opt_int_param(num_partitions, "num_partitions") - self._partitions_def = check.inst_param( - partitions_def, "partitions_def", TimeWindowPartitionsDefinition - ) - self._num_partitions = ( - num_partitions - if num_partitions - else self._num_partitions_from_time_windows(partitions_def, included_time_windows) - ) - self._included_time_windows = check.sequence_param( - included_time_windows, "included_time_windows", of_type=TimeWindow + return super(TimeWindowPartitionsSubset, cls).__new__( + cls, + partitions_def=partitions_def, + num_partitions=( + num_partitions + if num_partitions + else cls._num_partitions_from_time_windows(partitions_def, included_time_windows) + ), + included_time_windows=check.sequence_param( + included_time_windows, "included_time_windows", of_type=TimeWindow + ), ) def get_included_time_windows(self) -> Sequence[TimeWindow]: - return self._included_time_windows + return self.included_time_windows - @property - def partitions_def(self) -> TimeWindowPartitionsDefinition: - return self._partitions_def + def get_partitions_def(self) -> TimeWindowPartitionsDefinition: + return self.partitions_def @property def first_start(self) -> datetime: """The start datetime of the earliest partition in the subset.""" - if len(self._included_time_windows) == 0: + if len(self.included_time_windows) == 0: check.failed("Empty subset") - return self._included_time_windows[0].start + return self.included_time_windows[0].start @property def is_empty(self) -> bool: - return len(self._included_time_windows) == 0 + return len(self.included_time_windows) == 0 def cheap_ends_before(self, dt: datetime, dt_cron_schedule: str) -> bool: """Performs a cheap calculation that checks whether the latest window in this subset ends @@ -1897,18 +1916,13 @@ def cheap_ends_before(self, dt: datetime, dt_cron_schedule: str) -> bool: Args: dt_cron_schedule (str): A cron schedule that dt is on one of the ticks of. """ - if self._included_time_windows is not None: - return self._included_time_windows[-1].end <= dt + if self.included_time_windows is not None: + return self.included_time_windows[-1].end <= dt return False - @property - def num_partitions(self) -> int: - return self._num_partitions - - @property - def included_time_windows(self) -> Sequence[TimeWindow]: - return self._included_time_windows + def get_num_partitions(self) -> int: + return self.num_partitions @classmethod def _num_partitions_from_time_windows( @@ -1923,18 +1937,18 @@ def _num_partitions_from_time_windows( def get_partition_keys(self, current_time: Optional[datetime] = None) -> Iterable[str]: return [ pk - for time_window in self._included_time_windows + for time_window in self.included_time_windows for pk in self.partitions_def.get_partition_keys_in_time_window(time_window) ] def with_partition_keys(self, partition_keys: Iterable[str]) -> "TimeWindowPartitionsSubset": result_windows, added_partitions = self._add_partitions_to_time_windows( - self._included_time_windows, list(partition_keys) + self.included_time_windows, list(partition_keys) ) return TimeWindowPartitionsSubset( - self._partitions_def, - num_partitions=self._num_partitions + added_partitions, + self.partitions_def, + num_partitions=self.num_partitions + added_partitions, included_time_windows=result_windows, ) @@ -1951,14 +1965,14 @@ def with_partitions_def( self, partitions_def: TimeWindowPartitionsDefinition ) -> "BaseTimeWindowPartitionsSubset": check.invariant( - partitions_def.cron_schedule == self._partitions_def.cron_schedule, + partitions_def.cron_schedule == self.partitions_def.cron_schedule, "num_partitions would become inaccurate if the partitions_defs had different cron" " schedules", ) return TimeWindowPartitionsSubset( partitions_def=partitions_def, - num_partitions=self._num_partitions, - included_time_windows=self._included_time_windows, + num_partitions=self.num_partitions, + included_time_windows=self.included_time_windows, ) def __repr__(self) -> str: @@ -2088,7 +2102,7 @@ def fetch_flattened_time_window_ranges( flattened_time_window_statuses = [] for status, subset in prioritized_subsets: subset_time_window_statuses = [ - PartitionTimeWindowStatus(tw, status) for tw in subset.included_time_windows + PartitionTimeWindowStatus(tw, status) for tw in subset.get_included_time_windows() ] flattened_time_window_statuses = _flatten( flattened_time_window_statuses, subset_time_window_statuses diff --git a/python_modules/dagster/dagster/_core/execution/context/input.py b/python_modules/dagster/dagster/_core/execution/context/input.py index e821e95fdb194..42d5318cb3faf 100644 --- a/python_modules/dagster/dagster/_core/execution/context/input.py +++ b/python_modules/dagster/dagster/_core/execution/context/input.py @@ -411,7 +411,7 @@ def asset_partitions_time_window(self) -> TimeWindow: " with time windows.", ) - time_windows = subset.included_time_windows + time_windows = subset.get_included_time_windows() if len(time_windows) != 1: check.failed( "Tried to access asset_partitions_time_window, but there are " @@ -674,8 +674,7 @@ def with_partition_key_range( def serialize(self) -> str: raise NotImplementedError() - @property - def partitions_def(self) -> "PartitionsDefinition": + def get_partitions_def(self) -> "PartitionsDefinition": raise NotImplementedError() def __len__(self) -> int: diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py index 104215fa886f1..2e088e4d208ed 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py @@ -1,12 +1,15 @@ +from typing import cast + import pytest from dagster import DailyPartitionsDefinition, MultiPartitionsDefinition, StaticPartitionsDefinition -from dagster._core.definitions.multi_dimensional_partitions import MultiPartitionsSubset from dagster._core.definitions.partition import DefaultPartitionsSubset from dagster._core.definitions.time_window_partitions import ( TimePartitionKeyPartitionsSubset, + TimeWindowPartitionsDefinition, TimeWindowPartitionsSubset, ) from dagster._core.errors import DagsterInvalidDeserializationVersionError +from dagster._serdes import deserialize_value, serialize_value def test_default_subset_cannot_deserialize_invalid_version(): @@ -76,6 +79,30 @@ def test_get_subset_type(): def test_empty_subsets(): - assert type(composite.empty_subset()) is MultiPartitionsSubset assert type(static_partitions.empty_subset()) is DefaultPartitionsSubset assert type(time_window_partitions.empty_subset()) is TimePartitionKeyPartitionsSubset + + +@pytest.mark.parametrize( + "partitions_def", + [ + (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + ], +) +def test_time_window_partitions_subset_serialization_deserialization( + partitions_def: DailyPartitionsDefinition, +): + time_window_partitions_def = TimeWindowPartitionsDefinition( + start=partitions_def.start, + end=partitions_def.end, + cron_schedule="0 0 * * *", + fmt="%Y-%m-%d", + timezone=partitions_def.timezone, + end_offset=partitions_def.end_offset, + ) + subset = TimeWindowPartitionsSubset.empty_subset( + time_window_partitions_def + ).with_partition_keys(["2023-01-01"]) + + assert deserialize_value(serialize_value(cast(TimeWindowPartitionsSubset, subset))) == subset diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index 099e8ca8432e2..9c03ecffbcf1e 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -787,19 +787,19 @@ def test_partition_subset_get_partition_keys_not_in_subset( assert partition_key in subset assert ( subset.get_partition_keys_not_in_subset( - current_time=partitions_def.end_time_for_partition_key(full_set_keys[-1]) + current_time=partitions_def.end_time_for_partition_key(full_set_keys[-1]), ) == expected_keys_not_in_subset ) assert ( cast( TimeWindowPartitionsSubset, partitions_def.deserialize_subset(subset.serialize()) - ).included_time_windows - == subset.included_time_windows + ).get_included_time_windows() + == subset.get_included_time_windows() ) expected_range_count = case_str.count("-+") + (1 if case_str[0] == "+" else 0) - assert len(subset.included_time_windows) == expected_range_count, case_str + assert len(subset.get_included_time_windows()) == expected_range_count, case_str assert len(subset) == case_str.count("+") @@ -936,7 +936,9 @@ def test_partition_subset_with_partition_keys( expected_range_count = updated_subset_str.count("-+") + ( 1 if updated_subset_str[0] == "+" else 0 ) - assert len(updated_subset.included_time_windows) == expected_range_count, updated_subset_str + assert ( + len(updated_subset.get_included_time_windows()) == expected_range_count + ), updated_subset_str assert len(updated_subset) == updated_subset_str.count("+")