From 8ce31aa666e21cc67055c0a3c6add165e9b7cc4e Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Wed, 21 Feb 2024 17:17:54 -0500 Subject: [PATCH] [external-assets] Build base asset jobs using AssetGraph --- .../dagster/_core/definitions/asset_graph.py | 3 + .../dagster/_core/definitions/assets_job.py | 55 ++++++------------- .../_core/definitions/base_asset_graph.py | 16 ++++++ .../repository_data_builder.py | 3 +- .../asset_defs_tests/test_assets_job.py | 30 +++++----- .../general_tests/test_pending_repository.py | 12 ++-- 6 files changed, 59 insertions(+), 60 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/asset_graph.py b/python_modules/dagster/dagster/_core/definitions/asset_graph.py index dda6c4fcdbad4..800ade27b64e2 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_graph.py @@ -254,6 +254,9 @@ def get_execution_set_asset_and_check_keys( def assets_defs(self) -> Sequence[AssetsDefinition]: return list(dict.fromkeys(asset.assets_def for asset in self.asset_nodes)) + def assets_defs_for_keys(self, keys: Iterable[AssetKey]) -> Sequence[AssetsDefinition]: + return list(dict.fromkeys([self.get(key).assets_def for key in keys])) + @property def asset_checks_defs(self) -> Sequence[AssetChecksDefinition]: return list(dict.fromkeys(self._asset_checks_defs_by_key.values())) diff --git a/python_modules/dagster/dagster/_core/definitions/assets_job.py b/python_modules/dagster/dagster/_core/definitions/assets_job.py index a6ac727722166..e12c98dc1040d 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets_job.py +++ b/python_modules/dagster/dagster/_core/definitions/assets_job.py @@ -5,7 +5,6 @@ Any, Dict, Iterable, - List, Mapping, Optional, Sequence, @@ -17,6 +16,7 @@ from toposort import CircularDependencyError, toposort import dagster._check as check +from dagster._core.definitions.asset_graph import AssetGraph from dagster._core.definitions.hook_definition import HookDefinition from dagster._core.definitions.policy import RetryPolicy from dagster._core.errors import DagsterInvalidDefinitionError @@ -61,58 +61,37 @@ def is_base_asset_job_name(name: str) -> bool: def get_base_asset_jobs( - assets: Sequence[AssetsDefinition], - asset_checks: Sequence[AssetChecksDefinition], + asset_graph: AssetGraph, resource_defs: Optional[Mapping[str, ResourceDefinition]], executor_def: Optional[ExecutorDefinition], ) -> Sequence[JobDefinition]: - executable_assets = [a for a in assets if a.is_executable] - unexecutable_assets = [a for a in assets if not a.is_executable] - - executable_assets_by_partitions_def: Dict[ - Optional[PartitionsDefinition], List[AssetsDefinition] - ] = defaultdict(list) - for asset in executable_assets: - executable_assets_by_partitions_def[asset.partitions_def].append(asset) - # sort to ensure some stability in the ordering - all_partitions_defs = sorted( - [p for p in executable_assets_by_partitions_def.keys() if p], key=repr - ) - - if len(all_partitions_defs) == 0: + if len(asset_graph.all_partitions_defs) == 0: + executable_asset_keys = asset_graph.executable_asset_keys + loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys return [ build_assets_job( name=ASSET_BASE_JOB_PREFIX, - executable_assets=executable_assets, - loadable_assets=unexecutable_assets, - asset_checks=asset_checks, + executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys), + loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys), + asset_checks=asset_graph.asset_checks_defs, executor_def=executor_def, resource_defs=resource_defs, ) ] else: - unpartitioned_executable_assets = executable_assets_by_partitions_def.get(None, []) jobs = [] - - for i, partitions_def in enumerate(all_partitions_defs): - # all base jobs contain all unpartitioned assets - executable_assets_for_job = [ - *executable_assets_by_partitions_def[partitions_def], - *unpartitioned_executable_assets, - ] + for i, partitions_def in enumerate(asset_graph.all_partitions_defs): + executable_asset_keys = asset_graph.executable_asset_keys & { + *asset_graph.asset_keys_for_partitions_def(partitions_def=partitions_def), + *asset_graph.unpartitioned_asset_keys, + } + loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys jobs.append( build_assets_job( f"{ASSET_BASE_JOB_PREFIX}_{i}", - executable_assets=executable_assets_for_job, - loadable_assets=[ - *( - asset - for asset in executable_assets - if asset not in executable_assets_for_job - ), - *unexecutable_assets, - ], - asset_checks=asset_checks, + executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys), + loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys), + asset_checks=asset_graph.asset_checks_defs, resource_defs=resource_defs, executor_def=executor_def, partitions_def=partitions_def, diff --git a/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py b/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py index 7262a3af9c1a9..883da2bfe012a 100644 --- a/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/base_asset_graph.py @@ -215,9 +215,19 @@ def toposorted_asset_keys_by_level(self) -> Sequence[AbstractSet[AssetKey]]: """ return [set(level) for level in toposort.toposort(self.asset_dep_graph["upstream"])] + @cached_property + def unpartitioned_asset_keys(self) -> AbstractSet[AssetKey]: + return {node.key for node in self.asset_nodes if not node.is_partitioned} + def asset_keys_for_group(self, group_name: str) -> AbstractSet[AssetKey]: return {node.key for node in self.asset_nodes if node.group_name == group_name} + @cached_method + def asset_keys_for_partitions_def( + self, partitions_def: PartitionsDefinition + ) -> AbstractSet[AssetKey]: + return {node.key for node in self.asset_nodes if node.partitions_def == partitions_def} + @cached_property def root_materializable_asset_keys(self) -> AbstractSet[AssetKey]: """Materializable asset keys that have no materializable parents.""" @@ -236,6 +246,12 @@ def root_executable_asset_keys(self) -> AbstractSet[AssetKey]: def asset_check_keys(self) -> AbstractSet[AssetCheckKey]: return {key for asset in self.asset_nodes for key in asset.check_keys} + @cached_property + def all_partitions_defs(self) -> Sequence[PartitionsDefinition]: + return sorted( + set(node.partitions_def for node in self.asset_nodes if node.partitions_def), key=repr + ) + @cached_property def all_group_names(self) -> AbstractSet[str]: return {a.group_name for a in self.asset_nodes if a.group_name is not None} diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py index 3c1237ae85fe4..a5ea2fb939f86 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_data_builder.py @@ -243,10 +243,9 @@ def build_caching_repository_data_from_list( ) if assets_defs or source_assets or asset_checks_defs: for job_def in get_base_asset_jobs( - assets=asset_graph.assets_defs, + asset_graph=asset_graph, executor_def=default_executor_def, resource_defs=top_level_resources, - asset_checks=asset_checks_defs, ): jobs[job_def.name] = job_def diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py index a7151544da66c..99a525d02042d 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_job.py @@ -1951,16 +1951,17 @@ def hourly_asset(): ... def unpartitioned_asset(): ... jobs = get_base_asset_jobs( - assets=[ - daily_asset, - daily_asset2, - daily_asset_different_start_date, - hourly_asset, - unpartitioned_asset, - ], + asset_graph=AssetGraph.from_assets( + [ + daily_asset, + daily_asset2, + daily_asset_different_start_date, + hourly_asset, + unpartitioned_asset, + ] + ), executor_def=None, resource_defs={}, - asset_checks=[], ) assert len(jobs) == 3 assert {job_def.name for job_def in jobs} == { @@ -1995,14 +1996,15 @@ def asset_b(): ... def asset_x(asset_b: B): ... jobs = get_base_asset_jobs( - assets=[ - asset_x, - create_external_asset_from_source_asset(asset_a), - create_external_asset_from_source_asset(asset_b), - ], + asset_graph=AssetGraph.from_assets( + [ + asset_x, + create_external_asset_from_source_asset(asset_a), + create_external_asset_from_source_asset(asset_b), + ] + ), executor_def=None, resource_defs={}, - asset_checks=[], ) assert len(jobs) == 2 assert {job_def.name for job_def in jobs} == { diff --git a/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py b/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py index b65305b606842..af60e83fa1867 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_pending_repository.py @@ -133,7 +133,7 @@ def test_resolve_wrong_data(): ) -def define_resource_dependent_cacheable_and_uncacheable_assets(): +def define_uncacheable_and_resource_dependent_cacheable_assets(): class ResourceDependentCacheableAsset(CacheableAssetsDefinition): def __init__(self): super().__init__("res_downstream") @@ -160,11 +160,11 @@ def _op(context, res_upstream): for cd in data ] - @asset(required_resource_keys={"foo"}) + @asset def res_upstream(context): return context.resources.foo - @asset(required_resource_keys={"foo"}) + @asset def res_downstream(context, res_midstream): return res_midstream + context.resources.foo @@ -181,7 +181,7 @@ def test_resolve_no_resources(): @repository def resource_dependent_repo_no_resources(): return [ - define_resource_dependent_cacheable_and_uncacheable_assets(), + define_uncacheable_and_resource_dependent_cacheable_assets(), define_asset_job( "all_asset_job", ), @@ -207,7 +207,7 @@ def foo_resource(): def resource_dependent_repo_with_resources(): return [ with_resources( - define_resource_dependent_cacheable_and_uncacheable_assets(), {"foo": foo_resource} + define_uncacheable_and_resource_dependent_cacheable_assets(), {"foo": foo_resource} ), define_asset_job( "all_asset_job", @@ -289,7 +289,7 @@ def foo_resource(): for x in with_resources( [ x.with_attributes(group_names_by_key={AssetKey("res_midstream"): "my_cool_group"}) - for x in define_resource_dependent_cacheable_and_uncacheable_assets() + for x in define_uncacheable_and_resource_dependent_cacheable_assets() ], {"foo": foo_resource}, )