diff --git a/python_modules/dagster/dagster/_core/definitions/asset_graph.py b/python_modules/dagster/dagster/_core/definitions/asset_graph.py index fbda6375a9dda..295f3faea6adf 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_graph.py @@ -22,6 +22,7 @@ import toposort import dagster._check as check +from dagster._core.definitions.asset_subset import ValidAssetSubset from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy from dagster._core.errors import DagsterInvalidInvocationError from dagster._core.instance import DynamicPartitionsStore @@ -305,6 +306,71 @@ def get_ancestors( ancestors.add(asset_key) return ancestors + def get_parent_asset_subset( + self, + child_asset_subset: ValidAssetSubset, + parent_asset_key: AssetKey, + dynamic_partitions_store: DynamicPartitionsStore, + current_time: datetime, + ) -> ValidAssetSubset: + """Given a child AssetSubset, returns the corresponding parent AssetSubset, based on the + relevant PartitionMapping. + """ + child_asset_key = child_asset_subset.asset_key + child_partitions_def = self.get_partitions_def(child_asset_key) + parent_partitions_def = self.get_partitions_def(parent_asset_key) + + if parent_partitions_def is None: + return ValidAssetSubset(parent_asset_key, value=child_asset_subset.size > 0) + + partition_mapping = self.get_partition_mapping(child_asset_key, parent_asset_key) + parent_partitions_subset = ( + partition_mapping.get_upstream_mapped_partitions_result_for_partitions( + child_asset_subset.subset_value if child_partitions_def is not None else None, + downstream_partitions_def=child_partitions_def, + upstream_partitions_def=parent_partitions_def, + dynamic_partitions_store=dynamic_partitions_store, + current_time=current_time, + ) + ).partitions_subset + + return ValidAssetSubset(parent_asset_key, value=parent_partitions_subset) + + def get_child_asset_subset( + self, + parent_asset_subset: ValidAssetSubset, + child_asset_key: AssetKey, + dynamic_partitions_store: DynamicPartitionsStore, + current_time: datetime, + ) -> ValidAssetSubset: + """Given a parent AssetSubset, returns the corresponding child AssetSubset, based on the + relevant PartitionMapping. + """ + parent_asset_key = parent_asset_subset.asset_key + parent_partitions_def = self.get_partitions_def(parent_asset_key) + child_partitions_def = self.get_partitions_def(child_asset_key) + + if parent_partitions_def is None: + if parent_asset_subset.size > 0: + return ValidAssetSubset.all( + child_asset_key, child_partitions_def, dynamic_partitions_store, current_time + ) + else: + return ValidAssetSubset.empty(child_asset_key, child_partitions_def) + + if child_partitions_def is None: + return ValidAssetSubset(child_asset_key, value=parent_asset_subset.size > 0) + else: + partition_mapping = self.get_partition_mapping(child_asset_key, parent_asset_key) + child_partitions_subset = partition_mapping.get_downstream_partitions_for_partitions( + parent_asset_subset.subset_value, + parent_partitions_def, + downstream_partitions_def=child_partitions_def, + dynamic_partitions_store=dynamic_partitions_store, + current_time=current_time, + ) + return ValidAssetSubset(child_asset_key, value=child_partitions_subset) + def get_children_partitions( self, dynamic_partitions_store: DynamicPartitionsStore, diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py index b53d1a4768656..171ac03e5213d 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py @@ -26,6 +26,7 @@ from dagster._core.definitions.asset_check_spec import AssetCheckSpec from dagster._core.definitions.asset_graph import AssetGraph from dagster._core.definitions.asset_graph_subset import AssetGraphSubset +from dagster._core.definitions.asset_subset import AssetSubset from dagster._core.definitions.decorators.asset_check_decorator import asset_check from dagster._core.definitions.events import AssetKeyPartitionKey from dagster._core.definitions.external_asset_graph import ExternalAssetGraph @@ -99,7 +100,9 @@ def asset3(asset1, asset2): assert asset_graph.get_code_version(asset1.key) is None -def test_get_children_partitions_unpartitioned_parent_partitioned_child(asset_graph_from_assets): +def test_get_children_partitions_unpartitioned_parent_partitioned_child( + asset_graph_from_assets, +) -> None: @asset def parent(): ... @@ -112,12 +115,29 @@ def child(parent): current_time = pendulum.now("UTC") asset_graph = asset_graph_from_assets([parent, child]) - assert asset_graph.get_children_partitions(instance, current_time, parent.key) == set( - [AssetKeyPartitionKey(child.key, "a"), AssetKeyPartitionKey(child.key, "b")] + + expected_asset_partitions = { + AssetKeyPartitionKey(child.key, "a"), + AssetKeyPartitionKey(child.key, "b"), + } + assert ( + asset_graph.get_children_partitions(instance, current_time, parent.key) + == expected_asset_partitions + ) + assert ( + asset_graph.get_child_asset_subset( + AssetSubset.all(parent.key, parent.partitions_def), + child.key, + instance, + current_time, + ).asset_partitions + == expected_asset_partitions ) -def test_get_parent_partitions_unpartitioned_child_partitioned_parent(asset_graph_from_assets): +def test_get_parent_partitions_unpartitioned_child_partitioned_parent( + asset_graph_from_assets: Callable[..., AssetGraph], +): @asset(partitions_def=StaticPartitionsDefinition(["a", "b"])) def parent(): ... @@ -130,14 +150,27 @@ def child(parent): current_time = pendulum.now("UTC") asset_graph = asset_graph_from_assets([parent, child]) - assert asset_graph.get_parents_partitions( - instance, current_time, child.key - ).parent_partitions == set( - [AssetKeyPartitionKey(parent.key, "a"), AssetKeyPartitionKey(parent.key, "b")] + expected_asset_partitions = { + AssetKeyPartitionKey(parent.key, "a"), + AssetKeyPartitionKey(parent.key, "b"), + } + assert ( + asset_graph.get_parents_partitions(instance, current_time, child.key).parent_partitions + == expected_asset_partitions + ) + + assert ( + asset_graph.get_parent_asset_subset( + AssetSubset.all(child.key, child.partitions_def), + parent.key, + instance, + current_time, + ).asset_partitions + == expected_asset_partitions ) -def test_get_children_partitions_fan_out(asset_graph_from_assets): +def test_get_children_partitions_fan_out(asset_graph_from_assets: Callable[..., AssetGraph]): @asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01")) def parent(): ... @@ -150,17 +183,30 @@ def child(parent): with instance_for_test() as instance: current_time = pendulum.now("UTC") - assert asset_graph.get_children_partitions( - instance, current_time, parent.key, "2022-01-03" - ) == set( - [ - AssetKeyPartitionKey(child.key, f"2022-01-03-{str(hour).zfill(2)}:00") - for hour in range(24) - ] + expected_asset_partitions = { + AssetKeyPartitionKey(child.key, f"2022-01-03-{str(hour).zfill(2)}:00") + for hour in range(24) + } + parent_asset_partition = AssetKeyPartitionKey(parent.key, "2022-01-03") + + assert ( + asset_graph.get_children_partitions(instance, current_time, parent.key, "2022-01-03") + == expected_asset_partitions + ) + assert ( + asset_graph.get_child_asset_subset( + AssetSubset.from_asset_partitions_set( + parent.key, parent.partitions_def, {parent_asset_partition} + ), + child.key, + instance, + current_time, + ).asset_partitions + == expected_asset_partitions ) -def test_get_parent_partitions_fan_in(asset_graph_from_assets): +def test_get_parent_partitions_fan_in(asset_graph_from_assets: Callable[..., AssetGraph]) -> None: @asset(partitions_def=HourlyPartitionsDefinition(start_date="2022-01-01-00:00")) def parent(): ... @@ -174,17 +220,34 @@ def child(parent): with instance_for_test() as instance: current_time = pendulum.now("UTC") - assert asset_graph.get_parents_partitions( - instance, current_time, child.key, "2022-01-03" - ).parent_partitions == set( - [ - AssetKeyPartitionKey(parent.key, f"2022-01-03-{str(hour).zfill(2)}:00") - for hour in range(24) - ] + expected_asset_partitions = { + AssetKeyPartitionKey(parent.key, f"2022-01-03-{str(hour).zfill(2)}:00") + for hour in range(24) + } + child_asset_partition = AssetKeyPartitionKey(child.key, "2022-01-03") + + assert ( + asset_graph.get_parents_partitions( + instance, current_time, child.key, child_asset_partition.partition_key + ).parent_partitions + == expected_asset_partitions + ) + assert ( + asset_graph.get_parent_asset_subset( + AssetSubset.from_asset_partitions_set( + child.key, child.partitions_def, {child_asset_partition} + ), + parent.key, + instance, + current_time, + ).asset_partitions + == expected_asset_partitions ) -def test_get_parent_partitions_non_default_partition_mapping(asset_graph_from_assets): +def test_get_parent_partitions_non_default_partition_mapping( + asset_graph_from_assets: Callable[..., AssetGraph], +): @asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01")) def parent(): ... @@ -199,14 +262,23 @@ def child(parent): with instance_for_test() as instance: current_time = pendulum.now("UTC") + expected_asset_partitions = {AssetKeyPartitionKey(parent.key, "2022-01-02")} mapped_partitions_result = asset_graph.get_parents_partitions( instance, current_time, child.key ) - assert mapped_partitions_result.parent_partitions == { - AssetKeyPartitionKey(parent.key, "2022-01-02") - } + assert mapped_partitions_result.parent_partitions == expected_asset_partitions assert mapped_partitions_result.required_but_nonexistent_parents_partitions == set() + assert ( + asset_graph.get_parent_asset_subset( + AssetSubset.all(child.key, child.partitions_def), + parent.key, + instance, + current_time, + ).asset_partitions + == expected_asset_partitions + ) + def test_custom_unsupported_partition_mapping(): class TrailingWindowPartitionMapping(PartitionMapping): @@ -283,8 +355,10 @@ def child(parent): ).parent_partitions == {AssetKeyPartitionKey(parent.key, "2")} -def test_required_multi_asset_sets_non_subsettable_multi_asset(asset_graph_from_assets): - @multi_asset(outs={"a": AssetOut(dagster_type=None), "b": AssetOut(dagster_type=None)}) +def test_required_multi_asset_sets_non_subsettable_multi_asset( + asset_graph_from_assets: Callable[..., AssetGraph], +): + @multi_asset(outs={"a": AssetOut(), "b": AssetOut()}) def non_subsettable_multi_asset(): ... @@ -295,9 +369,12 @@ def non_subsettable_multi_asset(): ) -def test_required_multi_asset_sets_subsettable_multi_asset(asset_graph_from_assets): +def test_required_multi_asset_sets_subsettable_multi_asset( + asset_graph_from_assets: Callable[..., AssetGraph], +): @multi_asset( - outs={"a": AssetOut(dagster_type=None), "b": AssetOut(dagster_type=None)}, can_subset=True + outs={"a": AssetOut(), "b": AssetOut()}, + can_subset=True, ) def subsettable_multi_asset(): ... @@ -307,7 +384,9 @@ def subsettable_multi_asset(): assert asset_graph.get_required_multi_asset_keys(asset_key) == set() -def test_required_multi_asset_sets_graph_backed_multi_asset(asset_graph_from_assets): +def test_required_multi_asset_sets_graph_backed_multi_asset( + asset_graph_from_assets: Callable[..., AssetGraph], +): @op def op1(): return 1 @@ -328,7 +407,9 @@ def graph1(): assert asset_graph.get_required_multi_asset_keys(asset_key) == graph_backed_multi_asset.keys -def test_required_multi_asset_sets_same_op_in_different_assets(asset_graph_from_assets): +def test_required_multi_asset_sets_same_op_in_different_assets( + asset_graph_from_assets: Callable[..., AssetGraph], +): @op def op1(): ... @@ -342,7 +423,7 @@ def op1(): assert asset_graph.get_required_multi_asset_keys(asset_def.key) == set() -def test_get_non_source_roots_missing_source(asset_graph_from_assets): +def test_get_non_source_roots_missing_source(asset_graph_from_assets: Callable[..., AssetGraph]): @asset def foo(): pass @@ -357,7 +438,7 @@ def bar(foo): assert asset_graph.get_non_source_roots(AssetKey("bar")) == {AssetKey("foo")} -def test_partitioned_source_asset(asset_graph_from_assets): +def test_partitioned_source_asset(asset_graph_from_assets: Callable[..., AssetGraph]): partitions_def = DailyPartitionsDefinition(start_date="2022-01-01") partitioned_source = SourceAsset( @@ -378,7 +459,7 @@ def downstream_of_partitioned_source(): assert asset_graph.is_partitioned(AssetKey("downstream_of_partitioned_source")) -def test_bfs_filter_asset_subsets(asset_graph_from_assets): +def test_bfs_filter_asset_subsets(asset_graph_from_assets: Callable[..., AssetGraph]): daily_partitions_def = DailyPartitionsDefinition(start_date="2022-01-01") @asset(partitions_def=daily_partitions_def) @@ -408,6 +489,7 @@ def include_all(asset_key, partitions_subset): initial_asset1_subset = AssetGraphSubset( partitions_subsets_by_asset_key={asset1.key: initial_partitions_subset} ) + assert asset3.partitions_def is not None corresponding_asset3_subset = AssetGraphSubset( partitions_subsets_by_asset_key={ asset3.key: asset3.partitions_def.empty_subset().with_partition_key_range( @@ -470,7 +552,9 @@ def exclude_asset2(asset_key, partitions_subset): ) -def test_bfs_filter_asset_subsets_different_mappings(asset_graph_from_assets): +def test_bfs_filter_asset_subsets_different_mappings( + asset_graph_from_assets: Callable[..., AssetGraph], +): daily_partitions_def = DailyPartitionsDefinition(start_date="2022-01-01") @asset(partitions_def=daily_partitions_def) @@ -530,7 +614,7 @@ def include_all(asset_key, partitions_subset): ) -def test_asset_graph_subset_contains(asset_graph_from_assets) -> None: +def test_asset_graph_subset_contains(asset_graph_from_assets: Callable[..., AssetGraph]) -> None: daily_partitions_def = DailyPartitionsDefinition(start_date="2022-01-01") @asset(partitions_def=daily_partitions_def) @@ -568,7 +652,7 @@ def unpartitioned2(): assert AssetKeyPartitionKey(partitioned2.key, "2022-01-01") not in asset_graph_subset -def test_asset_graph_difference(asset_graph_from_assets): +def test_asset_graph_difference(asset_graph_from_assets: Callable[..., AssetGraph]): daily_partitions_def = DailyPartitionsDefinition(start_date="2022-01-01") @asset(partitions_def=daily_partitions_def) @@ -631,7 +715,7 @@ def unpartitioned2(): ) -def test_asset_graph_partial_deserialization(asset_graph_from_assets): +def test_asset_graph_partial_deserialization(asset_graph_from_assets: Callable[..., AssetGraph]): daily_partitions_def = DailyPartitionsDefinition(start_date="2022-01-01") static_partitions_def = StaticPartitionsDefinition(["a", "b", "c"]) @@ -694,7 +778,7 @@ def unpartitioned3(): AssetKey("unpartitioned1"), AssetKey("unpartitioned2"), }, - ).to_storage_dict(dynamic_partitions_store=None, asset_graph=get_ag1()) + ).to_storage_dict(dynamic_partitions_store=None, asset_graph=get_ag1()) # type: ignore asset_graph2 = get_ag2() assert not AssetGraphSubset.can_deserialize(ag1_storage_dict, asset_graph2) @@ -717,7 +801,9 @@ def unpartitioned3(): ) -def test_required_assets_and_checks_by_key_check_decorator(asset_graph_from_assets): +def test_required_assets_and_checks_by_key_check_decorator( + asset_graph_from_assets: Callable[..., AssetGraph], +): @asset def asset0(): ... @@ -731,7 +817,9 @@ def check0(): assert asset_graph.get_required_asset_and_check_keys(check0.spec.key) == set() -def test_required_assets_and_checks_by_key_asset_decorator(asset_graph_from_assets): +def test_required_assets_and_checks_by_key_asset_decorator( + asset_graph_from_assets: Callable[..., AssetGraph], +): foo_check = AssetCheckSpec(name="foo", asset="asset0") bar_check = AssetCheckSpec(name="bar", asset="asset0") @@ -752,7 +840,9 @@ def check0(): assert asset_graph.get_required_asset_and_check_keys(check0.spec.key) == set() -def test_required_assets_and_checks_by_key_multi_asset(asset_graph_from_assets): +def test_required_assets_and_checks_by_key_multi_asset( + asset_graph_from_assets: Callable[..., AssetGraph], +): foo_check = AssetCheckSpec(name="foo", asset="asset0") bar_check = AssetCheckSpec(name="bar", asset="asset1") @@ -793,7 +883,7 @@ def subsettable_asset_fn(): def test_required_assets_and_checks_by_key_multi_asset_single_asset( - asset_graph_from_assets, + asset_graph_from_assets: Callable[..., AssetGraph], ): foo_check = AssetCheckSpec(name="foo", asset="asset0") bar_check = AssetCheckSpec(name="bar", asset="asset0")