Skip to content

Commit

Permalink
asset graph subset changes
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 14, 2023
1 parent 3f091bc commit fad9b01
Showing 1 changed file with 52 additions and 49 deletions.
101 changes: 52 additions & 49 deletions python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import operator
from collections import defaultdict
from datetime import datetime
from functools import cached_property
from typing import (
AbstractSet,
Any,
Expand All @@ -20,78 +21,74 @@
PartitionsDefinition,
PartitionsSubset,
)
from dagster._core.definitions.time_window_partitions import (
PartitionKeysTimeWindowPartitionsSubset,
TimeWindowPartitionsSubset,
)
from dagster._core.definitions.time_window_partitions import PartitionKeysTimeWindowPartitionsSubset
from dagster._core.errors import (
DagsterDefinitionChangedDeserializationError,
)
from dagster._core.instance import DynamicPartitionsStore
from dagster._serdes.serdes import (
FieldSerializer,
deserialize_value,
serialize_value,
whitelist_for_serdes,
)
from dagster._serdes.serdes import NamedTupleSerializer, whitelist_for_serdes

from .asset_graph import AssetGraph
from .events import AssetKey, AssetKeyPartitionKey


class PartitionsSubsetByAssetKeySerializer(FieldSerializer):
"""Packs and unpacks a mapping from AssetKey to PartitionsSubset.
In JSON, a key must be a str, int, float, bool, or None. This serializer packs the AssetKey
into a str, and unpacks it back into an AssetKey.
It also converts PartitionKeysTimeWindowPartitionsSubset into serializable TimeWindowPartitionsSubsets.
"""

def pack(self, mapping: Mapping[AssetKey, Any], **_kwargs) -> Mapping[str, Any]:
return {
serialize_value(key): serialize_value(
value.to_time_window_partitions_subset()
if isinstance(value, PartitionKeysTimeWindowPartitionsSubset)
else value
)
for key, value in mapping.items()
}
class AssetGraphSubsetSerializer(NamedTupleSerializer):
def before_pack(self, value: "AssetGraphSubset") -> "AssetGraphSubset":
converted_partitions_subsets_by_serialized_asset_key = {}
for k, v in value.partitions_subsets_by_serialized_asset_key.items():
if isinstance(v, PartitionKeysTimeWindowPartitionsSubset):
converted_partitions_subsets_by_serialized_asset_key[
k
] = v.to_time_window_partitions_subset()
else:
converted_partitions_subsets_by_serialized_asset_key[k] = v

def unpack(
self,
mapping: Mapping[str, Any],
**_kwargs,
) -> Mapping[AssetKey, Any]:
return {
deserialize_value(key, AssetKey): deserialize_value(value, TimeWindowPartitionsSubset)
for key, value in mapping.items()
}
return value._replace(
partitions_subsets_by_serialized_asset_key=converted_partitions_subsets_by_serialized_asset_key
)


@whitelist_for_serdes(
field_serializers={"partitions_subsets_by_asset_key": PartitionsSubsetByAssetKeySerializer}
)
@whitelist_for_serdes(serializer=AssetGraphSubsetSerializer)
class AssetGraphSubset(
NamedTuple(
"_AssetGraphSubset",
[
("partitions_subsets_by_asset_key", Mapping[AssetKey, PartitionsSubset]),
("partitions_subsets_by_serialized_asset_key", Mapping[str, PartitionsSubset]),
("non_partitioned_asset_keys", AbstractSet[AssetKey]),
],
)
):
def __new__(
cls,
partitions_subsets_by_asset_key: Optional[Mapping[AssetKey, PartitionsSubset]] = None,
partitions_subsets_by_serialized_asset_key: Optional[Mapping[str, PartitionsSubset]] = None,
non_partitioned_asset_keys: Optional[AbstractSet[AssetKey]] = None,
partitions_subsets_by_asset_key: Optional[Mapping[AssetKey, PartitionsSubset]] = None,
):
check.invariant(
not (partitions_subsets_by_serialized_asset_key and partitions_subsets_by_asset_key),
"Cannot provide both partitions_subsets_by_serialized_asset_key and partitions_subsets_by_asset_key",
)

if partitions_subsets_by_asset_key:
partitions_subsets_by_serialized_asset_key = {
key.to_user_string(): value
for key, value in partitions_subsets_by_asset_key.items()
}

return super(AssetGraphSubset, cls).__new__(
cls,
partitions_subsets_by_asset_key=partitions_subsets_by_asset_key or {},
partitions_subsets_by_serialized_asset_key=partitions_subsets_by_serialized_asset_key
or {},
non_partitioned_asset_keys=non_partitioned_asset_keys or set(),
)

@cached_property
def partitions_subsets_by_asset_key(self) -> Mapping[AssetKey, PartitionsSubset]:
return {
AssetKey.from_user_string(key): value
for key, value in self.partitions_subsets_by_serialized_asset_key.items()
}

@property
def asset_keys(self) -> AbstractSet[AssetKey]:
return {
Expand Down Expand Up @@ -222,8 +219,8 @@ def _oper(self, other: "AssetGraphSubset", oper: Callable) -> "AssetGraphSubset"
result_partition_subsets_by_asset_key[asset_key] = oper(subset, other_subset)

return AssetGraphSubset(
result_partition_subsets_by_asset_key,
result_non_partitioned_asset_keys,
partitions_subsets_by_asset_key=result_partition_subsets_by_asset_key,
non_partitioned_asset_keys=result_non_partitioned_asset_keys,
)

def __or__(self, other: "AssetGraphSubset") -> "AssetGraphSubset":
Expand All @@ -237,12 +234,12 @@ def __and__(self, other: "AssetGraphSubset") -> "AssetGraphSubset":

def filter_asset_keys(self, asset_keys: AbstractSet[AssetKey]) -> "AssetGraphSubset":
return AssetGraphSubset(
{
partitions_subsets_by_asset_key={
asset_key: subset
for asset_key, subset in self.partitions_subsets_by_asset_key.items()
if asset_key in asset_keys
},
self.non_partitioned_asset_keys & asset_keys,
non_partitioned_asset_keys=self.non_partitioned_asset_keys & asset_keys,
)

def __eq__(self, other) -> bool:
Expand Down Expand Up @@ -371,7 +368,10 @@ def from_storage_dict(
AssetKey.from_user_string(key) for key in serialized_dict["non_partitioned_asset_keys"]
} & asset_graph.all_asset_keys

return AssetGraphSubset(partitions_subsets_by_asset_key, non_partitioned_asset_keys)
return AssetGraphSubset(
partitions_subsets_by_asset_key=partitions_subsets_by_asset_key,
non_partitioned_asset_keys=non_partitioned_asset_keys,
)

@classmethod
def all(
Expand Down Expand Up @@ -412,4 +412,7 @@ def from_asset_keys(
else:
non_partitioned_asset_keys.add(asset_key)

return AssetGraphSubset(partitions_subsets_by_asset_key, non_partitioned_asset_keys)
return AssetGraphSubset(
partitions_subsets_by_asset_key=partitions_subsets_by_asset_key,
non_partitioned_asset_keys=non_partitioned_asset_keys,
)

0 comments on commit fad9b01

Please sign in to comment.