From ed8925c5e09af9f0765c09ac03a7ee631ec5fbac Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Thu, 2 Nov 2023 13:22:09 -0700 Subject: [PATCH 1/5] add whitelist for serdes --- .../definitions/time_window_partitions.py | 39 ++++++++++++++++--- .../test_time_window_partitions.py | 23 +++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index be8ebad01d807..794eba3a413ed 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -28,6 +28,9 @@ import dagster._check as check from dagster._annotations import PublicAttr, public from dagster._core.instance import DynamicPartitionsStore +from dagster._serdes import whitelist_for_serdes +from dagster._serdes.serdes import FieldSerializer +from dagster._utils import utc_datetime_from_timestamp from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE from dagster._utils.schedules import ( cron_string_iterator, @@ -50,6 +53,24 @@ from .partition_key_range import PartitionKeyRange +class DatetimeFieldSerializer(FieldSerializer): + """Serializes datetime objects to and from floats.""" + + def pack(self, datetime: Optional[datetime], **_kwargs) -> Optional[float]: + if datetime: + check.invariant(datetime.tzinfo is not None) + + # Get the timestamp in UTC + return datetime.timestamp() if datetime else None + + def unpack( + self, + datetime_float: Optional[float], + **_kwargs, + ) -> Optional[datetime]: + return utc_datetime_from_timestamp(datetime_float) if datetime_float else None + + class TimeWindow(NamedTuple): """An interval that is closed at the start and open at the end. @@ -62,15 +83,18 @@ class TimeWindow(NamedTuple): end: PublicAttr[datetime] +@whitelist_for_serdes( + field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} +) class TimeWindowPartitionsDefinition( PartitionsDefinition, NamedTuple( "_TimeWindowPartitionsDefinition", [ ("start", PublicAttr[datetime]), + ("fmt", PublicAttr[str]), ("timezone", PublicAttr[str]), ("end", PublicAttr[Optional[datetime]]), - ("fmt", PublicAttr[str]), ("end_offset", PublicAttr[int]), ("cron_schedule", PublicAttr[str]), ], @@ -108,20 +132,25 @@ def __new__( cls, start: Union[datetime, str], fmt: str, - end: Union[datetime, str, None] = None, - schedule_type: Optional[ScheduleType] = None, timezone: Optional[str] = None, + end: Union[datetime, str, None] = None, end_offset: int = 0, + cron_schedule: Optional[str] = None, + schedule_type: Optional[ScheduleType] = None, minute_offset: Optional[int] = None, hour_offset: Optional[int] = None, day_offset: Optional[int] = None, - cron_schedule: Optional[str] = None, ): check.opt_str_param(timezone, "timezone") timezone = timezone or "UTC" if isinstance(start, datetime): start_dt = pendulum.instance(start, tz=timezone) + + if start.tzinfo: + # Pendulum.instance does not override the timezone of the datetime object, + # so we convert it to the provided timezone + start_dt = start_dt.in_tz(tz=timezone) else: start_dt = pendulum.instance(datetime.strptime(start, fmt), tz=timezone) @@ -156,7 +185,7 @@ def __new__( ) return super(TimeWindowPartitionsDefinition, cls).__new__( - cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule + cls, start_dt, fmt, timezone, end_dt, end_offset, cron_schedule ) def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float: diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index e843706c99808..b87338d0d237f 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -25,6 +25,7 @@ TimeWindow, TimeWindowPartitionsSubset, ) +from dagster._serdes import deserialize_value, serialize_value from dagster._seven.compat.pendulum import create_pendulum_time from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE @@ -1310,3 +1311,25 @@ def test_partition_with_end_date( assert partitions_def.has_partition_key(first_partition_window[0]) assert partitions_def.has_partition_key(last_partition_window[0]) assert not partitions_def.has_partition_key(last_partition_window[1]) + + +@pytest.mark.parametrize( + "partitions_def", + [ + (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + ], +) +def test_time_window_partitions_def_serialization(partitions_def): + time_window_partitions_def = TimeWindowPartitionsDefinition( + start=partitions_def.start, + end=partitions_def.end, + cron_schedule="0 0 * * *", + fmt="%Y-%m-%d", + timezone=partitions_def.timezone, + end_offset=partitions_def.end_offset, + ) + deserialized = deserialize_value( + serialize_value(time_window_partitions_def), TimeWindowPartitionsDefinition + ) + assert deserialized == time_window_partitions_def From ee23877a4b6385ad2a0f8e4f0001ea2e07c47c25 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Mon, 6 Nov 2023 10:15:41 -0800 Subject: [PATCH 2/5] create serializable time window partitions def named tuple --- .../definitions/time_window_partitions.py | 55 +++++++++++++++++-- .../test_time_window_partitions.py | 8 ++- 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index 794eba3a413ed..6262b87392143 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -31,6 +31,7 @@ from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import FieldSerializer from dagster._utils import utc_datetime_from_timestamp +from dagster._utils.cached_method import cached_method from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE from dagster._utils.schedules import ( cron_string_iterator, @@ -86,15 +87,52 @@ class TimeWindow(NamedTuple): @whitelist_for_serdes( field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} ) +class SerializableTimeWindowPartitionsDefinition( + NamedTuple( + "_SerializableTimeWindowPartitionsDefinition", + [ + ("start", PublicAttr[datetime]), + ("fmt", PublicAttr[str]), + ("timezone", PublicAttr[str]), + ("end", PublicAttr[Optional[datetime]]), + ("end_offset", PublicAttr[int]), + ("cron_schedule", PublicAttr[str]), + ], + ) +): + def __new__( + cls, + start: datetime, + fmt: str, + timezone: str, + end: Optional[datetime], + end_offset: int, + cron_schedule: str, + ): + return super(SerializableTimeWindowPartitionsDefinition, cls).__new__( + cls, start, fmt, timezone, end, end_offset, cron_schedule + ) + + def to_time_window_partitions_def(self) -> "TimeWindowPartitionsDefinition": + return TimeWindowPartitionsDefinition( + self.start, + self.fmt, + self.end, + timezone=self.timezone, + end_offset=self.end_offset, + cron_schedule=self.cron_schedule, + ) + + class TimeWindowPartitionsDefinition( PartitionsDefinition, NamedTuple( "_TimeWindowPartitionsDefinition", [ ("start", PublicAttr[datetime]), - ("fmt", PublicAttr[str]), ("timezone", PublicAttr[str]), ("end", PublicAttr[Optional[datetime]]), + ("fmt", PublicAttr[str]), ("end_offset", PublicAttr[int]), ("cron_schedule", PublicAttr[str]), ], @@ -132,14 +170,14 @@ def __new__( cls, start: Union[datetime, str], fmt: str, - timezone: Optional[str] = None, end: Union[datetime, str, None] = None, - end_offset: int = 0, - cron_schedule: Optional[str] = None, schedule_type: Optional[ScheduleType] = None, + timezone: Optional[str] = None, + end_offset: int = 0, minute_offset: Optional[int] = None, hour_offset: Optional[int] = None, day_offset: Optional[int] = None, + cron_schedule: Optional[str] = None, ): check.opt_str_param(timezone, "timezone") timezone = timezone or "UTC" @@ -185,7 +223,14 @@ def __new__( ) return super(TimeWindowPartitionsDefinition, cls).__new__( - cls, start_dt, fmt, timezone, end_dt, end_offset, cron_schedule + cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule + ) + + def to_serializable_time_window_partitions_def( + self, + ) -> SerializableTimeWindowPartitionsDefinition: + return SerializableTimeWindowPartitionsDefinition( + self.start, self.fmt, self.timezone, self.end, self.end_offset, self.cron_schedule ) def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float: diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index b87338d0d237f..050aeb56444d2 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -1329,7 +1329,9 @@ def test_time_window_partitions_def_serialization(partitions_def): timezone=partitions_def.timezone, end_offset=partitions_def.end_offset, ) - deserialized = deserialize_value( - serialize_value(time_window_partitions_def), TimeWindowPartitionsDefinition + assert ( + deserialize_value( + serialize_value(time_window_partitions_def.to_serializable_time_window_partitions_def()) + ).to_time_window_partitions_def() + == time_window_partitions_def ) - assert deserialized == time_window_partitions_def From eb16480e9c9ab59eece0496383790204d52335b7 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Mon, 6 Nov 2023 12:44:19 -0800 Subject: [PATCH 3/5] update tests and comments --- .../dagster/_core/definitions/time_window_partitions.py | 6 ++++++ .../definitions_tests/test_time_window_partitions.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index 6262b87392143..c30101cf62989 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -84,6 +84,10 @@ class TimeWindow(NamedTuple): end: PublicAttr[datetime] +# Unfortunately we can't use @whitelist_for_serdes on TimeWindowPartitionsDefinition +# because args to __new__ are a different order than the fields in the NamedTuple, and we can't +# reorder them because it's a public API. Until TimeWindowPartitionsDefinition can decorated, +# this class is used to serialize it. @whitelist_for_serdes( field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} ) @@ -222,6 +226,8 @@ def __new__( " TimeWindowPartitionsDefinition." ) + # When adding new fields to the NamedTuple, update the SerializableTimeWindowPartitionsDefinition + # class with the same fields. return super(TimeWindowPartitionsDefinition, cls).__new__( cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule ) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index 050aeb56444d2..9c04cc10cc5b9 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -1317,7 +1317,7 @@ def test_partition_with_end_date( "partitions_def", [ (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), - (DailyPartitionsDefinition("2023-01-01", timezone="America/New_York")), + (DailyPartitionsDefinition("2023-01-01")), ], ) def test_time_window_partitions_def_serialization(partitions_def): From c15fda66b844af35ac38886be3503e7bf141b523 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Fri, 10 Nov 2023 13:11:16 -0800 Subject: [PATCH 4/5] add option to disable same ordering --- .../definitions/time_window_partitions.py | 54 +------------------ .../dagster/dagster/_serdes/serdes.py | 34 ++++++++---- .../test_time_window_partitions.py | 5 +- 3 files changed, 26 insertions(+), 67 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index c30101cf62989..5ff7cda197175 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -31,7 +31,6 @@ from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import FieldSerializer from dagster._utils import utc_datetime_from_timestamp -from dagster._utils.cached_method import cached_method from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE from dagster._utils.schedules import ( cron_string_iterator, @@ -84,50 +83,10 @@ class TimeWindow(NamedTuple): end: PublicAttr[datetime] -# Unfortunately we can't use @whitelist_for_serdes on TimeWindowPartitionsDefinition -# because args to __new__ are a different order than the fields in the NamedTuple, and we can't -# reorder them because it's a public API. Until TimeWindowPartitionsDefinition can decorated, -# this class is used to serialize it. @whitelist_for_serdes( - field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} + field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer}, + require_args_to_match_field_ordering=False, ) -class SerializableTimeWindowPartitionsDefinition( - NamedTuple( - "_SerializableTimeWindowPartitionsDefinition", - [ - ("start", PublicAttr[datetime]), - ("fmt", PublicAttr[str]), - ("timezone", PublicAttr[str]), - ("end", PublicAttr[Optional[datetime]]), - ("end_offset", PublicAttr[int]), - ("cron_schedule", PublicAttr[str]), - ], - ) -): - def __new__( - cls, - start: datetime, - fmt: str, - timezone: str, - end: Optional[datetime], - end_offset: int, - cron_schedule: str, - ): - return super(SerializableTimeWindowPartitionsDefinition, cls).__new__( - cls, start, fmt, timezone, end, end_offset, cron_schedule - ) - - def to_time_window_partitions_def(self) -> "TimeWindowPartitionsDefinition": - return TimeWindowPartitionsDefinition( - self.start, - self.fmt, - self.end, - timezone=self.timezone, - end_offset=self.end_offset, - cron_schedule=self.cron_schedule, - ) - - class TimeWindowPartitionsDefinition( PartitionsDefinition, NamedTuple( @@ -226,19 +185,10 @@ def __new__( " TimeWindowPartitionsDefinition." ) - # When adding new fields to the NamedTuple, update the SerializableTimeWindowPartitionsDefinition - # class with the same fields. return super(TimeWindowPartitionsDefinition, cls).__new__( cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule ) - def to_serializable_time_window_partitions_def( - self, - ) -> SerializableTimeWindowPartitionsDefinition: - return SerializableTimeWindowPartitionsDefinition( - self.start, self.fmt, self.timezone, self.end, self.end_offset, self.cron_schedule - ) - def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float: return ( pendulum.instance(current_time, tz=self.timezone) diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index 6a6c1561dd375..7e249801a60a2 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -208,6 +208,7 @@ def whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = ..., skip_when_empty_fields: Optional[AbstractSet[str]] = ..., field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, + require_args_to_match_field_ordering: bool = True, ) -> Callable[[T_Type], T_Type]: ... @@ -222,6 +223,7 @@ def whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = None, skip_when_empty_fields: Optional[AbstractSet[str]] = None, field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, + require_args_to_match_field_ordering: bool = True, ) -> Union[T_Type, Callable[[T_Type], T_Type]]: """Decorator to whitelist a NamedTuple or Enum subclass to be serializable. Various arguments can be passed to alter serialization behavior for backcompat purposes. @@ -276,7 +278,10 @@ def whitelist_for_serdes( ) if __cls is not None: # decorator invoked directly on class check.class_param(__cls, "__cls") - return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(__cls) + return _whitelist_for_serdes( + whitelist_map=_WHITELIST_MAP, + require_args_to_match_field_ordering=require_args_to_match_field_ordering, + )(__cls) else: # decorator passed params check.opt_class_param(serializer, "serializer", superclass=Serializer) return _whitelist_for_serdes( @@ -288,6 +293,7 @@ def whitelist_for_serdes( old_fields=old_fields, skip_when_empty_fields=skip_when_empty_fields, field_serializers=field_serializers, + require_args_to_match_field_ordering=require_args_to_match_field_ordering, ) @@ -300,6 +306,7 @@ def _whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = None, skip_when_empty_fields: Optional[AbstractSet[str]] = None, field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, + require_args_to_match_field_ordering: bool = True, ) -> Callable[[T_Type], T_Type]: def __whitelist_for_serdes(klass: T_Type) -> T_Type: if issubclass(klass, Enum) and ( @@ -316,7 +323,9 @@ def __whitelist_for_serdes(klass: T_Type) -> T_Type: elif is_named_tuple_subclass(klass) and ( serializer is None or issubclass(serializer, NamedTupleSerializer) ): - _check_serdes_tuple_class_invariants(klass) + _check_serdes_tuple_class_invariants( + klass, require_args_to_match_field_ordering=require_args_to_match_field_ordering + ) whitelist_map.register_tuple( klass.__name__, klass, @@ -932,7 +941,9 @@ def _unpack_value( ################################################################################################### -def _check_serdes_tuple_class_invariants(klass: Type[NamedTuple]) -> None: +def _check_serdes_tuple_class_invariants( + klass: Type[NamedTuple], require_args_to_match_field_ordering: bool = True +) -> None: sig_params = signature(klass.__new__).parameters dunder_new_params = list(sig_params.values()) @@ -960,14 +971,15 @@ def _with_header(msg: str) -> str: raise SerdesUsageError(_with_header(error_msg)) - value_param = value_params[index] - if value_param.name != field: - error_msg = ( - "Params to __new__ must match the order of field declaration in the namedtuple. " - f'Declared field number {index + 1} in the namedtuple is "{field}". ' - f'Parameter {index + 1} in __new__ method is "{value_param.name}".' - ) - raise SerdesUsageError(_with_header(error_msg)) + if require_args_to_match_field_ordering: + value_param = value_params[index] + if value_param.name != field: + error_msg = ( + "Params to __new__ must match the order of field declaration in the namedtuple. " + f'Declared field number {index + 1} in the namedtuple is "{field}". ' + f'Parameter {index + 1} in __new__ method is "{value_param.name}".' + ) + raise SerdesUsageError(_with_header(error_msg)) if len(value_params) > len(klass._fields): # Ensure that remaining parameters have default values diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index 9c04cc10cc5b9..1f166365f4dcc 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -1330,8 +1330,5 @@ def test_time_window_partitions_def_serialization(partitions_def): end_offset=partitions_def.end_offset, ) assert ( - deserialize_value( - serialize_value(time_window_partitions_def.to_serializable_time_window_partitions_def()) - ).to_time_window_partitions_def() - == time_window_partitions_def + deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def ) From e97b5b24d79055dd4207de95afc839b56f2d9d17 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Mon, 13 Nov 2023 12:53:24 -0800 Subject: [PATCH 5/5] prevent TimeWindowPartitionsDefinition from being pickled --- .../definitions/time_window_partitions.py | 10 +++++++++- .../dagster/dagster/_serdes/serdes.py | 18 ++++++++---------- .../test_time_window_partitions.py | 13 +++++++++++++ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index 5ff7cda197175..11c1334a7f209 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -27,6 +27,7 @@ import dagster._check as check from dagster._annotations import PublicAttr, public +from dagster._core.errors import DagsterInvariantViolationError from dagster._core.instance import DynamicPartitionsStore from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import FieldSerializer @@ -85,7 +86,7 @@ class TimeWindow(NamedTuple): @whitelist_for_serdes( field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer}, - require_args_to_match_field_ordering=False, + is_pickleable=False, ) class TimeWindowPartitionsDefinition( PartitionsDefinition, @@ -318,6 +319,13 @@ def __repr__(self): def __hash__(self): return hash(tuple(self.__repr__())) + def __getstate__(self): + # Only namedtuples where the ordering of fields matches the ordering of __new__ args + # are pickleable. This does not apply for TimeWindowPartitionsDefinition, so we + # override __getstate__ to raise an error when attempting to pickle. + # https://github.com/dagster-io/dagster/issues/2372 + raise DagsterInvariantViolationError("TimeWindowPartitionsDefinition is not pickleable") + @functools.lru_cache(maxsize=100) def time_window_for_partition_key(self, partition_key: str) -> TimeWindow: partition_key_dt = pendulum.instance( diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index 7e249801a60a2..b7f5f74834fea 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -208,7 +208,7 @@ def whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = ..., skip_when_empty_fields: Optional[AbstractSet[str]] = ..., field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, - require_args_to_match_field_ordering: bool = True, + is_pickleable: bool = True, ) -> Callable[[T_Type], T_Type]: ... @@ -223,7 +223,7 @@ def whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = None, skip_when_empty_fields: Optional[AbstractSet[str]] = None, field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, - require_args_to_match_field_ordering: bool = True, + is_pickleable: bool = True, ) -> Union[T_Type, Callable[[T_Type], T_Type]]: """Decorator to whitelist a NamedTuple or Enum subclass to be serializable. Various arguments can be passed to alter serialization behavior for backcompat purposes. @@ -280,7 +280,7 @@ def whitelist_for_serdes( check.class_param(__cls, "__cls") return _whitelist_for_serdes( whitelist_map=_WHITELIST_MAP, - require_args_to_match_field_ordering=require_args_to_match_field_ordering, + is_pickleable=is_pickleable, )(__cls) else: # decorator passed params check.opt_class_param(serializer, "serializer", superclass=Serializer) @@ -293,7 +293,7 @@ def whitelist_for_serdes( old_fields=old_fields, skip_when_empty_fields=skip_when_empty_fields, field_serializers=field_serializers, - require_args_to_match_field_ordering=require_args_to_match_field_ordering, + is_pickleable=is_pickleable, ) @@ -306,7 +306,7 @@ def _whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = None, skip_when_empty_fields: Optional[AbstractSet[str]] = None, field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, - require_args_to_match_field_ordering: bool = True, + is_pickleable: bool = True, ) -> Callable[[T_Type], T_Type]: def __whitelist_for_serdes(klass: T_Type) -> T_Type: if issubclass(klass, Enum) and ( @@ -323,9 +323,7 @@ def __whitelist_for_serdes(klass: T_Type) -> T_Type: elif is_named_tuple_subclass(klass) and ( serializer is None or issubclass(serializer, NamedTupleSerializer) ): - _check_serdes_tuple_class_invariants( - klass, require_args_to_match_field_ordering=require_args_to_match_field_ordering - ) + _check_serdes_tuple_class_invariants(klass, is_pickleable=is_pickleable) whitelist_map.register_tuple( klass.__name__, klass, @@ -942,7 +940,7 @@ def _unpack_value( def _check_serdes_tuple_class_invariants( - klass: Type[NamedTuple], require_args_to_match_field_ordering: bool = True + klass: Type[NamedTuple], is_pickleable: bool = True ) -> None: sig_params = signature(klass.__new__).parameters dunder_new_params = list(sig_params.values()) @@ -971,7 +969,7 @@ def _with_header(msg: str) -> str: raise SerdesUsageError(_with_header(error_msg)) - if require_args_to_match_field_ordering: + if is_pickleable: value_param = value_params[index] if value_param.name != field: error_msg = ( diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index 1f166365f4dcc..93587f8beda27 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -1,3 +1,4 @@ +import pickle import random from datetime import datetime from typing import Optional, Sequence, cast @@ -25,6 +26,7 @@ TimeWindow, TimeWindowPartitionsSubset, ) +from dagster._core.errors import DagsterInvariantViolationError from dagster._serdes import deserialize_value, serialize_value from dagster._seven.compat.pendulum import create_pendulum_time from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE @@ -1332,3 +1334,14 @@ def test_time_window_partitions_def_serialization(partitions_def): assert ( deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def ) + + +def test_cannot_pickle_time_window_partitions_def(): + import datetime + + partitions_def = TimeWindowPartitionsDefinition( + datetime.datetime(2021, 1, 1), "America/Los_Angeles", cron_schedule="0 0 * * *" + ) + + with pytest.raises(DagsterInvariantViolationError, match="not pickleable"): + pickle.loads(pickle.dumps(partitions_def))