Skip to content

Commit

Permalink
[external-assets] refactor AssetGraph
Browse files Browse the repository at this point in the history
[INTERNAL_BRANCH=sean/external-assets-refactor-asset-graph]
  • Loading branch information
smackesey committed Feb 23, 2024
1 parent ecba104 commit 9532c18
Show file tree
Hide file tree
Showing 19 changed files with 533 additions and 469 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def spec(self) -> AssetCheckSpec:
def specs(self) -> Iterable[AssetCheckSpec]:
return self._specs_by_output_name.values()

@property
def keys(self) -> Iterable[AssetCheckKey]:
return self._specs_by_handle.keys()

@property
def specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]:
return self._specs_by_output_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def evaluate_asset(
"""
# convert the legacy AutoMaterializePolicy to an Evaluator
asset_condition = check.not_none(
self.asset_graph.auto_materialize_policies_by_key.get(asset_key)
self.asset_graph.get_auto_materialize_policy(asset_key)
).to_asset_condition()

asset_cursor = self.cursor.get_previous_evaluation_state(asset_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def backcompat_deserialize_asset_daemon_cursor_str(

previous_evaluation_state = []
cursor_keys = (
asset_graph.auto_materialize_policies_by_key.keys()
asset_graph.materializable_asset_keys
if asset_graph
else latest_evaluation_by_asset_key.keys()
)
Expand Down
306 changes: 122 additions & 184 deletions python_modules/dagster/dagster/_core/definitions/asset_graph.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,10 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
else asset_graph.materializable_asset_keys
)
return {
asset_key
for asset_key, group in asset_graph.group_names_by_key.items()
if group in self.selected_groups and asset_key in base_set
key
for group in self.selected_groups
for key in asset_graph.asset_keys_for_group(group)
if key in base_set
}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def get_minutes_overdue(
asset_key: AssetKey,
evaluation_time: datetime.datetime,
) -> Optional[FreshnessMinutes]:
freshness_policy = self.asset_graph.freshness_policies_by_key.get(asset_key)
freshness_policy = self.asset_graph.get_freshness_policy(asset_key)
if freshness_policy is None:
raise DagsterInvariantViolationError(
"Cannot calculate minutes late for asset without a FreshnessPolicy"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def get_implicit_job_def_for_assets(

def get_assets_def(self, key: CoercibleToAssetKey) -> AssetsDefinition:
asset_key = AssetKey.from_coercible(key)
for assets_def in self.get_asset_graph().assets:
for assets_def in self.get_asset_graph().assets_defs:
if asset_key in assets_def.keys:
return assets_def

Expand Down
336 changes: 206 additions & 130 deletions python_modules/dagster/dagster/_core/definitions/external_asset_graph.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def freshness_evaluation_results_for_asset_key(
execution_period,
evaluation_data,
) = get_execution_period_and_evaluation_data_for_policies(
local_policy=context.asset_graph.freshness_policies_by_key.get(asset_key),
local_policy=context.asset_graph.get_freshness_policy(asset_key),
policies=context.asset_graph.get_downstream_freshness_policies(asset_key=asset_key),
effective_data_time=effective_data_time,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _wrapped_fn(context: SensorEvaluationContext):

minutes_late_by_key: Dict[AssetKey, Optional[float]] = {}
for asset_key in monitored_keys:
freshness_policy = asset_graph.freshness_policies_by_key.get(asset_key)
freshness_policy = asset_graph.get_freshness_policy(asset_key)
if freshness_policy is None:
continue

Expand Down
261 changes: 154 additions & 107 deletions python_modules/dagster/dagster/_core/definitions/internal_asset_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import AbstractSet, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from typing import AbstractSet, Iterable, List, Mapping, Optional, Sequence, Union

from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.definitions.asset_checks import AssetChecksDefinition
from dagster._core.definitions.asset_graph import AssetGraph, AssetKeyOrCheckKey
from dagster._core.definitions.asset_spec import AssetExecutionType
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.backfill_policy import BackfillPolicy
Expand All @@ -13,53 +12,41 @@
from dagster._core.definitions.partition_mapping import PartitionMapping
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.selector.subset_selector import DependencyGraph, generate_asset_dep_graph
from dagster._utils.cached_method import cached_method


class InternalAssetGraph(AssetGraph):
def __init__(
self,
asset_dep_graph: DependencyGraph[AssetKey],
source_asset_keys: AbstractSet[AssetKey],
partitions_defs_by_key: Mapping[AssetKey, Optional[PartitionsDefinition]],
partition_mappings_by_key: Mapping[AssetKey, Optional[Mapping[AssetKey, PartitionMapping]]],
group_names_by_key: Mapping[AssetKey, Optional[str]],
freshness_policies_by_key: Mapping[AssetKey, Optional[FreshnessPolicy]],
auto_materialize_policies_by_key: Mapping[AssetKey, Optional[AutoMaterializePolicy]],
backfill_policies_by_key: Mapping[AssetKey, Optional[BackfillPolicy]],
assets: Sequence[AssetsDefinition],
assets_defs: Sequence[AssetsDefinition],
source_assets: Sequence[SourceAsset],
asset_checks: Sequence[AssetChecksDefinition],
code_versions_by_key: Mapping[AssetKey, Optional[str]],
is_observable_by_key: Mapping[AssetKey, bool],
auto_observe_interval_minutes_by_key: Mapping[AssetKey, Optional[float]],
required_assets_and_checks_by_key: Mapping[
AssetKeyOrCheckKey, AbstractSet[AssetKeyOrCheckKey]
],
asset_checks_defs: Sequence[AssetChecksDefinition],
):
super().__init__(
asset_dep_graph=asset_dep_graph,
source_asset_keys=source_asset_keys,
partitions_defs_by_key=partitions_defs_by_key,
partition_mappings_by_key=partition_mappings_by_key,
group_names_by_key=group_names_by_key,
freshness_policies_by_key=freshness_policies_by_key,
auto_materialize_policies_by_key=auto_materialize_policies_by_key,
backfill_policies_by_key=backfill_policies_by_key,
code_versions_by_key=code_versions_by_key,
is_observable_by_key=is_observable_by_key,
auto_observe_interval_minutes_by_key=auto_observe_interval_minutes_by_key,
required_assets_and_checks_by_key=required_assets_and_checks_by_key,
)
self._assets = assets
self._source_assets = source_assets
self._asset_checks = asset_checks
from dagster._core.definitions.external_asset import create_external_asset_from_source_asset

asset_check_keys = set()
for asset_check in asset_checks:
asset_check_keys.update([spec.key for spec in asset_check.specs])
for asset in assets:
asset_check_keys.update([spec.key for spec in asset.check_specs])
self._asset_check_keys = asset_check_keys
# Temporarily preserved until all source asset access is removed
self._source_assets = source_assets
self._orig_assets_defs = assets_defs

self._assets_defs = [
*assets_defs,
*(create_external_asset_from_source_asset(sa) for sa in source_assets),
]
self._assets_defs_by_key = {key: asset for asset in self._assets_defs for key in asset.keys}
self._assets_defs_by_check_key = {
**{check_key: asset for asset in assets_defs for check_key in asset.check_keys},
**{
check_key: self._assets_defs_by_key[check.asset_key]
for check in asset_checks_defs
for check_key in check.keys
if check.asset_key in self._assets_defs_by_key
},
}

self._asset_checks_defs = asset_checks_defs
self._asset_checks_defs_by_key = {
key: check for check in asset_checks_defs for key in check.keys
}

@staticmethod
def from_assets(
Expand All @@ -70,91 +57,54 @@ def from_assets(

assets_defs: List[AssetsDefinition] = []
source_assets: List[SourceAsset] = []
partitions_defs_by_key: Dict[AssetKey, Optional[PartitionsDefinition]] = {}
partition_mappings_by_key: Dict[
AssetKey, Optional[Mapping[AssetKey, PartitionMapping]]
] = {}
group_names_by_key: Dict[AssetKey, Optional[str]] = {}
freshness_policies_by_key: Dict[AssetKey, Optional[FreshnessPolicy]] = {}
auto_materialize_policies_by_key: Dict[AssetKey, Optional[AutoMaterializePolicy]] = {}
backfill_policies_by_key: Dict[AssetKey, Optional[BackfillPolicy]] = {}
code_versions_by_key: Dict[AssetKey, Optional[str]] = {}
is_observable_by_key: Dict[AssetKey, bool] = {}
auto_observe_interval_minutes_by_key: Dict[AssetKey, Optional[float]] = {}
required_assets_and_checks_by_key: Dict[
AssetKeyOrCheckKey, AbstractSet[AssetKeyOrCheckKey]
] = {}

for asset in all_assets:
if isinstance(asset, SourceAsset):
source_assets.append(asset)
partitions_defs_by_key[asset.key] = asset.partitions_def
group_names_by_key[asset.key] = asset.group_name
is_observable_by_key[asset.key] = asset.is_observable
auto_observe_interval_minutes_by_key[
asset.key
] = asset.auto_observe_interval_minutes
else: # AssetsDefinition
assets_defs.append(asset)
partition_mappings_by_key.update(
{key: asset.partition_mappings for key in asset.keys}
)
partitions_defs_by_key.update({key: asset.partitions_def for key in asset.keys})
group_names_by_key.update(asset.group_names_by_key)
freshness_policies_by_key.update(asset.freshness_policies_by_key)
auto_materialize_policies_by_key.update(asset.auto_materialize_policies_by_key)
backfill_policies_by_key.update({key: asset.backfill_policy for key in asset.keys})
code_versions_by_key.update(asset.code_versions_by_key)

is_observable = asset.execution_type == AssetExecutionType.OBSERVATION
is_observable_by_key.update({key: is_observable for key in asset.keys})

# Set auto_observe_interval_minutes for external observable assets
# This can be removed when/if we have a a solution for mapping
# `auto_observe_interval_minutes` to an AutoMaterialzePolicy
auto_observe_interval_minutes_by_key.update(
{key: asset.auto_observe_interval_minutes for key in asset.keys}
)

if not asset.can_subset:
all_required_keys = {*asset.check_keys, *asset.keys}
if len(all_required_keys) > 1:
for key in all_required_keys:
required_assets_and_checks_by_key[key] = all_required_keys

return InternalAssetGraph(
asset_dep_graph=generate_asset_dep_graph(assets_defs, source_assets),
source_asset_keys={source_asset.key for source_asset in source_assets},
partitions_defs_by_key=partitions_defs_by_key,
partition_mappings_by_key=partition_mappings_by_key,
group_names_by_key=group_names_by_key,
freshness_policies_by_key=freshness_policies_by_key,
auto_materialize_policies_by_key=auto_materialize_policies_by_key,
backfill_policies_by_key=backfill_policies_by_key,
assets=assets_defs,
asset_checks=asset_checks or [],
assets_defs=assets_defs,
asset_checks_defs=asset_checks or [],
source_assets=source_assets,
code_versions_by_key=code_versions_by_key,
is_observable_by_key=is_observable_by_key,
auto_observe_interval_minutes_by_key=auto_observe_interval_minutes_by_key,
required_assets_and_checks_by_key=required_assets_and_checks_by_key,
)

@property
def asset_check_keys(self) -> AbstractSet[AssetCheckKey]:
return self._asset_check_keys
return {
*(key for check in self._asset_checks_defs for key in check.keys),
*(key for asset in self._assets_defs for key in asset.check_keys),
}

@property
def assets(self) -> Sequence[AssetsDefinition]:
return self._assets
def assets_defs(self) -> Sequence[AssetsDefinition]:
# Temporarily return the original set of assets defs passed in until source assets are
# elimintaed
return self._orig_assets_defs

@property
def source_assets(self) -> Sequence[SourceAsset]:
return self._source_assets

def get_assets_def(self, asset_key: AssetKey) -> AssetsDefinition:
return self._assets_defs_by_key[asset_key]

def has_asset(self, asset_key: AssetKey) -> bool:
return asset_key in self._assets_defs_by_key

def get_assets_def_for_check(
self, asset_check_key: AssetCheckKey
) -> Optional[AssetsDefinition]:
return self._assets_defs_by_check_key.get(asset_check_key)

@property
def asset_checks(self) -> Sequence[AssetChecksDefinition]:
return self._asset_checks
def asset_checks_defs(self) -> Sequence[AssetChecksDefinition]:
return self._asset_checks_defs

def get_asset_checks_def(self, asset_check_key: AssetCheckKey) -> AssetChecksDefinition:
return self._asset_checks_defs_by_key[asset_check_key]

def has_asset_check(self, asset_check_key: AssetCheckKey) -> bool:
return asset_check_key in self._asset_checks_defs_by_key

def includes_materializable_and_source_assets(self, asset_keys: AbstractSet[AssetKey]) -> bool:
"""Returns true if the given asset keys contains at least one materializable asset and
Expand All @@ -163,3 +113,100 @@ def includes_materializable_and_source_assets(self, asset_keys: AbstractSet[Asse
selected_source_assets = self.source_asset_keys & asset_keys
selected_regular_assets = asset_keys - self.source_asset_keys
return len(selected_source_assets) > 0 and len(selected_regular_assets) > 0

@property
@cached_method
def asset_dep_graph(self) -> DependencyGraph[AssetKey]:
return generate_asset_dep_graph(self._assets_defs, self._source_assets)

@property
def all_asset_keys(self) -> AbstractSet[AssetKey]:
return {key for ad in self._assets_defs for key in ad.keys}

@property
def source_asset_keys(self) -> AbstractSet[AssetKey]:
return {sa.key for sa in self.source_assets}

@property
def materializable_asset_keys(self) -> AbstractSet[AssetKey]:
return {key for ad in self._assets_defs if ad.is_materializable for key in ad.keys}

def is_materializable(self, asset_key: AssetKey) -> bool:
return self.get_assets_def(asset_key).is_materializable

@property
def observable_asset_keys(self) -> AbstractSet[AssetKey]:
return {key for ad in self._assets_defs if ad.is_observable for key in ad.keys}

def is_observable(self, asset_key: AssetKey) -> bool:
# Performing an existence check temporarily until we change callsites
return self.has_asset(asset_key) and self.get_assets_def(asset_key).is_observable

@property
def external_asset_keys(self) -> AbstractSet[AssetKey]:
return {key for ad in self._assets_defs if ad.is_external for key in ad.keys}

def is_external(self, asset_key: AssetKey) -> bool:
# Preserving non-standard behavior of returning True for non-existent keys until callsites
# can be updated
return asset_key not in self.materializable_asset_keys

@property
def executable_asset_keys(self) -> AbstractSet[AssetKey]:
return {key for ad in self._assets_defs if ad.is_executable for key in ad.keys}

def is_executable(self, asset_key: AssetKey) -> bool:
return self.get_assets_def(asset_key).is_executable

def asset_keys_for_group(self, group_name: str) -> AbstractSet[AssetKey]:
return {
key
for ad in self._assets_defs
for key in ad.keys
if ad.group_names_by_key[key] == group_name
}

def get_required_multi_asset_keys(self, asset_key: AssetKey) -> AbstractSet[AssetKey]:
asset = self.get_assets_def(asset_key)
return set() if len(asset.keys) <= 1 or asset.can_subset else asset.keys

def get_required_asset_and_check_keys(
self, asset_or_check_key: AssetKeyOrCheckKey
) -> AbstractSet[AssetKeyOrCheckKey]:
if isinstance(asset_or_check_key, AssetKey):
asset = self.get_assets_def(asset_or_check_key)
else: # AssetCheckKey
# only checks emitted by AssetsDefinition have required keys
if self.has_asset_check(asset_or_check_key):
return set()
else:
asset = self.get_assets_def_for_check(asset_or_check_key)
if asset is None or asset_or_check_key not in asset.check_keys:
return set()
has_checks = len(asset.check_keys) > 0
if asset.can_subset or len(asset.keys) <= 1 and not has_checks:
return set()
else:
return {*asset.keys, *asset.check_keys}

def get_partitions_def(self, asset_key: AssetKey) -> Optional[PartitionsDefinition]:
# Performing an existence check temporarily until we change callsites
return self.get_assets_def(asset_key).partitions_def if self.has_asset(asset_key) else None

def get_partition_mappings(self, asset_key: AssetKey) -> Mapping[AssetKey, PartitionMapping]:
return self.get_assets_def(asset_key).partition_mappings

def get_freshness_policy(self, asset_key: AssetKey) -> Optional[FreshnessPolicy]:
return self.get_assets_def(asset_key).freshness_policies_by_key.get(asset_key)

def get_auto_materialize_policy(self, asset_key: AssetKey) -> Optional[AutoMaterializePolicy]:
return self.get_assets_def(asset_key).auto_materialize_policies_by_key.get(asset_key)

def get_auto_observe_interval_minutes(self, asset_key: AssetKey) -> Optional[float]:
return self.get_assets_def(asset_key).auto_observe_interval_minutes

def get_backfill_policy(self, asset_key: AssetKey) -> Optional[BackfillPolicy]:
return self.get_assets_def(asset_key).backfill_policy

def get_code_version(self, asset_key: AssetKey) -> Optional[str]:
return self.get_assets_def(asset_key).code_versions_by_key.get(asset_key)
Loading

0 comments on commit 9532c18

Please sign in to comment.