Skip to content

Commit

Permalink
[external-assets] Build base asset jobs using AssetGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Mar 11, 2024
1 parent 7408446 commit 8ce31aa
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def get_execution_set_asset_and_check_keys(
def assets_defs(self) -> Sequence[AssetsDefinition]:
return list(dict.fromkeys(asset.assets_def for asset in self.asset_nodes))

def assets_defs_for_keys(self, keys: Iterable[AssetKey]) -> Sequence[AssetsDefinition]:
return list(dict.fromkeys([self.get(key).assets_def for key in keys]))

@property
def asset_checks_defs(self) -> Sequence[AssetChecksDefinition]:
return list(dict.fromkeys(self._asset_checks_defs_by_key.values()))
55 changes: 17 additions & 38 deletions python_modules/dagster/dagster/_core/definitions/assets_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Expand All @@ -17,6 +16,7 @@
from toposort import CircularDependencyError, toposort

import dagster._check as check
from dagster._core.definitions.asset_graph import AssetGraph
from dagster._core.definitions.hook_definition import HookDefinition
from dagster._core.definitions.policy import RetryPolicy
from dagster._core.errors import DagsterInvalidDefinitionError
Expand Down Expand Up @@ -61,58 +61,37 @@ def is_base_asset_job_name(name: str) -> bool:


def get_base_asset_jobs(
assets: Sequence[AssetsDefinition],
asset_checks: Sequence[AssetChecksDefinition],
asset_graph: AssetGraph,
resource_defs: Optional[Mapping[str, ResourceDefinition]],
executor_def: Optional[ExecutorDefinition],
) -> Sequence[JobDefinition]:
executable_assets = [a for a in assets if a.is_executable]
unexecutable_assets = [a for a in assets if not a.is_executable]

executable_assets_by_partitions_def: Dict[
Optional[PartitionsDefinition], List[AssetsDefinition]
] = defaultdict(list)
for asset in executable_assets:
executable_assets_by_partitions_def[asset.partitions_def].append(asset)
# sort to ensure some stability in the ordering
all_partitions_defs = sorted(
[p for p in executable_assets_by_partitions_def.keys() if p], key=repr
)

if len(all_partitions_defs) == 0:
if len(asset_graph.all_partitions_defs) == 0:
executable_asset_keys = asset_graph.executable_asset_keys
loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys
return [
build_assets_job(
name=ASSET_BASE_JOB_PREFIX,
executable_assets=executable_assets,
loadable_assets=unexecutable_assets,
asset_checks=asset_checks,
executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys),
loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys),
asset_checks=asset_graph.asset_checks_defs,
executor_def=executor_def,
resource_defs=resource_defs,
)
]
else:
unpartitioned_executable_assets = executable_assets_by_partitions_def.get(None, [])
jobs = []

