Skip to content

Commit

Permalink
add option to disable same ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 10, 2023
1 parent c746132 commit 2cdad9c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from dagster._serdes import whitelist_for_serdes
from dagster._serdes.serdes import FieldSerializer
from dagster._utils import utc_datetime_from_timestamp
from dagster._utils.cached_method import cached_method
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE
from dagster._utils.schedules import (
cron_string_iterator,
Expand Down Expand Up @@ -84,50 +83,10 @@ class TimeWindow(NamedTuple):
end: PublicAttr[datetime]


# Unfortunately we can't use @whitelist_for_serdes on TimeWindowPartitionsDefinition
# because args to __new__ are a different order than the fields in the NamedTuple, and we can't
# reorder them because it's a public API. Until TimeWindowPartitionsDefinition can decorated,
# this class is used to serialize it.
@whitelist_for_serdes(
field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer}
field_serializers={"start": DatetimeFieldSerializer, "end": DatetimeFieldSerializer},
require_args_to_match_field_ordering=False,
)
class SerializableTimeWindowPartitionsDefinition(
NamedTuple(
"_SerializableTimeWindowPartitionsDefinition",
[
("start", PublicAttr[datetime]),
("fmt", PublicAttr[str]),
("timezone", PublicAttr[str]),
("end", PublicAttr[Optional[datetime]]),
("end_offset", PublicAttr[int]),
("cron_schedule", PublicAttr[str]),
],
)
):
def __new__(
cls,
start: datetime,
fmt: str,
timezone: str,
end: Optional[datetime],
end_offset: int,
cron_schedule: str,
):
return super(SerializableTimeWindowPartitionsDefinition, cls).__new__(
cls, start, fmt, timezone, end, end_offset, cron_schedule
)

def to_time_window_partitions_def(self) -> "TimeWindowPartitionsDefinition":
return TimeWindowPartitionsDefinition(
self.start,
self.fmt,
self.end,
timezone=self.timezone,
end_offset=self.end_offset,
cron_schedule=self.cron_schedule,
)


class TimeWindowPartitionsDefinition(
PartitionsDefinition,
NamedTuple(
Expand Down Expand Up @@ -226,19 +185,10 @@ def __new__(
" TimeWindowPartitionsDefinition."
)

# When adding new fields to the NamedTuple, update the SerializableTimeWindowPartitionsDefinition
# class with the same fields.
return super(TimeWindowPartitionsDefinition, cls).__new__(
cls, start_dt, timezone, end_dt, fmt, end_offset, cron_schedule
)

def to_serializable_time_window_partitions_def(
self,
) -> SerializableTimeWindowPartitionsDefinition:
return SerializableTimeWindowPartitionsDefinition(
self.start, self.fmt, self.timezone, self.end, self.end_offset, self.cron_schedule
)

def get_current_timestamp(self, current_time: Optional[datetime] = None) -> float:
return (
pendulum.instance(current_time, tz=self.timezone)
Expand Down
34 changes: 23 additions & 11 deletions python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def whitelist_for_serdes(
old_fields: Optional[Mapping[str, JsonSerializableValue]] = ...,
skip_when_empty_fields: Optional[AbstractSet[str]] = ...,
field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None,
require_args_to_match_field_ordering: bool = True,
) -> Callable[[T_Type], T_Type]:
...

Expand All @@ -222,6 +223,7 @@ def whitelist_for_serdes(
old_fields: Optional[Mapping[str, JsonSerializableValue]] = None,
skip_when_empty_fields: Optional[AbstractSet[str]] = None,
field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None,
require_args_to_match_field_ordering: bool = True,
) -> Union[T_Type, Callable[[T_Type], T_Type]]:
"""Decorator to whitelist a NamedTuple or Enum subclass to be serializable. Various arguments can be passed
to alter serialization behavior for backcompat purposes.
Expand Down Expand Up @@ -276,7 +278,10 @@ def whitelist_for_serdes(
)
if __cls is not None: # decorator invoked directly on class
check.class_param(__cls, "__cls")
return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(__cls)
return _whitelist_for_serdes(
whitelist_map=_WHITELIST_MAP,
require_args_to_match_field_ordering=require_args_to_match_field_ordering,
)(__cls)
else: # decorator passed params
check.opt_class_param(serializer, "serializer", superclass=Serializer)
return _whitelist_for_serdes(
Expand All @@ -288,6 +293,7 @@ def whitelist_for_serdes(
old_fields=old_fields,
skip_when_empty_fields=skip_when_empty_fields,
field_serializers=field_serializers,
require_args_to_match_field_ordering=require_args_to_match_field_ordering,
)


Expand All @@ -300,6 +306,7 @@ def _whitelist_for_serdes(
old_fields: Optional[Mapping[str, JsonSerializableValue]] = None,
skip_when_empty_fields: Optional[AbstractSet[str]] = None,
field_serializers: Optional[Mapping[str, Type["FieldSerializer"]]] = None,
require_args_to_match_field_ordering: bool = True,
) -> Callable[[T_Type], T_Type]:
def __whitelist_for_serdes(klass: T_Type) -> T_Type:
if issubclass(klass, Enum) and (
Expand All @@ -316,7 +323,9 @@ def __whitelist_for_serdes(klass: T_Type) -> T_Type:
elif is_named_tuple_subclass(klass) and (
serializer is None or issubclass(serializer, NamedTupleSerializer)
):
_check_serdes_tuple_class_invariants(klass)
_check_serdes_tuple_class_invariants(
klass, require_args_to_match_field_ordering=require_args_to_match_field_ordering
)
whitelist_map.register_tuple(
klass.__name__,
klass,
Expand Down Expand Up @@ -932,7 +941,9 @@ def _unpack_value(
###################################################################################################


def _check_serdes_tuple_class_invariants(klass: Type[NamedTuple]) -> None:
def _check_serdes_tuple_class_invariants(
klass: Type[NamedTuple], require_args_to_match_field_ordering: bool = True
) -> None:
sig_params = signature(klass.__new__).parameters
dunder_new_params = list(sig_params.values())

Expand Down Expand Up @@ -960,14 +971,15 @@ def _with_header(msg: str) -> str:

raise SerdesUsageError(_with_header(error_msg))

value_param = value_params[index]
if value_param.name != field:
error_msg = (
"Params to __new__ must match the order of field declaration in the namedtuple. "
f'Declared field number {index + 1} in the namedtuple is "{field}". '
f'Parameter {index + 1} in __new__ method is "{value_param.name}".'
)
raise SerdesUsageError(_with_header(error_msg))
if require_args_to_match_field_ordering:
value_param = value_params[index]
if value_param.name != field:
error_msg = (
"Params to __new__ must match the order of field declaration in the namedtuple. "
f'Declared field number {index + 1} in the namedtuple is "{field}". '
f'Parameter {index + 1} in __new__ method is "{value_param.name}".'
)
raise SerdesUsageError(_with_header(error_msg))

if len(value_params) > len(klass._fields):
# Ensure that remaining parameters have default values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,8 +1330,5 @@ def test_time_window_partitions_def_serialization(partitions_def):
end_offset=partitions_def.end_offset,
)
assert (
deserialize_value(
serialize_value(time_window_partitions_def.to_serializable_time_window_partitions_def())
).to_time_window_partitions_def()
== time_window_partitions_def
deserialize_value(serialize_value(time_window_partitions_def)) == time_window_partitions_def
)

0 comments on commit 2cdad9c

Please sign in to comment.