From e5bd1074db5e20a3b2689e30eaee956f880dcf46 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Tue, 7 Nov 2023 17:01:07 -0800 Subject: [PATCH] fix more tests --- .../definitions/time_window_partitions.py | 22 ++++++++++++++++--- .../dagster/dagster/_serdes/serdes.py | 6 ++++- .../test_partitions_subset.py | 20 +++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index 992f8a3d92a7c..10e899e5df0eb 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -31,7 +31,12 @@ from dagster._serdes import ( whitelist_for_serdes, ) -from dagster._serdes.serdes import FieldSerializer, deserialize_value, serialize_value +from dagster._serdes.serdes import ( + FieldSerializer, + NamedTupleSerializer, + deserialize_value, + serialize_value, +) from dagster._utils import utc_datetime_from_timestamp from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE from dagster._utils.schedules import ( @@ -1871,10 +1876,22 @@ def unpack( ).to_time_window_partitions_def() +class TimeWindowPartitionsSubsetSerializer(NamedTupleSerializer): + # TimeWindowPartitionsSubsets have custom logic to delay calculating num_partitions until it + # is needed to improve performance. When serializing, we want to serialize the number of + # partitions, so we force calculatation. + def before_pack(self, value: "TimeWindowPartitionsSubset") -> "TimeWindowPartitionsSubset": + if value._asdict()["num_partitions"] is None: + return value._replace(num_partitions=value.num_partitions) + return value + + @whitelist_for_serdes( - field_serializers={"partitions_def": TimeWindowPartitionsDefinitionSerializer} + field_serializers={"partitions_def": TimeWindowPartitionsDefinitionSerializer}, + serializer=TimeWindowPartitionsSubsetSerializer, ) class TimeWindowPartitionsSubset( + BaseTimeWindowPartitionsSubset, NamedTuple( "_TimeWindowPartitionsSubset", [ @@ -1883,7 +1900,6 @@ class TimeWindowPartitionsSubset( ("included_time_windows", Sequence[TimeWindow]), ], ), - BaseTimeWindowPartitionsSubset, ): def __new__( cls, diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index 6a6c1561dd375..ec241ca79fa39 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -509,7 +509,7 @@ def pack( ) -> Dict[str, JsonSerializableValue]: packed: Dict[str, JsonSerializableValue] = {} packed["__class__"] = self.get_storage_name() - for key, inner_value in value._asdict().items(): + for key, inner_value in self.before_pack(value)._asdict().items(): if key in self.skip_when_empty_fields and inner_value in EMPTY_VALUES_TO_SKIP: continue storage_key = self.storage_field_names.get(key, key) @@ -531,6 +531,10 @@ def pack( packed = self.after_pack(**packed) return packed + # Hook: Modify the contents of the named tuple before packing + def before_pack(self, value: T_NamedTuple) -> T_NamedTuple: + return value + # Hook: Modify the contents of the packed, json-serializable dict before it is converted to a # string. def after_pack(self, **packed_dict: JsonSerializableValue) -> Dict[str, JsonSerializableValue]: diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py index 9b94d70ce7d31..02f9ed7d81d22 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py @@ -110,3 +110,23 @@ def test_time_window_partitions_subset_serialization_deserialization( ) assert deserialized == subset assert deserialized.get_partition_keys() == ["2023-01-01"] + + +def test_time_window_partitions_subset_num_partitions_serialization(): + daily_partitions_def = DailyPartitionsDefinition("2023-01-01") + time_partitions_def = TimeWindowPartitionsDefinition( + start=daily_partitions_def.start, + end=daily_partitions_def.end, + cron_schedule="0 0 * * *", + fmt="%Y-%m-%d", + timezone=daily_partitions_def.timezone, + end_offset=daily_partitions_def.end_offset, + ) + + tw = time_partitions_def.time_window_for_partition_key("2023-01-01") + + subset = TimeWindowPartitionsSubset( + time_partitions_def, num_partitions=None, included_time_windows=[tw] + ) + deserialized = deserialize_value(serialize_value(subset), TimeWindowPartitionsSubset) + assert deserialized._asdict()["num_partitions"] is not None