Skip to content

Commit

Permalink
[external-assets] Update AssetGraph accessor callsites to fetch nodes
Browse files Browse the repository at this point in the history
[INTERNAL_BRANCH=sean/external-assets-asset-graph-nodes]
  • Loading branch information
smackesey committed Mar 1, 2024
1 parent f211e23 commit 991eef7
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 97 deletions.
34 changes: 2 additions & 32 deletions python_modules/dagster/dagster/_core/definitions/asset_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,33 +248,21 @@ def all_asset_keys(self) -> AbstractSet[AssetKey]:
def materializable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_materializable}

def is_materializable(self, key: AssetKey) -> bool:
return self.get(key).is_materializable

@property
@cached_method
def observable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_observable}

def is_observable(self, key: AssetKey) -> bool:
return self.get(key).is_observable

@property
@cached_method
def external_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_external}

def is_external(self, key: AssetKey) -> bool:
return self.get(key).is_external

@property
@cached_method
def executable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_executable}

def is_executable(self, key: AssetKey) -> bool:
return self.get(key).is_executable

@property
@cached_method
def toposorted_asset_keys(self) -> Sequence[AssetKey]:
Expand Down Expand Up @@ -353,24 +341,6 @@ def get_partitions_in_range(
def is_partitioned(self, asset_key: AssetKey) -> bool:
return self.get_partitions_def(asset_key) is not None

def get_group_name(self, asset_key: AssetKey) -> Optional[str]:
return self.get(asset_key).group_name

def get_freshness_policy(self, asset_key: AssetKey) -> Optional[FreshnessPolicy]:
return self.get(asset_key).freshness_policy

def get_auto_materialize_policy(self, asset_key: AssetKey) -> Optional[AutoMaterializePolicy]:
return self.get(asset_key).auto_materialize_policy

def get_auto_observe_interval_minutes(self, asset_key: AssetKey) -> Optional[float]:
return self.get(asset_key).auto_observe_interval_minutes

def get_backfill_policy(self, asset_key: AssetKey) -> Optional[BackfillPolicy]:
return self.get(asset_key).backfill_policy

def get_code_version(self, asset_key: AssetKey) -> Optional[str]:
return self.get(asset_key).code_version

def have_same_partitioning(self, asset_key1: AssetKey, asset_key2: AssetKey) -> bool:
"""Returns whether the given assets have the same partitions definition."""
return self.get(asset_key1).partitions_def == self.get(asset_key2).partitions_def
Expand Down Expand Up @@ -653,7 +623,7 @@ def has_materializable_parents(self, asset_key: AssetKey) -> bool:
)

