Skip to content

Commit

Permalink
fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 10, 2023
1 parent 9b4bbd3 commit 76e112f
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod, abstractproperty
from datetime import datetime
from enum import Enum
from functools import cached_property
from typing import (
AbstractSet,
Any,
Expand All @@ -30,7 +31,12 @@
from dagster._serdes import (
whitelist_for_serdes,
)
from dagster._serdes.serdes import FieldSerializer, deserialize_value, serialize_value
from dagster._serdes.serdes import (
FieldSerializer,
NamedTupleSerializer,
deserialize_value,
serialize_value,
)
from dagster._utils import utc_datetime_from_timestamp
from dagster._utils.partitions import DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE
from dagster._utils.schedules import (
Expand Down Expand Up @@ -1814,26 +1820,25 @@ def with_partitions_def(
)


class TimeWindowPartitionsDefinitionSerializer(FieldSerializer):
"""Serializes a TimeWindowPartitionsDefinition by converting it to a SerializableTimeWindowPartitionsDefinition."""

def pack(self, partitions_def: TimeWindowPartitionsDefinition, **_kwargs) -> str:
return serialize_value(partitions_def.to_serializable_time_window_partitions_def())

def unpack(
self,
serialized_time_window_partitions_def: str,
**_kwargs,
) -> TimeWindowPartitionsDefinition:
return deserialize_value(
serialized_time_window_partitions_def, SerializableTimeWindowPartitionsDefinition
).to_time_window_partitions_def()
class TimeWindowPartitionsSubsetSerializer(NamedTupleSerializer):
# TimeWindowPartitionsSubsets have custom logic to delay calculating num_partitions until it
# is needed to improve performance. When serializing, we want to serialize the number of
# partitions, so we force calculatation.
def before_pack(self, value: "TimeWindowPartitionsSubset") -> "TimeWindowPartitionsSubset":
if value._asdict()["num_partitions"] is None:
return TimeWindowPartitionsSubset(
partitions_def=value.partitions_def,
num_partitions=value.num_partitions,
included_time_windows=value.included_time_windows,
)
return value


@whitelist_for_serdes(
field_serializers={"partitions_def": TimeWindowPartitionsDefinitionSerializer}
serializer=TimeWindowPartitionsSubsetSerializer,
)
class TimeWindowPartitionsSubset(
BaseTimeWindowPartitionsSubset,
NamedTuple(
"_TimeWindowPartitionsSubset",
[
Expand All @@ -1842,11 +1847,11 @@ class TimeWindowPartitionsSubset(
("included_time_windows", Sequence[TimeWindow]),
],
),
BaseTimeWindowPartitionsSubset,
):
"""A PartitionsSubset for a TimeWindowPartitionsDefinition, which internally represents the
included partitions using TimeWindows.
"""

def __new__(
cls,
partitions_def: TimeWindowPartitionsDefinition,
Expand Down
6 changes: 5 additions & 1 deletion python_modules/dagster/dagster/_serdes/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def pack(
) -> Dict[str, JsonSerializableValue]:
packed: Dict[str, JsonSerializableValue] = {}
packed["__class__"] = self.get_storage_name()
for key, inner_value in value._asdict().items():
for key, inner_value in self.before_pack(value)._asdict().items():
if key in self.skip_when_empty_fields and inner_value in EMPTY_VALUES_TO_SKIP:
continue
storage_key = self.storage_field_names.get(key, key)
Expand All @@ -540,6 +540,10 @@ def pack(
packed = self.after_pack(**packed)
return packed

# Hook: Modify the contents of the named tuple before packing
def before_pack(self, value: T_NamedTuple) -> T_NamedTuple:
return value

# Hook: Modify the contents of the packed, json-serializable dict before it is converted to a
# string.
def after_pack(self, **packed_dict: JsonSerializableValue) -> Dict[str, JsonSerializableValue]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,23 @@ def test_time_window_partitions_subset_serialization_deserialization(
)
assert deserialized == subset
assert deserialized.get_partition_keys() == ["2023-01-01"]


def test_time_window_partitions_subset_num_partitions_serialization():
daily_partitions_def = DailyPartitionsDefinition("2023-01-01")
time_partitions_def = TimeWindowPartitionsDefinition(
start=daily_partitions_def.start,
end=daily_partitions_def.end,
cron_schedule="0 0 * * *",
fmt="%Y-%m-%d",
timezone=daily_partitions_def.timezone,
end_offset=daily_partitions_def.end_offset,
)

tw = time_partitions_def.time_window_for_partition_key("2023-01-01")

subset = TimeWindowPartitionsSubset(
time_partitions_def, num_partitions=None, included_time_windows=[tw]
)
deserialized = deserialize_value(serialize_value(subset), TimeWindowPartitionsSubset)
assert deserialized._asdict()["num_partitions"] is not None
Loading

0 comments on commit 76e112f

Please sign in to comment.