Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/n subset refactor] Make TimeWindowPartitionsDefinition serializable #17660

Merged
merged 5 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

import dagster._check as check
from dagster._annotations import PublicAttr, public
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.instance import DynamicPartitionsStore
from dagster._serdes import whitelist_for_serdes
from dagster._serdes.serdes import FieldSerializer
from dagster._utils import utc_datetime_from_timestamp
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE
from dagster._utils.schedules import (
cron_string_iterator,
Expand All @@ -50,6 +54,24 @@
from .partition_key_range import PartitionKeyRange


class DatetimeFieldSerializer(FieldSerializer):
"""Serializes datetime objects to and from floats."""

def pack(self, datetime: Optional[datetime], **_kwargs) -> Optional[float]:
if datetime:
check.invariant(datetime.tzinfo is not None)

# Get the timestamp in UTC
return datetime.timestamp() if datetime else None

def unpack(
self,
datetime_float: Optional[float],
**_kwargs,
) -> Optional[datetime]:
return utc_datetime_from_timestamp(datetime_float) if datetime_float else None


class TimeWindow(NamedTuple):
"""An interval that is closed at the start and open at the end.

Expand All @@ -62,6 +84,10 @@ class TimeWindow(NamedTuple):
end: PublicAttr[datetime]


@whitelist_for_serdes(
field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer},
is_pickleable=False,
)
class TimeWindowPartitionsDefinition(
PartitionsDefinition,
NamedTuple(
Expand Down Expand Up @@ -122,6 +148,11 @@ def __new__(

if isinstance(start, datetime):
start_dt = pendulum.instance(start, tz=timezone)

if start.tzinfo:
# Pendulum.instance does not override the timezone of the datetime object,
# so we convert it to the provided timezone
start_dt = start_dt.in_tz(tz=timezone)
else:
start_dt = pendulum.instance(datetime.strptime(start, fmt), tz=timezone)

Expand Down Expand Up @@ -288,6 +319,13 @@ def __repr__(self):
def __hash__(self):
return hash(tuple(self.__repr__()))

def __getstate__(self):
# Only namedtuples where the ordering of fields matches the ordering of __new__ args
# are pickleable. This does not apply for TimeWindowPartitionsDefinition, so we
# override __getstate__ to raise an error when attempting to pickle.
# https://github.com/dagster-io/dagster/issues/2372
raise DagsterInvariantViolationError("TimeWindowPartitionsDefinition is not pickleable")

@functools.lru_cache(maxsize=100)
def time_window_for_partition_key(self, partition_key: str) -> TimeWindow:
partition_key_dt = pendulum.instance(
Expand Down
32 changes: 21 additions & 11 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def whitelist_for_serdes(
old_fields: Optional[Mapping[str, JsonSerializableValue]] = ...,
skip_when_empty_fields: Optional[AbstractSet[str]] = ...,
field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None,
is_pickleable: bool = True,
) -> Callable[[T_Type], T_Type]:
...

Expand All @@ -222,6 +223,7 @@ def whitelist_for_serdes(
old_fields: Optional[Mapping[str, JsonSerializableValue]] = None,
skip_when_empty_fields: Optional[AbstractSet[str]] = None,
field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None,
is_pickleable: bool = True,
) -> Union[T_Type, Callable[[T_Type], T_Type]]:
"""Decorator to whitelist a NamedTuple or Enum subclass to be serializable. Various arguments can be passed
to alter serialization behavior for backcompat purposes.
Expand Down Expand Up @@ -276,7 +278,10 @@ def whitelist_for_serdes(
)
if __cls is not None: # decorator invoked directly on class
check.class_param(__cls, "__cls")
return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(__cls)
return _whitelist_for_serdes(
whitelist_map=_WHITELIST_MAP,
is_pickleable=is_pickleable,
)(__cls)
else: # decorator passed params
check.opt_class_param(serializer, "serializer", superclass=Serializer)
return _whitelist_for_serdes(
Expand All @@ -288,6 +293,7 @@ def whitelist_for_serdes(
old_fields=old_fields,
skip_when_empty_fields=skip_when_empty_fields,
field_serializers=field_serializers,
is_pickleable=is_pickleable,
)


Expand All @@ -300,6 +306,7 @@ def _whitelist_for_serdes(
old_fields: Optional[Mapping[str, JsonSerializableValue]] = None,
skip_when_empty_fields: Optional[AbstractSet[str]] = None,
field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None,
is_pickleable: bool = True,
) -> Callable[[T_Type], T_Type]:
def __whitelist_for_serdes(klass: T_Type) -> T_Type:
if issubclass(klass, Enum) and (
Expand All @@ -316,7 +323,7 @@ def __whitelist_for_serdes(klass: T_Type) -> T_Type:
elif is_named_tuple_subclass(klass) and (
serializer is None or issubclass(serializer, NamedTupleSerializer)
):
_check_serdes_tuple_class_invariants(klass)
_check_serdes_tuple_class_invariants(klass, is_pickleable=is_pickleable)
whitelist_map.register_tuple(
klass.__name__,
klass,
Expand Down Expand Up @@ -932,7 +939,9 @@ def _unpack_value(
###################################################################################################


def _check_serdes_tuple_class_invariants(klass: Type[NamedTuple]) -> None:
def _check_serdes_tuple_class_invariants(
klass: Type[NamedTuple], is_pickleable: bool = True
) -> None:
sig_params = signature(klass.__new__).parameters
dunder_new_params = list(sig_params.values())

Expand Down Expand Up @@ -960,14 +969,15 @@ def _with_header(msg: str) -> str:

raise SerdesUsageError(_with_header(error_msg))

value_param = value_params[index]
if value_param.name != field:
error_msg = (
"Params to __new__ must match the order of field declaration in the namedtuple. "
f'Declared field number {index + 1} in the namedtuple is "{field}". '
f'Parameter {index + 1} in __new__ method is "{value_param.name}".'
)
raise SerdesUsageError(_with_header(error_msg))
if is_pickleable:
value_param = value_params[index]
if value_param.name != field:
error_msg = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alangenfeld do you know if we can remove this check altogether? Looks like when rebuild named tuples during deserialization we use kwargs, so this might not be needed anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heres the history i could dig up, unfortunate we are locked out of our phabricator content now
52be7e8
#2372

I think PipelineRun was getting pickled when passed around in multiprocessing contexts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • might be better to rename the toggling bool to something along the lines of cant_pickle to more clearly advertise what opting out of the protection represents
  • probably want to put pickle round tripping under test to demonstrate that it is known to fail in whatever way it does (if possible make it fail clearly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the context here.

I renamed the param to is_pickleable. I also overrode a __getstate__ method on TimeWindowPartitionsDefinition that enables raising an error when attempting to pickle:

https://docs.python.org/3/library/pickle.html#handling-stateful-objects

"Params to __new__ must match the order of field declaration in the namedtuple. "
f'Declared field number {index + 1} in the namedtuple is "{field}". '
f'Parameter {index + 1} in __new__ method is "{value_param.name}".'
)
raise SerdesUsageError(_with_header(error_msg))

if len(value_params) > len(klass._fields):
# Ensure that remaining parameters have default values
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
import random
from datetime import datetime
from typing import Optional, Sequence, cast
Expand Down Expand Up @@ -25,6 +26,8 @@
TimeWindow,
TimeWindowPartitionsSubset,
)
from dagster._core.errors import DagsterInvariantViolationError
from dagster._serdes import deserialize_value, serialize_value
from dagster._seven.compat.pendulum import create_pendulum_time
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE

Expand Down Expand Up @@ -1310,3 +1313,35 @@ def test_partition_with_end_date(
assert partitions_def.has_partition_key(first_partition_window[0])
assert partitions_def.has_partition_key(last_partition_window[0])
assert not partitions_def.has_partition_key(last_partition_window[1])


@pytest.mark.parametrize(
"partitions_def",
[
(DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")),
(DailyPartitionsDefinition("2023-01-01")),
],
)
def test_time_window_partitions_def_serialization(partitions_def):
time_window_partitions_def = TimeWindowPartitionsDefinition(
start=partitions_def.start,
end=partitions_def.end,
cron_schedule="0 0 * * *",
fmt="%Y-%m-%d",
timezone=partitions_def.timezone,
end_offset=partitions_def.end_offset,
)
assert (
deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def
)


def test_cannot_pickle_time_window_partitions_def():
import datetime

partitions_def = TimeWindowPartitionsDefinition(
datetime.datetime(2021, 1, 1), "America/Los_Angeles", cron_schedule="0 0 * * *"
)

with pytest.raises(DagsterInvariantViolationError, match="not pickleable"):
pickle.loads(pickle.dumps(partitions_def))