Skip to content

Commit

Permalink
[external-assets] Build base asset jobs using AssetGraph
Browse files Browse the repository at this point in the history
[INTERNAL_BRANCH=sean/external-assets-asset-graph-nodes]
  • Loading branch information
smackesey committed Mar 4, 2024
1 parent d61c5bf commit 7ff96a5
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 57 deletions.
18 changes: 18 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/asset_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,20 @@ def toposorted_asset_keys_by_level(self) -> Sequence[AbstractSet[AssetKey]]:
{key for key in level} for level in toposort.toposort(self.asset_dep_graph["upstream"])
]

@property
@cached_method
def unpartitioned_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if not node.is_partitioned}

def asset_keys_for_group(self, group_name: str) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.group_name == group_name}

@cached_method
def asset_keys_for_partitions_def(
self, *, partitions_def: PartitionsDefinition
) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.partitions_def == partitions_def}

@property
@cached_method
def root_materializable_asset_keys(self) -> AbstractSet[AssetKey]:
Expand All @@ -311,6 +322,13 @@ def root_executable_asset_keys(self) -> AbstractSet[AssetKey]:
def asset_check_keys(self) -> AbstractSet[AssetCheckKey]:
return {key for asset in self.asset_nodes for key in asset.check_keys}

@property
@cached_method
def all_partitions_defs(self) -> Sequence[PartitionsDefinition]:
return sorted(
set(node.partitions_def for node in self.asset_nodes if node.partitions_def), key=repr
)

@property
@cached_method
def all_group_names(self) -> AbstractSet[str]:
Expand Down
55 changes: 17 additions & 38 deletions python_modules/dagster/dagster/_core/definitions/assets_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Expand All @@ -18,6 +17,7 @@

import dagster._check as check
from dagster._core.definitions.hook_definition import HookDefinition
from dagster._core.definitions.internal_asset_graph import InternalAssetGraph
from dagster._core.definitions.policy import RetryPolicy
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.selector.subset_selector import AssetSelectionData
Expand Down Expand Up @@ -61,58 +61,37 @@ def is_base_asset_job_name(name: str) -> bool:


def get_base_asset_jobs(
assets: Sequence[AssetsDefinition],
asset_checks: Sequence[AssetChecksDefinition],
asset_graph: InternalAssetGraph,
resource_defs: Optional[Mapping[str, ResourceDefinition]],
executor_def: Optional[ExecutorDefinition],
) -> Sequence[JobDefinition]:
executable_assets = [a for a in assets if a.is_executable]
unexecutable_assets = [a for a in assets if not a.is_executable]

executable_assets_by_partitions_def: Dict[
Optional[PartitionsDefinition], List[AssetsDefinition]
] = defaultdict(list)
for asset in executable_assets:
executable_assets_by_partitions_def[asset.partitions_def].append(asset)
# sort to ensure some stability in the ordering
all_partitions_defs = sorted(
[p for p in executable_assets_by_partitions_def.keys() if p], key=repr
)

if len(all_partitions_defs) == 0:
if len(asset_graph.all_partitions_defs) == 0:
executable_asset_keys = asset_graph.executable_asset_keys
loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys
return [
build_assets_job(
name=ASSET_BASE_JOB_PREFIX,
executable_assets=executable_assets,
loadable_assets=unexecutable_assets,
asset_checks=asset_checks,
executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys),
loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys),
asset_checks=asset_graph.asset_checks_defs,
executor_def=executor_def,
resource_defs=resource_defs,
)
]
else:
unpartitioned_executable_assets = executable_assets_by_partitions_def.get(None, [])
jobs = []