def get_materializable_roots(self, asset_key: AssetKey) -> AbstractSet[AssetKey]:
"""Returns all assets upstream of the given asset which do not consume any other
"""Returns all assets upstream of the giget which do not consume any other
materializable assets.
"""
if not self.has_materializable_parents(asset_key):
Expand All @@ -667,7 +637,7 @@ def get_materializable_roots(self, asset_key: AssetKey) -> AbstractSet[AssetKey]
}

def upstream_key_iterator(self, asset_key: AssetKey) -> Iterator[AssetKey]:
"""Iterates through all asset keys which are upstream of the given key."""
"""Iterates tgetl asset keys which are upstream of the given key."""
visited: Set[AssetKey] = set()
queue = deque([asset_key])
while queue:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def _compare_base_and_branch_assets(self, asset_key: "AssetKey") -> Sequence[Cha
return [ChangeReason.NEW]

changes = []
if self.branch_asset_graph.get_code_version(
asset_key
) != self.base_asset_graph.get_code_version(asset_key):
if (
self.branch_asset_graph.get(asset_key).code_version
!= self.base_asset_graph.get(asset_key).code_version
):
changes.append(ChangeReason.CODE_VERSION)

if self.branch_asset_graph.get_parents(asset_key) != self.base_asset_graph.get_parents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ def evaluate_for_asset(
# ignore missing or unexecutable assets, which will never have a materialization or
# observation
if not (
context.asset_graph.has_asset(parent.asset_key)
and context.asset_graph.is_executable(parent.asset_key)
context.asset_graph.has(parent.asset_key)
and context.asset_graph.get(parent.asset_key).is_executable
):
continue
if not context.instance_queryer.asset_partition_has_materialization_or_observation(
Expand Down
13 changes: 6 additions & 7 deletions python_modules/dagster/dagster/_core/definitions/data_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def _upstream_records_by_key(

for parent_key in self.asset_graph.get_parents(asset_key):
if not (
self.asset_graph.has_asset(parent_key)
and self.asset_graph.is_executable(parent_key)
self.asset_graph.has(parent_key) and self.asset_graph.get(parent_key).is_executable
):
continue

Expand Down Expand Up @@ -319,7 +318,7 @@ def _calculate_data_time_by_key(
cursor=record_id,
partitions_def=partitions_def,
)
elif self.asset_graph.is_observable(asset_key):
elif self.asset_graph.get(asset_key).is_observable:
return self._calculate_data_time_by_key_observable_source(
asset_key=asset_key,
record_id=record_id,
Expand Down Expand Up @@ -533,18 +532,18 @@ def get_minutes_overdue(
asset_key: AssetKey,
evaluation_time: datetime.datetime,
) -> Optional[FreshnessMinutes]:
freshness_policy = self.asset_graph.get_freshness_policy(asset_key)
if freshness_policy is None:
asset = self.asset_graph.get(asset_key)
if asset.freshness_policy is None:
raise DagsterInvariantViolationError(
"Cannot calculate minutes late for asset without a FreshnessPolicy"
)

if self.asset_graph.is_external(asset_key):
if asset.is_external:
current_data_time = self._get_source_data_time(asset_key, current_time=evaluation_time)
else:
current_data_time = self.get_current_data_time(asset_key, current_time=evaluation_time)

return freshness_policy.minutes_overdue(
return asset.freshness_policy.minutes_overdue(
data_time=current_data_time,
evaluation_time=evaluation_time,
)
40 changes: 23 additions & 17 deletions python_modules/dagster/dagster/_core/definitions/data_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,14 @@ def get_current_data_version(
def _get_status(self, key: "AssetKeyPartitionKey") -> StaleStatus:
# The status loader does not support querying for the stale status of a
# partitioned asset without specifying a partition, so we return here.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return StaleStatus.FRESH
else:
current_version = self._get_current_data_version(key=key)
if current_version == NULL_DATA_VERSION:
return StaleStatus.MISSING
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return StaleStatus.FRESH
else:
causes = self._get_stale_causes(key=key)
Expand All @@ -440,9 +441,10 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
# Querying for the stale status of a partitioned asset without specifying a partition key
# is strictly speaking undefined, but we return an empty list here (from which FRESH status
# is inferred) for backcompat.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return []
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return []
else:
current_version = self._get_current_data_version(key=key)
Expand All @@ -454,21 +456,22 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
)

def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitionKey") -> bool:
dep_asset = self.asset_graph.get(dep_key.asset_key)
if dep_key.partition_key is None:
current_data_version = self._get_current_data_version(key=dep_key)
return provenance.input_data_versions[dep_key.asset_key] != current_data_version
else:
cursor = provenance.input_storage_ids[dep_key.asset_key]
updated_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
after_cursor=cursor,
)
if updated_record:
previous_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
before_cursor=cursor + 1 if cursor else None,
)
Expand All @@ -485,7 +488,7 @@ def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitio
def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterator[StaleCause]:
from dagster._core.definitions.events import AssetKeyPartitionKey

code_version = self.asset_graph.get_code_version(key.asset_key)
code_version = self.asset_graph.get(key.asset_key).code_version
provenance = self._get_current_data_provenance(key=key)

asset_deps = self.asset_graph.get_parents(key.asset_key)
Expand Down Expand Up @@ -513,6 +516,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# partition counts.
partition_deps = self._get_partition_dependencies(key=key)
for dep_key in sorted(partition_deps):
dep_asset = self.asset_graph.get(dep_key.asset_key)
if self._get_status(key=dep_key) == StaleStatus.STALE:
yield StaleCause(
key,
Expand All @@ -532,9 +536,10 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# Currently we exclude assets downstream of AllPartitionMappings from stale
# status logic due to potentially huge numbers of dependencies.
elif self._is_dep_updated(provenance, dep_key):
report_data_version = self.asset_graph.get_code_version(
dep_key.asset_key
) is not None or self._is_current_data_version_user_provided(key=dep_key)
report_data_version = (
dep_asset.code_version is not None
or self._is_current_data_version_user_provided(key=dep_key)
)
yield StaleCause(
key,
StaleCauseCategory.DATA,
Expand Down Expand Up @@ -563,7 +568,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# timestamps instead of versions this should be removable eventually since
# provenance is on all newer materializations. If dep is a source, then we'll never
# provide a stale reason here.
elif not self.asset_graph.is_external(dep_key.asset_key):
elif not dep_asset.is_external:
dep_materialization = self._get_latest_data_version_record(key=dep_key)
if dep_materialization is None:
# The input must be new if it has no materialization
Expand Down Expand Up @@ -622,7 +627,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi
# Currently we can only use asset records, which are fetched in one shot, for non-source
# assets. This is because the most recent AssetObservation is not stored on the AssetRecord.
record = self._get_latest_data_version_record(key=key)
if self.asset_graph.is_external(key.asset_key) and record is None:
if self.asset_graph.get(key.asset_key).is_external and record is None:
return DEFAULT_DATA_VERSION
elif record is None:
return NULL_DATA_VERSION
Expand All @@ -632,7 +637,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi

@cached_method
def _is_current_data_version_user_provided(self, *, key: "AssetKeyPartitionKey") -> bool:
if self.asset_graph.is_external(key.asset_key):
if self.asset_graph.get(key.asset_key).is_external:
return True
else:
provenance = self._get_current_data_provenance(key=key)
Expand All @@ -654,10 +659,11 @@ def _get_current_data_provenance(
# are at the root of the graph (have no dependencies) or are downstream of a volatile asset.
@cached_method
def _is_volatile(self, *, key: "AssetKey") -> bool:
if self.asset_graph.is_external(key):
return self.asset_graph.is_observable(key)
asset = self.asset_graph.get(key)
if asset.is_external:
return asset.is_observable
else:
deps = self.asset_graph.get_parents(key)
deps = asset.get_parents(key)
return len(deps) == 0 or any(self._is_volatile(key=dep_key) for dep_key in deps)

@cached_method
Expand All @@ -678,7 +684,7 @@ def _get_latest_data_version_record(
# If an asset record is cached, all of its ancestors have already been cached.
if (
key.partition_key is None
and not self.asset_graph.is_external(key.asset_key)
and not self.asset_graph.get(key.asset_key).is_external
and not self.instance_queryer.has_cached_asset_record(key.asset_key)
):
ancestors = self.asset_graph.get_ancestors(key.asset_key, include_self=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def freshness_evaluation_results_for_asset_key(
execution_period,
evaluation_data,
) = get_execution_period_and_evaluation_data_for_policies(
local_policy=context.asset_graph.get_freshness_policy(asset_key),
local_policy=context.asset_graph.get(asset_key).freshness_policy,
policies=context.asset_graph.get_downstream_freshness_policies(asset_key=asset_key),
effective_data_time=effective_data_time,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _wrapped_fn(context: SensorEvaluationContext):

minutes_late_by_key: Dict[AssetKey, Optional[float]] = {}
for asset_key in monitored_keys:
freshness_policy = asset_graph.get_freshness_policy(asset_key)
freshness_policy = asset_graph.get(asset_key).freshness_policy
if freshness_policy is None:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ def execute_asset_backfill_iteration_inner(

# check if all assets have backfill policies if any of them do, otherwise, raise error
asset_backfill_policies = [
asset_graph.get_backfill_policy(asset_key)
asset_graph.get_asset(asset_key).backfill_policy
for asset_key in {
asset_partition.asset_key for asset_partition in asset_partitions_to_request
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,14 @@ def _external_sensors(self) -> Dict[str, "ExternalSensor"]:
default_sensor_asset_keys = set()

for asset_key in asset_graph.materializable_asset_keys:
policy = asset_graph.get_auto_materialize_policy(asset_key)
if not policy:
if not asset_graph.get(asset_key).auto_materialize_policy:
continue

if asset_key not in covered_asset_keys:
default_sensor_asset_keys.add(asset_key)

for asset_key in asset_graph.observable_asset_keys:
if asset_graph.get_auto_observe_interval_minutes(asset_key) is None:
if asset_graph.get(asset_key).auto_observe_interval_minutes is None:
continue

has_any_auto_observe_source_assets = True
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_daemon/asset_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,14 +682,14 @@ def _process_auto_materialize_tick_generator(
auto_materialize_asset_keys = {
target_key
for target_key in eligible_keys
if asset_graph.get_auto_materialize_policy(target_key) is not None
if asset_graph.get(target_key).auto_materialize_policy is not None
}
num_target_assets = len(auto_materialize_asset_keys)

auto_observe_asset_keys = {
key
for key in eligible_keys
if asset_graph.get_auto_observe_interval_minutes(key) is not None
if asset_graph.get(key).auto_observe_interval_minutes is not None
}
num_auto_observe_assets = len(auto_observe_asset_keys)

Expand Down
Loading

0 comments on commit 991eef7

Please sign in to comment.