for i, partitions_def in enumerate(all_partitions_defs):
# all base jobs contain all unpartitioned assets
executable_assets_for_job = [
*executable_assets_by_partitions_def[partitions_def],
*unpartitioned_executable_assets,
]
for i, partitions_def in enumerate(asset_graph.all_partitions_defs):
executable_asset_keys = asset_graph.executable_asset_keys & {
*asset_graph.asset_keys_for_partitions_def(partitions_def=partitions_def),
*asset_graph.unpartitioned_asset_keys,
}
loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys
jobs.append(
build_assets_job(
f"{ASSET_BASE_JOB_PREFIX}_{i}",
executable_assets=executable_assets_for_job,
loadable_assets=[
*(
asset
for asset in executable_assets
if asset not in executable_assets_for_job
),
*unexecutable_assets,
],
asset_checks=asset_checks,
executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys),
loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys),
asset_checks=asset_graph.asset_checks_defs,
resource_defs=resource_defs,
executor_def=executor_def,
partitions_def=partitions_def,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,19 @@ def toposorted_asset_keys_by_level(self) -> Sequence[AbstractSet[AssetKey]]:
"""
return [set(level) for level in toposort.toposort(self.asset_dep_graph["upstream"])]

@cached_property
def unpartitioned_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if not node.is_partitioned}

def asset_keys_for_group(self, group_name: str) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.group_name == group_name}

@cached_method
def asset_keys_for_partitions_def(
self, partitions_def: PartitionsDefinition
) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.partitions_def == partitions_def}

@cached_property
def root_materializable_asset_keys(self) -> AbstractSet[AssetKey]:
"""Materializable asset keys that have no materializable parents."""
Expand All @@ -236,6 +246,12 @@ def root_executable_asset_keys(self) -> AbstractSet[AssetKey]:
def asset_check_keys(self) -> AbstractSet[AssetCheckKey]:
return {key for asset in self.asset_nodes for key in asset.check_keys}

@cached_property
def all_partitions_defs(self) -> Sequence[PartitionsDefinition]:
return sorted(
set(node.partitions_def for node in self.asset_nodes if node.partitions_def), key=repr
)

@cached_property
def all_group_names(self) -> AbstractSet[str]:
return {a.group_name for a in self.asset_nodes if a.group_name is not None}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,9 @@ def build_caching_repository_data_from_list(
)
if assets_defs or source_assets or asset_checks_defs:
for job_def in get_base_asset_jobs(
assets=asset_graph.assets_defs,
asset_graph=asset_graph,
executor_def=default_executor_def,
resource_defs=top_level_resources,
asset_checks=asset_checks_defs,
):
jobs[job_def.name] = job_def

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1951,16 +1951,17 @@ def hourly_asset(): ...
def unpartitioned_asset(): ...

jobs = get_base_asset_jobs(
assets=[
daily_asset,
daily_asset2,
daily_asset_different_start_date,
hourly_asset,
unpartitioned_asset,
],
asset_graph=AssetGraph.from_assets(
[
daily_asset,
daily_asset2,
daily_asset_different_start_date,
hourly_asset,
unpartitioned_asset,
]
),
executor_def=None,
resource_defs={},
asset_checks=[],
)
assert len(jobs) == 3
assert {job_def.name for job_def in jobs} == {
Expand Down Expand Up @@ -1995,14 +1996,15 @@ def asset_b(): ...
def asset_x(asset_b: B): ...

jobs = get_base_asset_jobs(
assets=[
asset_x,
create_external_asset_from_source_asset(asset_a),
create_external_asset_from_source_asset(asset_b),
],
asset_graph=AssetGraph.from_assets(
[
asset_x,
create_external_asset_from_source_asset(asset_a),
create_external_asset_from_source_asset(asset_b),
]
),
executor_def=None,
resource_defs={},
asset_checks=[],
)
assert len(jobs) == 2
assert {job_def.name for job_def in jobs} == {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_resolve_wrong_data():
)


def define_resource_dependent_cacheable_and_uncacheable_assets():
def define_uncacheable_and_resource_dependent_cacheable_assets():
class ResourceDependentCacheableAsset(CacheableAssetsDefinition):
def __init__(self):
super().__init__("res_downstream")
Expand All @@ -160,11 +160,11 @@ def _op(context, res_upstream):
for cd in data
]

@asset(required_resource_keys={"foo"})
@asset
def res_upstream(context):
return context.resources.foo

@asset(required_resource_keys={"foo"})
@asset
def res_downstream(context, res_midstream):
return res_midstream + context.resources.foo

Expand All @@ -181,7 +181,7 @@ def test_resolve_no_resources():
@repository
def resource_dependent_repo_no_resources():
return [
define_resource_dependent_cacheable_and_uncacheable_assets(),
define_uncacheable_and_resource_dependent_cacheable_assets(),
define_asset_job(
"all_asset_job",
),
Expand All @@ -207,7 +207,7 @@ def foo_resource():
def resource_dependent_repo_with_resources():
return [
with_resources(
define_resource_dependent_cacheable_and_uncacheable_assets(), {"foo": foo_resource}
define_uncacheable_and_resource_dependent_cacheable_assets(), {"foo": foo_resource}
),
define_asset_job(
"all_asset_job",
Expand Down Expand Up @@ -289,7 +289,7 @@ def foo_resource():
for x in with_resources(
[
x.with_attributes(group_names_by_key={AssetKey("res_midstream"): "my_cool_group"})
for x in define_resource_dependent_cacheable_and_uncacheable_assets()
for x in define_uncacheable_and_resource_dependent_cacheable_assets()
],
{"foo": foo_resource},
)
Expand Down

0 comments on commit 8ce31aa

Please sign in to comment.