Skip to content

Commit

Permalink
Make AssetGraphSubset and AssetBackfillData serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 9, 2023
1 parent d038030 commit 05c6643
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ def test_launch_asset_backfill_with_upstream_anchor_asset():
target_subset = asset_backfill_data.target_subset
asset_graph = target_subset.asset_graph
assert target_subset == AssetGraphSubset(
target_subset.asset_graph,
partitions_subsets_by_asset_key={
AssetKey("hourly"): asset_graph.get_partitions_def(
AssetKey("hourly")
Expand Down Expand Up @@ -607,7 +606,6 @@ def test_launch_asset_backfill_with_two_anchor_assets():
target_subset = asset_backfill_data.target_subset
asset_graph = target_subset.asset_graph
assert target_subset == AssetGraphSubset(
target_subset.asset_graph,
partitions_subsets_by_asset_key={
AssetKey("hourly1"): asset_graph.get_partitions_def(
AssetKey("hourly1")
Expand Down Expand Up @@ -665,7 +663,6 @@ def test_launch_asset_backfill_with_upstream_anchor_asset_and_non_partitioned_as
target_subset = asset_backfill_data.target_subset
asset_graph = target_subset.asset_graph
assert target_subset == AssetGraphSubset(
target_subset.asset_graph,
non_partitioned_asset_keys={AssetKey("non_partitioned")},
partitions_subsets_by_asset_key={
AssetKey("hourly"): (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def _execute_asset_backfill_iteration_no_side_effects(
updated_backfill = backfill.with_asset_backfill_data(
cast(AssetBackfillIterationResult, result).backfill_data,
dynamic_partitions_store=graphql_context.instance,
asset_graph=asset_graph,
)
graphql_context.instance.update_backfill(updated_backfill)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,15 +579,14 @@ def bfs_filter_subsets(
else None
),
}
result = AssetGraphSubset(self)
result = AssetGraphSubset()

while len(queue) > 0:
asset_key = queue.popleft()
partitions_subset = queued_subsets_by_asset_key.get(asset_key)

if condition_fn(asset_key, partitions_subset):
result |= AssetGraphSubset(
self,
non_partitioned_asset_keys={asset_key} if partitions_subset is None else set(),
partitions_subsets_by_asset_key=(
{asset_key: partitions_subset} if partitions_subset is not None else {}
Expand Down
190 changes: 124 additions & 66 deletions python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,122 @@
import operator
from collections import defaultdict
from datetime import datetime
from typing import AbstractSet, Any, Callable, Dict, Iterable, Mapping, Optional, Set, Union, cast
from typing import (
AbstractSet,
Any,
Callable,
Dict,
Iterable,
Mapping,
NamedTuple,
Optional,
Set,
Union,
cast,
)

from dagster import _check as check
from dagster._core.definitions.partition import (
PartitionsDefinition,
PartitionsSubset,
)
from dagster._core.definitions.time_window_partitions import (
TimePartitionKeyPartitionsSubset,
TimeWindowPartitionsSubset,
)
from dagster._core.errors import (
DagsterDefinitionChangedDeserializationError,
)
from dagster._core.instance import DynamicPartitionsStore
from dagster._serdes.serdes import (
FieldSerializer,
deserialize_value,
serialize_value,
whitelist_for_serdes,
)

from .asset_graph import AssetGraph
from .events import AssetKey, AssetKeyPartitionKey


class AssetGraphSubset:
def __init__(
class PartitionsSubsetByAssetKeySerializer(FieldSerializer):
"""Packs and unpacks a mapping from AssetKey to PartitionsSubset.
In JSON, a key must be a str, int, float, bool, or None. This serializer packs the AssetKey
into a str, and unpacks it back into an AssetKey.
It also converts TimePartitionKeyPartitionsSubsets into serializable TimeWindowPartitionsSubsets.
"""

def pack(self, mapping: Mapping[AssetKey, Any], **_kwargs) -> Mapping[str, Any]:
return {
serialize_value(key): serialize_value(
value.to_time_window_partitions_subset()
if isinstance(value, TimePartitionKeyPartitionsSubset)
else value
)
for key, value in mapping.items()
}

def unpack(
self,
asset_graph: AssetGraph,
mapping: Mapping[str, Any],
**_kwargs,
) -> Mapping[AssetKey, Any]:
return {
deserialize_value(key, AssetKey): deserialize_value(value, TimeWindowPartitionsSubset)
for key, value in mapping.items()
}


@whitelist_for_serdes(
field_serializers={"partitions_subsets_by_asset_key": PartitionsSubsetByAssetKeySerializer}
)
class AssetGraphSubset(
NamedTuple(
"_AssetGraphSubset",
[
("partitions_subsets_by_asset_key", Mapping[AssetKey, PartitionsSubset]),
("non_partitioned_asset_keys", AbstractSet[AssetKey]),
],
)
):
def __new__(
cls,
partitions_subsets_by_asset_key: Optional[Mapping[AssetKey, PartitionsSubset]] = None,
non_partitioned_asset_keys: Optional[AbstractSet[AssetKey]] = None,
):
self._asset_graph = asset_graph
self._partitions_subsets_by_asset_key = partitions_subsets_by_asset_key or {}
self._non_partitioned_asset_keys = non_partitioned_asset_keys or set()

@property
def asset_graph(self) -> AssetGraph:
return self._asset_graph

@property
def partitions_subsets_by_asset_key(self) -> Mapping[AssetKey, PartitionsSubset]:
return self._partitions_subsets_by_asset_key

@property
def non_partitioned_asset_keys(self) -> AbstractSet[AssetKey]:
return self._non_partitioned_asset_keys
return super(AssetGraphSubset, cls).__new__(
cls,
partitions_subsets_by_asset_key=partitions_subsets_by_asset_key or {},
non_partitioned_asset_keys=non_partitioned_asset_keys or set(),
)

@property
def asset_keys(self) -> AbstractSet[AssetKey]:
return {
key for key, subset in self.partitions_subsets_by_asset_key.items() if len(subset) > 0
} | self._non_partitioned_asset_keys
} | self.non_partitioned_asset_keys

@property
def num_partitions_and_non_partitioned_assets(self):
return len(self._non_partitioned_asset_keys) + sum(
len(subset) for subset in self._partitions_subsets_by_asset_key.values()
return len(self.non_partitioned_asset_keys) + sum(
len(subset) for subset in self.partitions_subsets_by_asset_key.values()
)

def get_partitions_subset(self, asset_key: AssetKey) -> PartitionsSubset:
partitions_def = self.asset_graph.get_partitions_def(asset_key)
if partitions_def is None:
check.failed("Can only call get_partitions_subset on a partitioned asset")
# partitions_def = asset_graph.get_partitions_def(asset_key)
# if partitions_def is None:
# check.failed("Can only call get_partitions_subset on a partitioned asset")

return self.partitions_subsets_by_asset_key.get(asset_key, partitions_def.empty_subset())
return self.partitions_subsets_by_asset_key[asset_key]

def iterate_asset_partitions(self) -> Iterable[AssetKeyPartitionKey]:
for asset_key, partitions_subset in self.partitions_subsets_by_asset_key.items():
for partition_key in partitions_subset.get_partition_keys():
yield AssetKeyPartitionKey(asset_key, partition_key)

for asset_key in self._non_partitioned_asset_keys:
for asset_key in self.non_partitioned_asset_keys:
yield AssetKeyPartitionKey(asset_key, None)

def __contains__(self, asset: Union[AssetKey, AssetKeyPartitionKey]) -> bool:
Expand All @@ -74,19 +126,18 @@ def __contains__(self, asset: Union[AssetKey, AssetKeyPartitionKey]) -> bool:
"""
if isinstance(asset, AssetKey):
# check if any keys are in the subset
if self.asset_graph.is_partitioned(asset):
partitions_subset = self.partitions_subsets_by_asset_key.get(asset)
return partitions_subset is not None and len(partitions_subset) > 0
else:
return asset in self._non_partitioned_asset_keys
partitions_subset = self.partitions_subsets_by_asset_key.get(asset)
return (partitions_subset is not None and len(partitions_subset) > 0) or (
asset in self.non_partitioned_asset_keys
)
elif asset.partition_key is None:
return asset.asset_key in self._non_partitioned_asset_keys
return asset.asset_key in self.non_partitioned_asset_keys
else:
partitions_subset = self.partitions_subsets_by_asset_key.get(asset.asset_key)
return partitions_subset is not None and asset.partition_key in partitions_subset

def to_storage_dict(
self, dynamic_partitions_store: DynamicPartitionsStore
self, dynamic_partitions_store: DynamicPartitionsStore, asset_graph: AssetGraph
) -> Mapping[str, object]:
return {
"partitions_subsets_by_asset_key": {
Expand All @@ -95,36 +146,34 @@ def to_storage_dict(
},
"serializable_partitions_def_ids_by_asset_key": {
key.to_user_string(): check.not_none(
self._asset_graph.get_partitions_def(key)
asset_graph.get_partitions_def(key)
).get_serializable_unique_identifier(
dynamic_partitions_store=dynamic_partitions_store
)
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"partitions_def_class_names_by_asset_key": {
key.to_user_string(): check.not_none(
self._asset_graph.get_partitions_def(key)
asset_graph.get_partitions_def(key)
).__class__.__name__
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"non_partitioned_asset_keys": [
key.to_user_string() for key in self._non_partitioned_asset_keys
key.to_user_string() for key in self.non_partitioned_asset_keys
],
}

def _oper(
self, other: Union["AssetGraphSubset", AbstractSet[AssetKeyPartitionKey]], oper: Callable
) -> "AssetGraphSubset":
def _oper(self, other: "AssetGraphSubset", oper: Callable) -> "AssetGraphSubset":
"""Returns the AssetGraphSubset that results from applying the given operator to the set of
asset partitions in self and other.
Note: Not all operators are supported on the underlying PartitionsSubset objects.
"""
result_partition_subsets_by_asset_key = {**self.partitions_subsets_by_asset_key}
result_non_partitioned_asset_keys = set(self._non_partitioned_asset_keys)
result_non_partitioned_asset_keys = set(self.non_partitioned_asset_keys)

if not isinstance(other, AssetGraphSubset):
other = AssetGraphSubset.from_asset_partition_set(other, self.asset_graph)
# if not isinstance(other, AssetGraphSubset):
# other = AssetGraphSubset.from_asset_partition_set(other, self.asset_graph)

for asset_key in other.asset_keys:
if asset_key in other.non_partitioned_asset_keys:
Expand All @@ -133,48 +182,62 @@ def _oper(
result_non_partitioned_asset_keys, {asset_key}
)
else:
subset = self.get_partitions_subset(asset_key)
check.invariant(asset_key not in self.non_partitioned_asset_keys)
result_partition_subsets_by_asset_key[asset_key] = oper(
subset, other.get_partitions_subset(asset_key)
subset = (
self.get_partitions_subset(asset_key)
if asset_key in self.partitions_subsets_by_asset_key
else None
)

other_subset = other.get_partitions_subset(asset_key)
if other_subset is None and subset is None:
pass
if subset is None and other_subset is not None:
if oper == operator.or_:
result_partition_subsets_by_asset_key[asset_key] = other_subset
elif oper == operator.sub:
pass
elif oper == operator.and_:
pass
else:
check.failed(f"Unsupported operator {oper}")
elif subset is not None and other_subset is None:
if oper == operator.or_:
pass
elif oper == operator.sub:
pass
elif oper == operator.and_:
del result_partition_subsets_by_asset_key[asset_key]
else:
result_partition_subsets_by_asset_key[asset_key] = oper(subset, other_subset)

return AssetGraphSubset(
self.asset_graph,
result_partition_subsets_by_asset_key,
result_non_partitioned_asset_keys,
)

def __or__(
self, other: Union["AssetGraphSubset", AbstractSet[AssetKeyPartitionKey]]
) -> "AssetGraphSubset":
def __or__(self, other: "AssetGraphSubset") -> "AssetGraphSubset":
return self._oper(other, operator.or_)

def __sub__(
self, other: Union["AssetGraphSubset", AbstractSet[AssetKeyPartitionKey]]
) -> "AssetGraphSubset":
def __sub__(self, other: "AssetGraphSubset") -> "AssetGraphSubset":
return self._oper(other, operator.sub)

def __and__(
self, other: Union["AssetGraphSubset", AbstractSet[AssetKeyPartitionKey]]
) -> "AssetGraphSubset":
def __and__(self, other: "AssetGraphSubset") -> "AssetGraphSubset":
return self._oper(other, operator.and_)

def filter_asset_keys(self, asset_keys: AbstractSet[AssetKey]) -> "AssetGraphSubset":
return AssetGraphSubset(
self.asset_graph,
{
asset_key: subset
for asset_key, subset in self.partitions_subsets_by_asset_key.items()
if asset_key in asset_keys
},
self._non_partitioned_asset_keys & asset_keys,
self.non_partitioned_asset_keys & asset_keys,
)

def __eq__(self, other) -> bool:
return (
isinstance(other, AssetGraphSubset)
and self.asset_graph == other.asset_graph
and self.partitions_subsets_by_asset_key == other.partitions_subsets_by_asset_key
and self.non_partitioned_asset_keys == other.non_partitioned_asset_keys
)
Expand Down Expand Up @@ -209,7 +272,6 @@ def from_asset_partition_set(
for asset_key, partition_keys in partitions_by_asset_key.items()
},
non_partitioned_asset_keys=non_partitioned_asset_keys,
asset_graph=asset_graph,
)

@classmethod
Expand Down Expand Up @@ -297,9 +359,7 @@ def from_storage_dict(
AssetKey.from_user_string(key) for key in serialized_dict["non_partitioned_asset_keys"]
} & asset_graph.all_asset_keys

return AssetGraphSubset(
asset_graph, partitions_subsets_by_asset_key, non_partitioned_asset_keys
)
return AssetGraphSubset(partitions_subsets_by_asset_key, non_partitioned_asset_keys)

@classmethod
def all(
Expand Down Expand Up @@ -339,6 +399,4 @@ def from_asset_keys(
else:
non_partitioned_asset_keys.add(asset_key)

return AssetGraphSubset(
asset_graph, partitions_subsets_by_asset_key, non_partitioned_asset_keys
)
return AssetGraphSubset(partitions_subsets_by_asset_key, non_partitioned_asset_keys)
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,11 @@ def with_partitions_def(
included_partition_keys=self._included_partition_keys,
)

def to_time_window_partitions_subset(self) -> "TimeWindowPartitionsSubset":
return TimeWindowPartitionsSubset(
self.partitions_def, self.num_partitions, self.included_time_windows
)


class TimeWindowPartitionsDefinitionSerializer(FieldSerializer):
"""Serializes a TimeWindowPartitionsDefinition by converting it to a SerializableTimeWindowPartitionsDefinition."""
Expand Down
Loading

0 comments on commit 05c6643

Please sign in to comment.