From 7ff96a52b7316c8595ee4025e059ef6aded9e2ff Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Wed, 21 Feb 2024 17:17:54 -0500 Subject: [PATCH] [external-assets] Build base asset jobs using AssetGraph [INTERNAL_BRANCH=sean/external-assets-asset-graph-nodes] --- .../dagster/_core/definitions/asset_graph.py | 18 ++++++ .../dagster/_core/definitions/assets_job.py | 55 ++++++------------- .../_core/definitions/internal_asset_graph.py | 3 + .../repository_data_builder.py | 9 ++- .../valid_definitions.py | 2 + .../asset_defs_tests/test_assets_job.py | 30 +++++----- .../test_asset_check_decorator.py | 2 +- .../general_tests/test_pending_repository.py | 4 +- 8 files changed, 66 insertions(+), 57 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/asset_graph.py b/python_modules/dagster/dagster/_core/definitions/asset_graph.py index df0ba2a382e7f..ddab949fc1a04 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_graph.py @@ -287,9 +287,20 @@ def toposorted_asset_keys_by_level(self) -> Sequence[AbstractSet[AssetKey]]: {key for key in level} for level in toposort.toposort(self.asset_dep_graph["upstream"]) ] + @property + @cached_method + def unpartitioned_asset_keys(self) -> AbstractSet[AssetKey]: + return {node.key for node in self.asset_nodes if not node.is_partitioned} + def asset_keys_for_group(self, group_name: str) -> AbstractSet[AssetKey]: return {node.key for node in self.asset_nodes if node.group_name == group_name} + @cached_method + def asset_keys_for_partitions_def( + self, *, partitions_def: PartitionsDefinition + ) -> AbstractSet[AssetKey]: + return {node.key for node in self.asset_nodes if node.partitions_def == partitions_def} + @property @cached_method def root_materializable_asset_keys(self) -> AbstractSet[AssetKey]: @@ -311,6 +322,13 @@ def root_executable_asset_keys(self) -> AbstractSet[AssetKey]: def asset_check_keys(self) -> AbstractSet[AssetCheckKey]: return {key for asset in self.asset_nodes for key in asset.check_keys} + @property + @cached_method + def all_partitions_defs(self) -> Sequence[PartitionsDefinition]: + return sorted( + set(node.partitions_def for node in self.asset_nodes if node.partitions_def), key=repr + ) + @property @cached_method def all_group_names(self) -> AbstractSet[str]: diff --git a/python_modules/dagster/dagster/_core/definitions/assets_job.py b/python_modules/dagster/dagster/_core/definitions/assets_job.py index 5a9883e307f35..3d9a16473ff49 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets_job.py +++ b/python_modules/dagster/dagster/_core/definitions/assets_job.py @@ -5,7 +5,6 @@ Any, Dict, Iterable, - List, Mapping, Optional, Sequence, @@ -18,6 +17,7 @@ import dagster._check as check from dagster._core.definitions.hook_definition import HookDefinition +from dagster._core.definitions.internal_asset_graph import InternalAssetGraph from dagster._core.definitions.policy import RetryPolicy from dagster._core.errors import DagsterInvalidDefinitionError from dagster._core.selector.subset_selector import AssetSelectionData @@ -61,58 +61,37 @@ def is_base_asset_job_name(name: str) -> bool: def get_base_asset_jobs( - assets: Sequence[AssetsDefinition], - asset_checks: Sequence[AssetChecksDefinition], + asset_graph: InternalAssetGraph, resource_defs: Optional[Mapping[str, ResourceDefinition]], executor_def: Optional[ExecutorDefinition], ) -> Sequence[JobDefinition]: - executable_assets = [a for a in assets if a.is_executable] - unexecutable_assets = [a for a in assets if not a.is_executable] - - executable_assets_by_partitions_def: Dict[ - Optional[PartitionsDefinition], List[AssetsDefinition] - ] = defaultdict(list) - for asset in executable_assets: - executable_assets_by_partitions_def[asset.partitions_def].append(asset) - # sort to ensure some stability in the ordering - all_partitions_defs = sorted( - [p for p in executable_assets_by_partitions_def.keys() if p], key=repr - ) - - if len(all_partitions_defs) == 0: + if len(asset_graph.all_partitions_defs) == 0: + executable_asset_keys = asset_graph.executable_asset_keys + loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys return [ build_assets_job( name=ASSET_BASE_JOB_PREFIX, - executable_assets=executable_assets, - loadable_assets=unexecutable_assets, - asset_checks=asset_checks, + executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys), + loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys), + asset_checks=asset_graph.asset_checks_defs, executor_def=executor_def, resource_defs=resource_defs, ) ] else: - unpartitioned_executable_assets = executable_assets_by_partitions_def.get(None, []) jobs = [] - - for i, partitions_def in enumerate(all_partitions_defs): - # all base jobs contain all unpartitioned assets - executable_assets_for_job = [ - *executable_assets_by_partitions_def[partitions_def], - *unpartitioned_executable_assets, - ] + for i, partitions_def in enumerate(asset_graph.all_partitions_defs): + executable_asset_keys = asset_graph.executable_asset_keys & { + *asset_graph.asset_keys_for_partitions_def(partitions_def=partitions_def), + *asset_graph.unpartitioned_asset_keys, + } + loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys jobs.append( build_assets_job( f"{ASSET_BASE_JOB_PREFIX}_{i}", - executable_assets=executable_assets_for_job, - loadable_assets=[ - *( - asset - for asset in executable_assets - if asset not in executable_assets_for_job - ), - *unexecutable_assets, - ], - asset_checks=asset_checks, + executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys), + loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys), + asset_checks=asset_graph.asset_checks_defs, resource_defs=resource_defs, executor_def=executor_def, partitions_def=partitions_def, diff --git a/python_modules/dagster/dagster/_core/definitions/internal_asset_graph.py b/python_modules/dagster/dagster/_core/definitions/internal_asset_graph.py index 48b646ec33904..7234975cd13e6 100644 --- a/python_modules/dagster/dagster/_core/definitions/internal_asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/internal_asset_graph.py @@ -243,6 +243,9 @@ def get_execution_unit_asset_and_check_keys( def assets_defs(self) -> Sequence[AssetsDefinition]: return list(dict.fromkeys(asset.assets_def for asset in self.asset_nodes)) + def assets_defs_for_keys(self, keys: Iterable[AssetKey]) -> Sequence[AssetsDefinition]: + return list(dict.fromkeys([self.get(key).assets_def for key in keys])) + @property def asset_checks_defs(self) -> Sequence[AssetChecksDefinition]: return list(dict.fromkeys(self._asset_checks_defs_by_key.values())) diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py index 02a613a4062e5..05462721ac5f7 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py @@ -50,6 +50,7 @@ from .valid_definitions import VALID_REPOSITORY_DATA_DICT_KEYS, RepositoryListDefinition if TYPE_CHECKING: + from dagster._core.definitions.asset_check_spec import AssetCheckKey from dagster._core.definitions.events import AssetKey @@ -162,6 +163,7 @@ def build_caching_repository_data_from_list( sensors: Dict[str, SensorDefinition] = {} assets_defs: List[AssetsDefinition] = [] asset_keys: Set[AssetKey] = set() + asset_check_keys: Set[AssetCheckKey] = set() source_assets: List[SourceAsset] = [] asset_checks_defs: List[AssetChecksDefinition] = [] for definition in repository_definitions: @@ -228,6 +230,10 @@ def build_caching_repository_data_from_list( source_assets.append(definition) asset_keys.add(definition.key) elif isinstance(definition, AssetChecksDefinition): + for key in definition.keys: + if key in asset_check_keys: + raise DagsterInvalidDefinitionError(f"Duplicate asset check key: {key}") + asset_check_keys.update(definition.keys) asset_checks_defs.append(definition) else: check.failed(f"Unexpected repository entry {definition}") @@ -237,10 +243,9 @@ def build_caching_repository_data_from_list( ) if assets_defs or source_assets or asset_checks_defs: for job_def in get_base_asset_jobs( - assets=asset_graph.assets_defs, + asset_graph=asset_graph, executor_def=default_executor_def, resource_defs=top_level_resources, - asset_checks=asset_checks_defs, ): jobs[job_def.name] = job_def diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py index e767a695a2d2e..8258d87619d4d 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py @@ -2,6 +2,7 @@ from typing_extensions import TypeAlias +from dagster._core.definitions.asset_checks import AssetChecksDefinition from dagster._core.definitions.graph_definition import GraphDefinition from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.schedule_definition import ScheduleDefinition @@ -33,6 +34,7 @@ RepositoryListDefinition: TypeAlias = Union[ "AssetsDefinition", + AssetChecksDefinition, GraphDefinition, JobDefinition, ScheduleDefinition, diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py index 290e135670f6f..8e5aa8d4567e8 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py @@ -1959,16 +1959,17 @@ def unpartitioned_asset(): ... jobs = get_base_asset_jobs( - assets=[ - daily_asset, - daily_asset2, - daily_asset_different_start_date, - hourly_asset, - unpartitioned_asset, - ], + asset_graph=InternalAssetGraph.from_assets( + [ + daily_asset, + daily_asset2, + daily_asset_different_start_date, + hourly_asset, + unpartitioned_asset, + ] + ), executor_def=None, resource_defs={}, - asset_checks=[], ) assert len(jobs) == 3 assert {job_def.name for job_def in jobs} == { @@ -2007,14 +2008,15 @@ def asset_x(asset_b: B): ... jobs = get_base_asset_jobs( - assets=[ - asset_x, - create_external_asset_from_source_asset(asset_a), - create_external_asset_from_source_asset(asset_b), - ], + asset_graph=InternalAssetGraph.from_assets( + [ + asset_x, + create_external_asset_from_source_asset(asset_a), + create_external_asset_from_source_asset(asset_b), + ] + ), executor_def=None, resource_defs={}, - asset_checks=[], ) assert len(jobs) == 2 assert {job_def.name for job_def in jobs} == { diff --git a/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_check_decorator.py b/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_check_decorator.py index 0b22f1561985a..921673651110f 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_check_decorator.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/decorators_tests/test_asset_check_decorator.py @@ -391,7 +391,7 @@ def check1(context: AssetExecutionContext): with pytest.raises( DagsterInvalidDefinitionError, - match='Detected conflicting node definitions with the same name "asset1_check1"', + match="Duplicate asset check key.+asset1.+check1", ): Definitions(asset_checks=[make_check(), make_check()]) diff --git a/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py b/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py index b65305b606842..ccc4c08f0137c 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py @@ -160,11 +160,11 @@ def _op(context, res_upstream): for cd in data ] - @asset(required_resource_keys={"foo"}) + @asset def res_upstream(context): return context.resources.foo - @asset(required_resource_keys={"foo"}) + @asset def res_downstream(context, res_midstream): return res_midstream + context.resources.foo