From cd6e384d9b90e1ac35f8627a50b56a01fce61a00 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Wed, 15 Nov 2023 15:20:21 -0800 Subject: [PATCH] make DatetimeFieldSerializer serialize timezone --- .../definitions/time_window_partitions.py | 65 ++++++++++++++----- .../test_partitions_subset.py | 13 +++- .../test_time_window_partitions.py | 6 +- 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index 0b341c426e34e..bbaa29acc7e71 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -35,6 +35,10 @@ from dagster._serdes.serdes import ( FieldSerializer, NamedTupleSerializer, + UnpackContext, + WhitelistMap, + pack_value, + unpack_value, ) from dagster._utils import utc_datetime_from_timestamp from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE @@ -59,22 +63,54 @@ from .partition_key_range import PartitionKeyRange +# UTCTimestampWithTimezone is used to preserve timezone information when serializing. +# We can't store datetime.isoformat() because it only preserves UTC offsets, which vary depending on +# daylight savings time. +@whitelist_for_serdes +class UTCTimestampWithTimezone(NamedTuple): + datetime_float: float + timezone: str + + class DatetimeFieldSerializer(FieldSerializer): """Serializes datetime objects to and from floats.""" - def pack(self, datetime: Optional[datetime], **_kwargs) -> Optional[float]: + def pack( + self, datetime: Optional[datetime], whitelist_map: WhitelistMap, descent_path: str + ) -> Optional[Mapping[str, Any]]: if datetime: check.invariant(datetime.tzinfo is not None) + pendulum_datetime = pendulum.instance(datetime, tz=datetime.tzinfo) + return pack_value( + UTCTimestampWithTimezone( + datetime.timestamp(), str(pendulum_datetime.timezone.name) + ), + whitelist_map, + descent_path, + ) - # Get the timestamp in UTC - return datetime.timestamp() if datetime else None + return None def unpack( self, - datetime_float: Optional[float], - **_kwargs, + serialized_datetime_with_timezone: Optional[Mapping[str, Any]], + whitelist_map: WhitelistMap, + context: UnpackContext, ) -> Optional[datetime]: - return utc_datetime_from_timestamp(datetime_float) if datetime_float else None + if serialized_datetime_with_timezone: + unpacked = unpack_value( + serialized_datetime_with_timezone, + UTCTimestampWithTimezone, + whitelist_map, + context, + ) + unpacked_datetime = pendulum.instance( + utc_datetime_from_timestamp(unpacked.datetime_float), tz=unpacked.timezone + ).in_tz(tz=unpacked.timezone) + check.invariant(unpacked_datetime.tzinfo is not None) + return unpacked_datetime + + return None @whitelist_for_serdes( @@ -1836,6 +1872,9 @@ class TimeWindowPartitionsSubsetSerializer(NamedTupleSerializer): # is needed to improve performance. When serializing, we want to serialize the number of # partitions, so we force calculation. def before_pack(self, value: "TimeWindowPartitionsSubset") -> "TimeWindowPartitionsSubset": + # value.num_partitions will calculate the number of partitions if the field is None + # We want to check if the field is None and replace the value with the calculated value + # for serialization if value._asdict()["num_partitions"] is None: return TimeWindowPartitionsSubset( partitions_def=value.partitions_def, @@ -1867,23 +1906,15 @@ def __new__( num_partitions: Optional[int], included_time_windows: Sequence[TimeWindow], ): - check.sequence_param(included_time_windows, "included_time_windows", of_type=TimeWindow) - - time_windows_with_timezone = [ - TimeWindow( - start=pendulum.instance(tw.start).in_tz(tz=partitions_def.timezone), - end=pendulum.instance(tw.end).in_tz(tz=partitions_def.timezone), - ) - for tw in included_time_windows - ] - return super(TimeWindowPartitionsSubset, cls).__new__( cls, partitions_def=check.inst_param( partitions_def, "partitions_def", TimeWindowPartitionsDefinition ), num_partitions=check.opt_int_param(num_partitions, "num_partitions"), - included_time_windows=time_windows_with_timezone, + included_time_windows=check.sequence_param( + included_time_windows, "included_time_windows", of_type=TimeWindow + ), ) @property diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py index dca4797644170..e8477ed035888 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitions_subset.py @@ -105,15 +105,22 @@ def test_time_window_partitions_subset_serialization_deserialization( timezone=partitions_def.timezone, end_offset=partitions_def.end_offset, ) - subset = TimeWindowPartitionsSubset.empty_subset( - time_window_partitions_def - ).with_partition_keys(["2023-01-01"]) + subset = cast( + TimeWindowPartitionsSubset, + TimeWindowPartitionsSubset.empty_subset(time_window_partitions_def).with_partition_keys( + ["2023-01-01"] + ), + ) deserialized = deserialize_value( serialize_value(cast(TimeWindowPartitionsSubset, subset)), TimeWindowPartitionsSubset ) assert deserialized == subset assert deserialized.get_partition_keys() == ["2023-01-01"] + assert ( + deserialized.included_time_windows[0].start.tzinfo + == subset.included_time_windows[0].start.tzinfo + ) def test_time_window_partitions_subset_num_partitions_serialization(): diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index 2dc6e65cf2df3..5d0d29ff0dd42 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -1330,9 +1330,9 @@ def test_time_window_partitions_def_serialization(partitions_def): timezone=partitions_def.timezone, end_offset=partitions_def.end_offset, ) - assert ( - deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def - ) + deserialized = deserialize_value(serialize_value(time_window_partitions_def)) + assert deserialized == time_window_partitions_def + assert deserialized.start.tzinfo == time_window_partitions_def.start.tzinfo def test_cannot_pickle_time_window_partitions_def():