Skip to content

Commit

Permalink
make DatetimeFieldSerializer serialize timezone
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 16, 2023
1 parent 31ef4da commit cd6e384
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from dagster._serdes.serdes import (
FieldSerializer,
NamedTupleSerializer,
UnpackContext,
WhitelistMap,
pack_value,
unpack_value,
)
from dagster._utils import utc_datetime_from_timestamp
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE
Expand All @@ -59,22 +63,54 @@
from .partition_key_range import PartitionKeyRange


# UTCTimestampWithTimezone is used to preserve timezone information when serializing.
# We can't store datetime.isoformat() because it only preserves UTC offsets, which vary depending on
# daylight savings time.
@whitelist_for_serdes
class UTCTimestampWithTimezone(NamedTuple):
datetime_float: float
timezone: str


class DatetimeFieldSerializer(FieldSerializer):
"""Serializes datetime objects to and from floats."""

def pack(self, datetime: Optional[datetime], **_kwargs) -> Optional[float]:
def pack(
self, datetime: Optional[datetime], whitelist_map: WhitelistMap, descent_path: str
) -> Optional[Mapping[str, Any]]:
if datetime:
check.invariant(datetime.tzinfo is not None)
pendulum_datetime = pendulum.instance(datetime, tz=datetime.tzinfo)
return pack_value(
UTCTimestampWithTimezone(
datetime.timestamp(), str(pendulum_datetime.timezone.name)
),
whitelist_map,
descent_path,
)

# Get the timestamp in UTC
return datetime.timestamp() if datetime else None
return None

def unpack(
self,
datetime_float: Optional[float],
**_kwargs,
serialized_datetime_with_timezone: Optional[Mapping[str, Any]],
whitelist_map: WhitelistMap,
context: UnpackContext,
) -> Optional[datetime]:
return utc_datetime_from_timestamp(datetime_float) if datetime_float else None
if serialized_datetime_with_timezone:
unpacked = unpack_value(
serialized_datetime_with_timezone,
UTCTimestampWithTimezone,
whitelist_map,
context,
)
unpacked_datetime = pendulum.instance(
utc_datetime_from_timestamp(unpacked.datetime_float), tz=unpacked.timezone
).in_tz(tz=unpacked.timezone)
check.invariant(unpacked_datetime.tzinfo is not None)
return unpacked_datetime

return None


@whitelist_for_serdes(
Expand Down Expand Up @@ -1836,6 +1872,9 @@ class TimeWindowPartitionsSubsetSerializer(NamedTupleSerializer):
# is needed to improve performance. When serializing, we want to serialize the number of
# partitions, so we force calculation.
def before_pack(self, value: "TimeWindowPartitionsSubset") -> "TimeWindowPartitionsSubset":
# value.num_partitions will calculate the number of partitions if the field is None
# We want to check if the field is None and replace the value with the calculated value
# for serialization
if value._asdict()["num_partitions"] is None:
return TimeWindowPartitionsSubset(
partitions_def=value.partitions_def,
Expand Down Expand Up @@ -1867,23 +1906,15 @@ def __new__(
num_partitions: Optional[int],
included_time_windows: Sequence[TimeWindow],
):
check.sequence_param(included_time_windows, "included_time_windows", of_type=TimeWindow)

time_windows_with_timezone = [
TimeWindow(
start=pendulum.instance(tw.start).in_tz(tz=partitions_def.timezone),
end=pendulum.instance(tw.end).in_tz(tz=partitions_def.timezone),
)
for tw in included_time_windows
]

return super(TimeWindowPartitionsSubset, cls).__new__(
cls,
partitions_def=check.inst_param(
partitions_def, "partitions_def", TimeWindowPartitionsDefinition
),
num_partitions=check.opt_int_param(num_partitions, "num_partitions"),
included_time_windows=time_windows_with_timezone,
included_time_windows=check.sequence_param(
included_time_windows, "included_time_windows", of_type=TimeWindow
),
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,22 @@ def test_time_window_partitions_subset_serialization_deserialization(
timezone=partitions_def.timezone,
end_offset=partitions_def.end_offset,
)
subset = TimeWindowPartitionsSubset.empty_subset(
time_window_partitions_def
).with_partition_keys(["2023-01-01"])
subset = cast(
TimeWindowPartitionsSubset,
TimeWindowPartitionsSubset.empty_subset(time_window_partitions_def).with_partition_keys(
["2023-01-01"]
),
)

deserialized = deserialize_value(
serialize_value(cast(TimeWindowPartitionsSubset, subset)), TimeWindowPartitionsSubset
)
assert deserialized == subset
assert deserialized.get_partition_keys() == ["2023-01-01"]
assert (
deserialized.included_time_windows[0].start.tzinfo
== subset.included_time_windows[0].start.tzinfo
)


def test_time_window_partitions_subset_num_partitions_serialization():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,9 +1330,9 @@ def test_time_window_partitions_def_serialization(partitions_def):
timezone=partitions_def.timezone,
end_offset=partitions_def.end_offset,
)
assert (
deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def
)
deserialized = deserialize_value(serialize_value(time_window_partitions_def))
assert deserialized == time_window_partitions_def
assert deserialized.start.tzinfo == time_window_partitions_def.start.tzinfo


def test_cannot_pickle_time_window_partitions_def():
Expand Down

0 comments on commit cd6e384

Please sign in to comment.