Skip to content

Commit

Permalink
add whitelist for serdes
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 2, 2023
1 parent 04a91d7 commit 8a95de4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import dagster._check as check
from dagster._annotations import PublicAttr, public
from dagster._core.instance import DynamicPartitionsStore
from dagster._serdes import whitelist_for_serdes
from dagster._serdes.serdes import FieldSerializer
from dagster._utils import utc_datetime_from_timestamp
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE
from dagster._utils.schedules import (
cron_string_iterator,
Expand All @@ -48,6 +51,24 @@
from .partition_key_range import PartitionKeyRange


class DatetimeFieldSerializer(FieldSerializer):
"""Serializes datetime objects to and from floats."""

def pack(self, datetime: Optional[datetime], **_kwargs) -> Optional[float]:
if datetime:
check.invariant(datetime.tzinfo is not None)

# Get the timestamp in UTC
return datetime.timestamp() if datetime else None

def unpack(
self,
datetime_float: Optional[float],
**_kwargs,
) -> Optional[datetime]:
return utc_datetime_from_timestamp(datetime_float) if datetime_float else None


class TimeWindow(NamedTuple):
"""An interval that is closed at the start and open at the end.
Expand All @@ -60,15 +81,18 @@ class TimeWindow(NamedTuple):
end: PublicAttr[datetime]


@whitelist_for_serdes(
field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer}
)
class TimeWindowPartitionsDefinition(
PartitionsDefinition,
NamedTuple(
"_TimeWindowPartitionsDefinition",
[
("start", PublicAttr[datetime]),
("fmt", PublicAttr[str]),
("timezone", PublicAttr[str]),
("end", PublicAttr[Optional[datetime]]),
("fmt", PublicAttr[str]),
("end_offset", PublicAttr[int]),
("cron_schedule", PublicAttr[str]),
],
Expand Down Expand Up @@ -106,20 +130,25 @@ def __new__(
cls,
start: Union[datetime, str],
fmt: str,
end: Union[datetime, str, None] = None,
schedule_type: Optional[ScheduleType] = None,
timezone: Optional[str] = None,
end: Union[datetime, str, None] = None,
end_offset: int = 0,
cron_schedule: Optional[str] = None,
schedule_type: Optional[ScheduleType] = None,
minute_offset: Optional[int] = None,
hour_offset: Optional[int] = None,
day_offset: Optional[int] = None,
cron_schedule: Optional[str] = None,
):
check.opt_str_param(timezone, "timezone")
timezone = timezone or "UTC"

if isinstance(start, datetime):
start_dt = pendulum.instance(start, tz=timezone)

if start.tzinfo:
# Pendulum.instance does not override the timezone of the datetime object,
# so we convert it to the provided timezone
start_dt = start_dt.in_tz(tz=timezone)
else:
start_dt = pendulum.instance(datetime.strptime(start, fmt), tz=timezone)

Expand Down Expand Up @@ -154,7 +183,7 @@ def __new__(
)

return super(TimeWindowPartitionsDefinition, cls).__new__(
cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule
cls, start_dt, fmt, timezone, end_dt, end_offset, cron_schedule
)

def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TimeWindow,
TimeWindowPartitionsSubset,
)
from dagster._serdes import deserialize_value, serialize_value
from dagster._seven.compat.pendulum import create_pendulum_time
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE

Expand Down Expand Up @@ -1268,3 +1269,25 @@ def test_partition_with_end_date(
assert partitions_def.has_partition_key(first_partition_window[0])
assert partitions_def.has_partition_key(last_partition_window[0])
assert not partitions_def.has_partition_key(last_partition_window[1])


@pytest.mark.parametrize(
"partitions_def",
[
(DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")),
(DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")),
],
)
def test_time_window_partitions_def_serialization(partitions_def):
time_window_partitions_def = TimeWindowPartitionsDefinition(
start=partitions_def.start,
end=partitions_def.end,
cron_schedule="0 0 * * *",
fmt="%Y-%m-%d",
timezone=partitions_def.timezone,
end_offset=partitions_def.end_offset,
)
deserialized = deserialize_value(
serialize_value(time_window_partitions_def), TimeWindowPartitionsDefinition
)
assert deserialized == time_window_partitions_def

0 comments on commit 8a95de4

Please sign in to comment.