From 2cdad9cd71be2c820f726ea6985487db20373690 Mon Sep 17 00:00:00 2001 From: Claire Lin Date: Fri, 10 Nov 2023 13:11:16 -0800 Subject: [PATCH] add option to disable same ordering --- .../definitions/time_window_partitions.py | 54 +------------------ .../dagster/dagster/_serdes/serdes.py | 34 ++++++++---- .../test_time_window_partitions.py | 5 +- 3 files changed, 26 insertions(+), 67 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py index c30101cf62989..5ff7cda197175 100644 --- a/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/time_window_partitions.py @@ -31,7 +31,6 @@ from dagster._serdes import whitelist_for_serdes from dagster._serdes.serdes import FieldSerializer from dagster._utils import utc_datetime_from_timestamp -from dagster._utils.cached_method import cached_method from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE from dagster._utils.schedules import ( cron_string_iterator, @@ -84,50 +83,10 @@ class TimeWindow(NamedTuple): end: PublicAttr[datetime] -# Unfortunately we can't use @whitelist_for_serdes on TimeWindowPartitionsDefinition -# because args to __new__ are a different order than the fields in the NamedTuple, and we can't -# reorder them because it's a public API. Until TimeWindowPartitionsDefinition can decorated, -# this class is used to serialize it. @whitelist_for_serdes( - field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer} + field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer}, + require_args_to_match_field_ordering=False, ) -class SerializableTimeWindowPartitionsDefinition( - NamedTuple( - "_SerializableTimeWindowPartitionsDefinition", - [ - ("start", PublicAttr[datetime]), - ("fmt", PublicAttr[str]), - ("timezone", PublicAttr[str]), - ("end", PublicAttr[Optional[datetime]]), - ("end_offset", PublicAttr[int]), - ("cron_schedule", PublicAttr[str]), - ], - ) -): - def __new__( - cls, - start: datetime, - fmt: str, - timezone: str, - end: Optional[datetime], - end_offset: int, - cron_schedule: str, - ): - return super(SerializableTimeWindowPartitionsDefinition, cls).__new__( - cls, start, fmt, timezone, end, end_offset, cron_schedule - ) - - def to_time_window_partitions_def(self) -> "TimeWindowPartitionsDefinition": - return TimeWindowPartitionsDefinition( - self.start, - self.fmt, - self.end, - timezone=self.timezone, - end_offset=self.end_offset, - cron_schedule=self.cron_schedule, - ) - - class TimeWindowPartitionsDefinition( PartitionsDefinition, NamedTuple( @@ -226,19 +185,10 @@ def __new__( " TimeWindowPartitionsDefinition." ) - # When adding new fields to the NamedTuple, update the SerializableTimeWindowPartitionsDefinition - # class with the same fields. return super(TimeWindowPartitionsDefinition, cls).__new__( cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule ) - def to_serializable_time_window_partitions_def( - self, - ) -> SerializableTimeWindowPartitionsDefinition: - return SerializableTimeWindowPartitionsDefinition( - self.start, self.fmt, self.timezone, self.end, self.end_offset, self.cron_schedule - ) - def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float: return ( pendulum.instance(current_time, tz=self.timezone) diff --git a/python_modules/dagster/dagster/_serdes/serdes.py b/python_modules/dagster/dagster/_serdes/serdes.py index 6a6c1561dd375..7e249801a60a2 100644 --- a/python_modules/dagster/dagster/_serdes/serdes.py +++ b/python_modules/dagster/dagster/_serdes/serdes.py @@ -208,6 +208,7 @@ def whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = ..., skip_when_empty_fields: Optional[AbstractSet[str]] = ..., field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, + require_args_to_match_field_ordering: bool = True, ) -> Callable[[T_Type], T_Type]: ... @@ -222,6 +223,7 @@ def whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = None, skip_when_empty_fields: Optional[AbstractSet[str]] = None, field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, + require_args_to_match_field_ordering: bool = True, ) -> Union[T_Type, Callable[[T_Type], T_Type]]: """Decorator to whitelist a NamedTuple or Enum subclass to be serializable. Various arguments can be passed to alter serialization behavior for backcompat purposes. @@ -276,7 +278,10 @@ def whitelist_for_serdes( ) if __cls is not None: # decorator invoked directly on class check.class_param(__cls, "__cls") - return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(__cls) + return _whitelist_for_serdes( + whitelist_map=_WHITELIST_MAP, + require_args_to_match_field_ordering=require_args_to_match_field_ordering, + )(__cls) else: # decorator passed params check.opt_class_param(serializer, "serializer", superclass=Serializer) return _whitelist_for_serdes( @@ -288,6 +293,7 @@ def whitelist_for_serdes( old_fields=old_fields, skip_when_empty_fields=skip_when_empty_fields, field_serializers=field_serializers, + require_args_to_match_field_ordering=require_args_to_match_field_ordering, ) @@ -300,6 +306,7 @@ def _whitelist_for_serdes( old_fields: Optional[Mapping[str, JsonSerializableValue]] = None, skip_when_empty_fields: Optional[AbstractSet[str]] = None, field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None, + require_args_to_match_field_ordering: bool = True, ) -> Callable[[T_Type], T_Type]: def __whitelist_for_serdes(klass: T_Type) -> T_Type: if issubclass(klass, Enum) and ( @@ -316,7 +323,9 @@ def __whitelist_for_serdes(klass: T_Type) -> T_Type: elif is_named_tuple_subclass(klass) and ( serializer is None or issubclass(serializer, NamedTupleSerializer) ): - _check_serdes_tuple_class_invariants(klass) + _check_serdes_tuple_class_invariants( + klass, require_args_to_match_field_ordering=require_args_to_match_field_ordering + ) whitelist_map.register_tuple( klass.__name__, klass, @@ -932,7 +941,9 @@ def _unpack_value( ################################################################################################### -def _check_serdes_tuple_class_invariants(klass: Type[NamedTuple]) -> None: +def _check_serdes_tuple_class_invariants( + klass: Type[NamedTuple], require_args_to_match_field_ordering: bool = True +) -> None: sig_params = signature(klass.__new__).parameters dunder_new_params = list(sig_params.values()) @@ -960,14 +971,15 @@ def _with_header(msg: str) -> str: raise SerdesUsageError(_with_header(error_msg)) - value_param = value_params[index] - if value_param.name != field: - error_msg = ( - "Params to __new__ must match the order of field declaration in the namedtuple. " - f'Declared field number {index + 1} in the namedtuple is "{field}". ' - f'Parameter {index + 1} in __new__ method is "{value_param.name}".' - ) - raise SerdesUsageError(_with_header(error_msg)) + if require_args_to_match_field_ordering: + value_param = value_params[index] + if value_param.name != field: + error_msg = ( + "Params to __new__ must match the order of field declaration in the namedtuple. " + f'Declared field number {index + 1} in the namedtuple is "{field}". ' + f'Parameter {index + 1} in __new__ method is "{value_param.name}".' + ) + raise SerdesUsageError(_with_header(error_msg)) if len(value_params) > len(klass._fields): # Ensure that remaining parameters have default values diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py index 9c04cc10cc5b9..1f166365f4dcc 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_time_window_partitions.py @@ -1330,8 +1330,5 @@ def test_time_window_partitions_def_serialization(partitions_def): end_offset=partitions_def.end_offset, ) assert ( - deserialize_value( - serialize_value(time_window_partitions_def.to_serializable_time_window_partitions_def()) - ).to_time_window_partitions_def() - == time_window_partitions_def + deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def )