Skip to content

Commit

Permalink
adjust to use mapping keyed by asset key
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Nov 21, 2023
1 parent 5682c61 commit 9a28f89
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Tuple

import mock
from dagster import (
AssetKey,
DailyPartitionsDefinition,
Expand Down Expand Up @@ -421,6 +422,71 @@ def test_launch_asset_backfill():
)


def test_remove_partitions_defs_after_backfill_backcompat():
repo = get_repo()
all_asset_keys = repo.asset_graph.materializable_asset_keys

with instance_for_test() as instance:
with define_out_of_process_context(__file__, "get_repo", instance) as context:
launch_backfill_result = execute_dagster_graphql(
context,
LAUNCH_PARTITION_BACKFILL_MUTATION,
variables={
"backfillParams": {
"partitionNames": ["a", "b"],
"assetSelection": [key.to_graphql_input() for key in all_asset_keys],
}
},
)
backfill_id, asset_backfill_data = _get_backfill_data(
launch_backfill_result, instance, repo
)
assert asset_backfill_data.target_subset.asset_keys == all_asset_keys

# Replace the asset backfill data with the backcompat serialization
backfill = instance.get_backfills()[0]
backcompat_backfill = backfill._replace(
asset_backfill_data=None,
serialized_asset_backfill_data=backfill.asset_backfill_data.serialize(
instance, asset_graph=repo.asset_graph
),
)

with mock.patch(
"dagster._core.instance.DagsterInstance.get_backfills",
return_value=[backcompat_backfill],
):
# When the partitions defs are unchanged, the backfill data can be fetched
with define_out_of_process_context(__file__, "get_repo", instance) as context:
get_backfills_result = execute_dagster_graphql(
context, GET_PARTITION_BACKFILLS_QUERY, variables={}
)
assert not get_backfills_result.errors
assert get_backfills_result.data

backfill_results = get_backfills_result.data["partitionBackfillsOrError"]["results"]
assert len(backfill_results) == 1
assert backfill_results[0]["numPartitions"] == 2
assert backfill_results[0]["id"] == backfill_id
assert set(backfill_results[0]["partitionNames"]) == {"a", "b"}

# When the partitions defs are changed, the backfill data cannot be fetched
with define_out_of_process_context(
__file__, "get_repo_with_non_partitioned_asset", instance
) as context:
get_backfills_result = execute_dagster_graphql(
context, GET_PARTITION_BACKFILLS_QUERY, variables={}
)
assert not get_backfills_result.errors
assert get_backfills_result.data

backfill_results = get_backfills_result.data["partitionBackfillsOrError"]["results"]
assert len(backfill_results) == 1
assert backfill_results[0]["numPartitions"] == 0
assert backfill_results[0]["id"] == backfill_id
assert set(backfill_results[0]["partitionNames"]) == set()


def test_remove_partitions_defs_after_backfill():
repo = get_repo()
all_asset_keys = repo.asset_graph.materializable_asset_keys
Expand Down
93 changes: 61 additions & 32 deletions python_modules/dagster/dagster/_core/execution/asset_backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
IWorkspaceProcessContext,
)
from dagster._core.workspace.workspace import IWorkspace
from dagster._serdes import whitelist_for_serdes
from dagster._serdes import SerializableNonScalarKeyMapping, whitelist_for_serdes
from dagster._utils import hash_collection, utc_datetime_from_timestamp
from dagster._utils.caching_instance_queryer import CachingInstanceQueryer

