From 828e5a49bcf58273214a1a40995461a60037acaa Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Thu, 2 Nov 2023 13:22:09 -0700 Subject: [PATCH] add whitelist for serdes --- .../definitions/time_window_partitions.py | 39 ++++++++++++++++--- .../test_time_window_partitions.py | 23 +++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index d886a127c2d96..6f6cc415aa787 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -28,6 +28,9 @@ from dagster._annotations import PublicAttr, public from dagster._core.instance import DynamicPartitionsStore from dagster._utils.cached_method import cached_method +from dagster._serdes import whitelist_for_serdes +from dagster._serdes.serdes import FieldSerializer +from dagster._utils import utc_datetime_from_timestamp from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE from dagster._utils.schedules import ( cron_string_iterator, @@ -50,6 +53,24 @@ from .partition_key_range import PartitionKeyRange +class DatetimeFieldSerializer(FieldSerializer): + """Serializes datetime objects to and from floats.""" + + def pack(self, datetime: Optional[datetime], **_kwargs) -> Optional[float]: + if datetime: + check.invariant(datetime.tzinfo is not None) + + # Get the timestamp in UTC + return datetime.timestamp() if datetime else None + + def unpack( + self, + datetime_float: Optional[float], + **_kwargs, + ) -> Optional[datetime]: + return utc_datetime_from_timestamp(datetime_float) if datetime_float else None + + class TimeWindow(NamedTuple): """An interval that is closed at the start and open at the end. @@ -62,15 +83,18 @@ class TimeWindow(NamedTuple): end: PublicAttr[datetime] +@whitelist_for_serdes( + field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} +) class TimeWindowPartitionsDefinition( PartitionsDefinition, NamedTuple( "_TimeWindowPartitionsDefinition", [ ("start", PublicAttr[datetime]), + ("fmt", PublicAttr[str]), ("timezone", PublicAttr[str]), ("end", PublicAttr[Optional[datetime]]), - ("fmt", PublicAttr[str]), ("end_offset", PublicAttr[int]), ("cron_schedule", PublicAttr[str]), ], @@ -108,20 +132,25 @@ def __new__( cls, start: Union[datetime, str], fmt: str, - end: Union[datetime, str, None] = None, - schedule_type: Optional[ScheduleType] = None, timezone: Optional[str] = None, + end: Union[datetime, str, None] = None, end_offset: int = 0, + cron_schedule: Optional[str] = None, + schedule_type: Optional[ScheduleType] = None, minute_offset: Optional[int] = None, hour_offset: Optional[int] = None, day_offset: Optional[int] = None, - cron_schedule: Optional[str] = None, ): check.opt_str_param(timezone, "timezone") timezone = timezone or "UTC" if isinstance(start, datetime): start_dt = pendulum.instance(start, tz=timezone) + + if start.tzinfo: + # Pendulum.instance does not override the timezone of the datetime object, + # so we convert it to the provided timezone + start_dt = start_dt.in_tz(tz=timezone) else: start_dt = pendulum.instance(datetime.strptime(start, fmt), tz=timezone) @@ -156,7 +185,7 @@ def __new__( ) return super(TimeWindowPartitionsDefinition, cls).__new__( - cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule + cls, start_dt, fmt, timezone, end_dt, end_offset, cron_schedule ) def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float: diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index db418ad82f189..a8356bbf52fab 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -25,6 +25,7 @@ TimeWindowPartitionsSubset, UnresolvedTimeWindowPartitionsSubset, ) +from dagster._serdes import deserialize_value, serialize_value from dagster._seven.compat.pendulum import create_pendulum_time from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE @@ -1287,3 +1288,25 @@ def test_partition_with_end_date( assert partitions_def.has_partition_key(first_partition_window[0]) assert partitions_def.has_partition_key(last_partition_window[0]) assert not partitions_def.has_partition_key(last_partition_window[1]) + + +@pytest.mark.parametrize( + "partitions_def", + [ + (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + ], +) +def test_time_window_partitions_def_serialization(partitions_def): + time_window_partitions_def = TimeWindowPartitionsDefinition( + start=partitions_def.start, + end=partitions_def.end, + cron_schedule="0 0 * * *", + fmt="%Y-%m-%d", + timezone=partitions_def.timezone, + end_offset=partitions_def.end_offset, + ) + deserialized = deserialize_value( + serialize_value(time_window_partitions_def), TimeWindowPartitionsDefinition + ) + assert deserialized == time_window_partitions_def