for i, partitions_def in enumerate(all_partitions_defs):
# all base jobs contain all unpartitioned assets
executable_assets_for_job = [
*executable_assets_by_partitions_def[partitions_def],
*unpartitioned_executable_assets,
]
for i, partitions_def in enumerate(asset_graph.all_partitions_defs):
executable_asset_keys = asset_graph.executable_asset_keys & {
*asset_graph.asset_keys_for_partitions_def(partitions_def=partitions_def),
*asset_graph.unpartitioned_asset_keys,
}
loadable_asset_keys = asset_graph.all_asset_keys - executable_asset_keys
jobs.append(
build_assets_job(
f"{ASSET_BASE_JOB_PREFIX}_{i}",
executable_assets=executable_assets_for_job,
loadable_assets=[
*(
asset
for asset in executable_assets
if asset not in executable_assets_for_job
),
*unexecutable_assets,
],
asset_checks=asset_checks,
executable_assets=asset_graph.assets_defs_for_keys(executable_asset_keys),
loadable_assets=asset_graph.assets_defs_for_keys(loadable_asset_keys),
asset_checks=asset_graph.asset_checks_defs,
resource_defs=resource_defs,
executor_def=executor_def,
partitions_def=partitions_def,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def get_execution_unit_asset_and_check_keys(
def assets_defs(self) -> Sequence[AssetsDefinition]:
return list(dict.fromkeys(asset.assets_def for asset in self.asset_nodes))

def assets_defs_for_keys(self, keys: Iterable[AssetKey]) -> Sequence[AssetsDefinition]:
return list(dict.fromkeys([self.get(key).assets_def for key in keys]))

@property
def asset_checks_defs(self) -> Sequence[AssetChecksDefinition]:
return list(dict.fromkeys(self._asset_checks_defs_by_key.values()))
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .valid_definitions import VALID_REPOSITORY_DATA_DICT_KEYS, RepositoryListDefinition

if TYPE_CHECKING:
from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.definitions.events import AssetKey


Expand Down Expand Up @@ -162,6 +163,7 @@ def build_caching_repository_data_from_list(
sensors: Dict[str, SensorDefinition] = {}
assets_defs: List[AssetsDefinition] = []
asset_keys: Set[AssetKey] = set()
asset_check_keys: Set[AssetCheckKey] = set()
source_assets: List[SourceAsset] = []
asset_checks_defs: List[AssetChecksDefinition] = []
for definition in repository_definitions:
Expand Down Expand Up @@ -228,6 +230,10 @@ def build_caching_repository_data_from_list(
source_assets.append(definition)
asset_keys.add(definition.key)
elif isinstance(definition, AssetChecksDefinition):
for key in definition.keys:
if key in asset_check_keys:
raise DagsterInvalidDefinitionError(f"Duplicate asset check key: {key}")
asset_check_keys.update(definition.keys)
asset_checks_defs.append(definition)
else:
check.failed(f"Unexpected repository entry {definition}")
Expand All @@ -237,10 +243,9 @@ def build_caching_repository_data_from_list(
)
if assets_defs or source_assets or asset_checks_defs:
for job_def in get_base_asset_jobs(
assets=asset_graph.assets_defs,
asset_graph=asset_graph,
executor_def=default_executor_def,
resource_defs=top_level_resources,
asset_checks=asset_checks_defs,
):
jobs[job_def.name] = job_def

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing_extensions import TypeAlias

from dagster._core.definitions.asset_checks import AssetChecksDefinition
from dagster._core.definitions.graph_definition import GraphDefinition
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.schedule_definition import ScheduleDefinition
Expand Down Expand Up @@ -33,6 +34,7 @@

RepositoryListDefinition: TypeAlias = Union[
"AssetsDefinition",
AssetChecksDefinition,
GraphDefinition,
JobDefinition,
ScheduleDefinition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1959,16 +1959,17 @@ def unpartitioned_asset():
...

jobs = get_base_asset_jobs(
assets=[
daily_asset,
daily_asset2,
daily_asset_different_start_date,
hourly_asset,
unpartitioned_asset,
],
asset_graph=InternalAssetGraph.from_assets(
[
daily_asset,
daily_asset2,
daily_asset_different_start_date,
hourly_asset,
unpartitioned_asset,
]
),
executor_def=None,
resource_defs={},
asset_checks=[],
)
assert len(jobs) == 3
assert {job_def.name for job_def in jobs} == {
Expand Down Expand Up @@ -2007,14 +2008,15 @@ def asset_x(asset_b: B):
...

jobs = get_base_asset_jobs(
assets=[
asset_x,
create_external_asset_from_source_asset(asset_a),
create_external_asset_from_source_asset(asset_b),
],
asset_graph=InternalAssetGraph.from_assets(
[
asset_x,
create_external_asset_from_source_asset(asset_a),
create_external_asset_from_source_asset(asset_b),
]
),
executor_def=None,
resource_defs={},
asset_checks=[],
)
assert len(jobs) == 2
assert {job_def.name for job_def in jobs} == {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def check1(context: AssetExecutionContext):

with pytest.raises(
DagsterInvalidDefinitionError,
match='Detected conflicting node definitions with the same name "asset1_check1"',
match="Duplicate asset check key.+asset1.+check1",
):
Definitions(asset_checks=[make_check(), make_check()])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ def _op(context, res_upstream):
for cd in data
]

@asset(required_resource_keys={"foo"})
@asset
def res_upstream(context):
return context.resources.foo

@asset(required_resource_keys={"foo"})
@asset
def res_downstream(context, res_midstream):
return res_midstream + context.resources.foo

Expand Down

0 comments on commit 7ff96a5

Please sign in to comment.