Expand Down Expand Up @@ -133,22 +133,53 @@ def __new__(cls, asset_key: AssetKey, asset_backfill_status: Optional[AssetBackf


@whitelist_for_serdes(field_serializers={"backfill_start_time": DatetimeFieldSerializer})
class AssetBackfillData(NamedTuple):
"""Has custom serialization instead of standard Dagster NamedTuple serialization because the
asset graph is required to build the AssetGraphSubset objects.
"""

target_subset: AssetGraphSubset
requested_runs_for_target_roots: bool
latest_storage_id: Optional[int]
materialized_subset: AssetGraphSubset
requested_subset: AssetGraphSubset
failed_and_downstream_subset: AssetGraphSubset
backfill_start_time: datetime
partitions_ids_by_serialized_asset_key: Optional[Mapping[str, str]]

# TODO add __new__ that asserts that partitions_ids_by_serialized_asset_key
# contains all keys for target subset
class AssetBackfillData(
NamedTuple(
"_AssetBackfillData",
[
("target_subset", AssetGraphSubset),
("requested_runs_for_target_roots", bool),
("latest_storage_id", Optional[int]),
("materialized_subset", AssetGraphSubset),
("requested_subset", AssetGraphSubset),
("failed_and_downstream_subset", AssetGraphSubset),
("backfill_start_time", datetime),
("partitions_def_ids_by_asset_key", Optional[Mapping[AssetKey, str]]),
],
)
):
def __new__(
cls,
target_subset: AssetGraphSubset,
requested_runs_for_target_roots: bool,
latest_storage_id: Optional[int],
materialized_subset: AssetGraphSubset,
requested_subset: AssetGraphSubset,
failed_and_downstream_subset: AssetGraphSubset,
backfill_start_time: datetime,
partitions_def_ids_by_asset_key: Optional[Mapping[AssetKey, str]],
):
check.opt_mapping_param(
partitions_def_ids_by_asset_key,
"partitions_def_ids_by_asset_key",
key_type=AssetKey,
value_type=str,
)
return super(AssetBackfillData, cls).__new__(
cls,
check.inst_param(target_subset, "target_subset", AssetGraphSubset),
check.bool_param(requested_runs_for_target_roots, "requested_runs_for_target_roots"),
check.opt_int_param(latest_storage_id, "latest_storage_id"),
check.inst_param(materialized_subset, "materialized_subset", AssetGraphSubset),
check.inst_param(requested_subset, "requested_subset", AssetGraphSubset),
check.inst_param(
failed_and_downstream_subset, "failed_and_downstream_subset", AssetGraphSubset
),
check.inst_param(backfill_start_time, "backfill_start_time", datetime),
SerializableNonScalarKeyMapping(partitions_def_ids_by_asset_key)
if partitions_def_ids_by_asset_key
else None,
)

def replace_requested_subset(self, requested_subset: AssetGraphSubset) -> "AssetBackfillData":
return AssetBackfillData(
Expand All @@ -159,7 +190,7 @@ def replace_requested_subset(self, requested_subset: AssetGraphSubset) -> "Asset
failed_and_downstream_subset=self.failed_and_downstream_subset,
requested_subset=requested_subset,
backfill_start_time=self.backfill_start_time,
partitions_ids_by_serialized_asset_key=self.partitions_ids_by_serialized_asset_key,
partitions_def_ids_by_asset_key=self.partitions_def_ids_by_asset_key,
)

def is_complete(self) -> bool:
Expand Down Expand Up @@ -400,8 +431,8 @@ def empty(
asset_graph: AssetGraph,
dynamic_partitions_store: DynamicPartitionsStore,
) -> "AssetBackfillData":
partition_ids_by_serialized_asset_key = {
asset_key.to_string(): check.not_none(
partition_ids_by_asset_key: Mapping[AssetKey, str] = {
asset_key: check.not_none(
asset_graph.get_partitions_def(asset_key)
).get_serializable_unique_identifier(dynamic_partitions_store)
for asset_key in target_subset.partitions_subsets_by_asset_key.keys()
Expand All @@ -414,7 +445,7 @@ def empty(
failed_and_downstream_subset=AssetGraphSubset(),
latest_storage_id=None,
backfill_start_time=backfill_start_time,
partitions_ids_by_serialized_asset_key=partition_ids_by_serialized_asset_key,
partitions_def_ids_by_asset_key=partition_ids_by_asset_key,
)

@classmethod
Expand All @@ -430,12 +461,10 @@ def from_serialized(
) -> "AssetBackfillData":
storage_dict = json.loads(serialized)

target_subset = AssetGraphSubset.from_storage_dict(
storage_dict["serialized_target_subset"], asset_graph
)

return cls(
target_subset=target_subset,
target_subset=AssetGraphSubset.from_storage_dict(
storage_dict["serialized_target_subset"], asset_graph
),
requested_runs_for_target_roots=storage_dict["requested_runs_for_target_roots"],
requested_subset=AssetGraphSubset.from_storage_dict(
storage_dict["serialized_requested_subset"], asset_graph
Expand All @@ -448,7 +477,7 @@ def from_serialized(
),
latest_storage_id=storage_dict["latest_storage_id"],
backfill_start_time=utc_datetime_from_timestamp(backfill_start_timestamp),
partitions_ids_by_serialized_asset_key=None,
partitions_def_ids_by_asset_key=None,
)

@classmethod
Expand Down Expand Up @@ -817,14 +846,14 @@ def _check_and_deserialize_asset_backfill_data(
)
elif backfill.asset_backfill_data:
asset_backfill_data = backfill.asset_backfill_data
partitions_ids_by_serialized_asset_key = check.not_none(
asset_backfill_data.partitions_ids_by_serialized_asset_key
partitions_def_ids_by_asset_key = check.not_none(
asset_backfill_data.partitions_def_ids_by_asset_key
)
for asset_key in asset_backfill_data.target_subset.asset_keys:
_check_no_partitions_def_changes_to_asset(
asset_key,
asset_graph,
partitions_ids_by_serialized_asset_key.get(asset_key.to_string()),
partitions_def_ids_by_asset_key.get(asset_key),
instance_queryer,
)
else:
Expand Down Expand Up @@ -1030,7 +1059,7 @@ def get_canceling_asset_backfill_iteration_data(
failed_and_downstream_subset=failed_and_downstream_subset,
requested_subset=asset_backfill_data.requested_subset,
backfill_start_time=backfill_start_time,
partitions_ids_by_serialized_asset_key=asset_backfill_data.partitions_ids_by_serialized_asset_key,
partitions_def_ids_by_asset_key=asset_backfill_data.partitions_def_ids_by_asset_key,
)

yield updated_backfill_data
Expand Down Expand Up @@ -1314,7 +1343,7 @@ def execute_asset_backfill_iteration_inner(
requested_subset=asset_backfill_data.requested_subset
| AssetGraphSubset.from_asset_partition_set(set(asset_partitions_to_request), asset_graph),
backfill_start_time=backfill_start_time,
partitions_ids_by_serialized_asset_key=asset_backfill_data.partitions_ids_by_serialized_asset_key,
partitions_def_ids_by_asset_key=asset_backfill_data.partitions_def_ids_by_asset_key,
)
yield AssetBackfillIterationResult(run_requests, updated_asset_backfill_data)

Expand Down
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/_serdes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .serdes import (
EnumSerializer as EnumSerializer,
NamedTupleSerializer as NamedTupleSerializer,
SerializableNonScalarKeyMapping as SerializableNonScalarKeyMapping,
WhitelistMap as WhitelistMap,
deserialize_value as deserialize_value,
pack_value as pack_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def repo():
partitions_subsets_by_asset_key={},
non_partitioned_asset_keys=set(),
)
partition_ids_by_serialized_asset_key = {
asset_key.to_string(): check.not_none(
partitions_defs_ids_by_asset_key = {
asset_key: check.not_none(
repo.asset_graph.get_partitions_def(asset_key)
).get_serializable_unique_identifier(dynamic_partitions_store=instance)
for asset_key in target_subset.partitions_subsets_by_asset_key.keys()
Expand All @@ -333,7 +333,7 @@ def repo():
requested_subset=empty_subset,
failed_and_downstream_subset=empty_subset,
backfill_start_time=test_time,
partitions_ids_by_serialized_asset_key=partition_ids_by_serialized_asset_key,
partitions_defs_ids_by_asset_key=partitions_defs_ids_by_asset_key,
)
backfill = PartitionBackfill(
backfill_id=f"backfill{i}",
Expand Down

0 comments on commit 9a28f89

Please sign in to comment.