Skip to content

Commit

Permalink
[external-assets] Update AssetGraph accessor callsites to fetch nodes
Browse files Browse the repository at this point in the history
[INTERNAL_BRANCH=sean/external-assets-asset-graph-nodes]
  • Loading branch information
smackesey committed Mar 5, 2024
1 parent eeef985 commit 0a7bffe
Show file tree
Hide file tree
Showing 19 changed files with 84 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_implicit_auto_materialize_policy(
asset_key: AssetKey, asset_graph: AssetGraph
) -> Optional[AutoMaterializePolicy]:
"""For backcompat with pre-auto materialize policy graphs, assume a default scope of 1 day."""
auto_materialize_policy = asset_graph.get_auto_materialize_policy(asset_key)
auto_materialize_policy = asset_graph.get(asset_key).auto_materialize_policy
if auto_materialize_policy is None:
time_partitions_def = get_time_partitions_def(asset_graph.get_partitions_def(asset_key))
if time_partitions_def is None:
Expand Down Expand Up @@ -147,9 +147,7 @@ def auto_materialize_asset_keys_and_parents(self) -> AbstractSet[AssetKey]:
@property
def asset_records_to_prefetch(self) -> Sequence[AssetKey]:
return [
key
for key in self.auto_materialize_asset_keys_and_parents
if self.asset_graph.has_asset(key)
key for key in self.auto_materialize_asset_keys_and_parents if self.asset_graph.has(key)
]

@property
Expand Down Expand Up @@ -192,7 +190,7 @@ def evaluate_asset(
"""
# convert the legacy AutoMaterializePolicy to an Evaluator
asset_condition = check.not_none(
self.asset_graph.get_auto_materialize_policy(asset_key)
self.asset_graph.get(asset_key).auto_materialize_policy
).to_asset_condition()

asset_cursor = self.cursor.get_previous_evaluation_state(asset_key)
Expand Down Expand Up @@ -434,7 +432,7 @@ def build_run_requests_with_backfill_policies(
run_requests.append(RunRequest(asset_selection=list(asset_keys), tags=tags))
else:
backfill_policies = {
check.not_none(asset_graph.get_backfill_policy(asset_key))
check.not_none(asset_graph.get(asset_key).backfill_policy)
for asset_key in asset_keys
}
if len(backfill_policies) == 1:
Expand All @@ -453,7 +451,7 @@ def build_run_requests_with_backfill_policies(
else:
# if backfill policies are different, we need to backfill them separately
for asset_key in asset_keys:
backfill_policy = asset_graph.get_backfill_policy(asset_key)
backfill_policy = asset_graph.get(asset_key).backfill_policy
run_requests.extend(
_build_run_requests_with_backfill_policy(
[asset_key],
Expand Down Expand Up @@ -567,7 +565,7 @@ def get_auto_observe_run_requests(
assets_to_auto_observe: Set[AssetKey] = set()
for asset_key in auto_observe_asset_keys:
last_observe_request_timestamp = last_observe_request_timestamp_by_asset_key.get(asset_key)
auto_observe_interval_minutes = asset_graph.get_auto_observe_interval_minutes(asset_key)
auto_observe_interval_minutes = asset_graph.get(asset_key).auto_observe_interval_minutes

if auto_observe_interval_minutes and (
last_observe_request_timestamp is None
Expand Down
38 changes: 2 additions & 36 deletions python_modules/dagster/dagster/_core/definitions/asset_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ def asset_nodes(self) -> Iterable[T_AssetNode]:
def has(self, asset_key: AssetKey) -> bool:
return asset_key in self.asset_nodes_by_key

# To be removed in upstack PR and callsites replaced with `has`
def has_asset(self, asset_key: AssetKey) -> bool:
return self.has(asset_key)

def get(self, asset_key: AssetKey) -> T_AssetNode:
return self.asset_nodes_by_key[asset_key]

Expand All @@ -239,33 +235,21 @@ def all_asset_keys(self) -> AbstractSet[AssetKey]:
def materializable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_materializable}

def is_materializable(self, key: AssetKey) -> bool:
return self.get(key).is_materializable

@property
@cached_method
def observable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_observable}

def is_observable(self, key: AssetKey) -> bool:
return self.get(key).is_observable

@property
@cached_method
def external_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_external}

def is_external(self, key: AssetKey) -> bool:
return self.get(key).is_external

@property
@cached_method
def executable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_executable}

def is_executable(self, key: AssetKey) -> bool:
return self.get(key).is_executable

@property
@cached_method
def toposorted_asset_keys(self) -> Sequence[AssetKey]:
Expand Down Expand Up @@ -344,24 +328,6 @@ def get_partitions_in_range(
def is_partitioned(self, asset_key: AssetKey) -> bool:
return self.get_partitions_def(asset_key) is not None

def get_group_name(self, asset_key: AssetKey) -> Optional[str]:
return self.get(asset_key).group_name

def get_freshness_policy(self, asset_key: AssetKey) -> Optional[FreshnessPolicy]:
return self.get(asset_key).freshness_policy

def get_auto_materialize_policy(self, asset_key: AssetKey) -> Optional[AutoMaterializePolicy]:
return self.get(asset_key).auto_materialize_policy

def get_auto_observe_interval_minutes(self, asset_key: AssetKey) -> Optional[float]:
return self.get(asset_key).auto_observe_interval_minutes

def get_backfill_policy(self, asset_key: AssetKey) -> Optional[BackfillPolicy]:
return self.get(asset_key).backfill_policy

def get_code_version(self, asset_key: AssetKey) -> Optional[str]:
return self.get(asset_key).code_version

def have_same_partitioning(self, asset_key1: AssetKey, asset_key2: AssetKey) -> bool:
"""Returns whether the given assets have the same partitions definition."""
return self.get(asset_key1).partitions_def == self.get(asset_key2).partitions_def
Expand Down Expand Up @@ -644,7 +610,7 @@ def has_materializable_parents(self, asset_key: AssetKey) -> bool:
)

def get_materializable_roots(self, asset_key: AssetKey) -> AbstractSet[AssetKey]:
"""Returns all assets upstream of the given asset which do not consume any other
"""Returns all assets upstream of the giget which do not consume any other
materializable assets.
"""
if not self.has_materializable_parents(asset_key):
Expand All @@ -658,7 +624,7 @@ def get_materializable_roots(self, asset_key: AssetKey) -> AbstractSet[AssetKey]
}

def upstream_key_iterator(self, asset_key: AssetKey) -> Iterator[AssetKey]:
"""Iterates through all asset keys which are upstream of the given key."""
"""Iterates tgetl asset keys which are upstream of the given key."""
visited: Set[AssetKey] = set()
queue = deque([asset_key])
while queue:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def _compare_base_and_branch_assets(self, asset_key: "AssetKey") -> Sequence[Cha
return [ChangeReason.NEW]

changes = []
if self.branch_asset_graph.get_code_version(
asset_key
) != self.base_asset_graph.get_code_version(asset_key):
if (
self.branch_asset_graph.get(asset_key).code_version
!= self.base_asset_graph.get(asset_key).code_version
):
changes.append(ChangeReason.CODE_VERSION)

if self.branch_asset_graph.get_parents(asset_key) != self.base_asset_graph.get_parents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ def evaluate_for_asset(
# ignore missing or unexecutable assets, which will never have a materialization or
# observation
if not (
context.asset_graph.has_asset(parent.asset_key)
and context.asset_graph.is_executable(parent.asset_key)
context.asset_graph.has(parent.asset_key)
and context.asset_graph.get(parent.asset_key).is_executable
):
continue
if not context.instance_queryer.asset_partition_has_materialization_or_observation(
Expand Down
13 changes: 6 additions & 7 deletions python_modules/dagster/dagster/_core/definitions/data_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def _upstream_records_by_key(

for parent_key in self.asset_graph.get_parents(asset_key):
if not (
self.asset_graph.has_asset(parent_key)
and self.asset_graph.is_executable(parent_key)
self.asset_graph.has(parent_key) and self.asset_graph.get(parent_key).is_executable
):
continue

Expand Down Expand Up @@ -319,7 +318,7 @@ def _calculate_data_time_by_key(
cursor=record_id,
partitions_def=partitions_def,
)
elif self.asset_graph.is_observable(asset_key):
elif self.asset_graph.get(asset_key).is_observable:
return self._calculate_data_time_by_key_observable_source(
asset_key=asset_key,
record_id=record_id,
Expand Down Expand Up @@ -533,18 +532,18 @@ def get_minutes_overdue(
asset_key: AssetKey,
evaluation_time: datetime.datetime,
) -> Optional[FreshnessMinutes]:
freshness_policy = self.asset_graph.get_freshness_policy(asset_key)
if freshness_policy is None:
asset = self.asset_graph.get(asset_key)
if asset.freshness_policy is None:
raise DagsterInvariantViolationError(
"Cannot calculate minutes late for asset without a FreshnessPolicy"
)

if self.asset_graph.is_observable(asset_key):
if asset.is_observable:
current_data_time = self._get_source_data_time(asset_key, current_time=evaluation_time)
else:
current_data_time = self.get_current_data_time(asset_key, current_time=evaluation_time)

return freshness_policy.minutes_overdue(
return asset.freshness_policy.minutes_overdue(
data_time=current_data_time,
evaluation_time=evaluation_time,
)
40 changes: 23 additions & 17 deletions python_modules/dagster/dagster/_core/definitions/data_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,14 @@ def get_current_data_version(
def _get_status(self, key: "AssetKeyPartitionKey") -> StaleStatus:
# The status loader does not support querying for the stale status of a
# partitioned asset without specifying a partition, so we return here.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return StaleStatus.FRESH
else:
current_version = self._get_current_data_version(key=key)
if current_version == NULL_DATA_VERSION:
return StaleStatus.MISSING
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return StaleStatus.FRESH
else:
causes = self._get_stale_causes(key=key)
Expand All @@ -440,9 +441,10 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
# Querying for the stale status of a partitioned asset without specifying a partition key
# is strictly speaking undefined, but we return an empty list here (from which FRESH status
# is inferred) for backcompat.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return []
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return []
else:
current_version = self._get_current_data_version(key=key)
Expand All @@ -454,21 +456,22 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
)

def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitionKey") -> bool:
dep_asset = self.asset_graph.get(dep_key.asset_key)
if dep_key.partition_key is None:
current_data_version = self._get_current_data_version(key=dep_key)
return provenance.input_data_versions[dep_key.asset_key] != current_data_version
else:
cursor = provenance.input_storage_ids[dep_key.asset_key]
updated_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
after_cursor=cursor,
)
if updated_record:
previous_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
before_cursor=cursor + 1 if cursor else None,
)
Expand All @@ -485,7 +488,7 @@ def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitio
def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterator[StaleCause]:
from dagster._core.definitions.events import AssetKeyPartitionKey

code_version = self.asset_graph.get_code_version(key.asset_key)
code_version = self.asset_graph.get(key.asset_key).code_version
provenance = self._get_current_data_provenance(key=key)

asset_deps = self.asset_graph.get_parents(key.asset_key)
Expand Down Expand Up @@ -513,6 +516,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# partition counts.
partition_deps = self._get_partition_dependencies(key=key)
for dep_key in sorted(partition_deps):
dep_asset = self.asset_graph.get(dep_key.asset_key)
if self._get_status(key=dep_key) == StaleStatus.STALE:
yield StaleCause(
key,
Expand All @@ -532,9 +536,10 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# Currently we exclude assets downstream of AllPartitionMappings from stale
# status logic due to potentially huge numbers of dependencies.
elif self._is_dep_updated(provenance, dep_key):
report_data_version = self.asset_graph.get_code_version(
dep_key.asset_key
) is not None or self._is_current_data_version_user_provided(key=dep_key)
report_data_version = (
dep_asset.code_version is not None
or self._is_current_data_version_user_provided(key=dep_key)
)
yield StaleCause(
key,
StaleCauseCategory.DATA,
Expand Down Expand Up @@ -563,7 +568,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# timestamps instead of versions this should be removable eventually since
# provenance is on all newer materializations. If dep is a source, then we'll never
# provide a stale reason here.
elif not self.asset_graph.is_external(dep_key.asset_key):
elif not dep_asset.is_external:
dep_materialization = self._get_latest_data_version_record(key=dep_key)
if dep_materialization is None:
# The input must be new if it has no materialization
Expand Down Expand Up @@ -622,7 +627,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi
# Currently we can only use asset records, which are fetched in one shot, for non-source
# assets. This is because the most recent AssetObservation is not stored on the AssetRecord.
record = self._get_latest_data_version_record(key=key)
if self.asset_graph.is_external(key.asset_key) and record is None:
if self.asset_graph.get(key.asset_key).is_external and record is None:
return DEFAULT_DATA_VERSION
elif record is None:
return NULL_DATA_VERSION
Expand All @@ -632,7 +637,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi

@cached_method
def _is_current_data_version_user_provided(self, *, key: "AssetKeyPartitionKey") -> bool:
if self.asset_graph.is_external(key.asset_key):
if self.asset_graph.get(key.asset_key).is_external:
return True
else:
provenance = self._get_current_data_provenance(key=key)
Expand All @@ -654,10 +659,11 @@ def _get_current_data_provenance(
# are at the root of the graph (have no dependencies) or are downstream of a volatile asset.
@cached_method
def _is_volatile(self, *, key: "AssetKey") -> bool:
if self.asset_graph.is_external(key):
return self.asset_graph.is_observable(key)
asset = self.asset_graph.get(key)
if asset.is_external:
return asset.is_observable
else:
deps = self.asset_graph.get_parents(key)
deps = asset.get_parents(key)
return len(deps) == 0 or any(self._is_volatile(key=dep_key) for dep_key in deps)

@cached_method
Expand All @@ -678,7 +684,7 @@ def _get_latest_data_version_record(
# If an asset record is cached, all of its ancestors have already been cached.
if (
key.partition_key is None
and not self.asset_graph.is_external(key.asset_key)
and not self.asset_graph.get(key.asset_key).is_external
and not self.instance_queryer.has_cached_asset_record(key.asset_key)
):
ancestors = self.asset_graph.get_ancestors(key.asset_key, include_self=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def freshness_evaluation_results_for_asset_key(
execution_period,
evaluation_data,
) = get_execution_period_and_evaluation_data_for_policies(
local_policy=context.asset_graph.get_freshness_policy(asset_key),
local_policy=context.asset_graph.get(asset_key).freshness_policy,
policies=context.asset_graph.get_downstream_freshness_policies(asset_key=asset_key),
effective_data_time=effective_data_time,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _wrapped_fn(context: SensorEvaluationContext):

minutes_late_by_key: Dict[AssetKey, Optional[float]] = {}
for asset_key in monitored_keys:
freshness_policy = asset_graph.get_freshness_policy(asset_key)
freshness_policy = asset_graph.get(asset_key).freshness_policy
if freshness_policy is None:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ def get_execution_unit_asset_and_check_keys(

##### INTERNAL-SPECIFIC INTERFACE

def get_assets_def(self, asset_key: AssetKey) -> AssetsDefinition:
return self.get(asset_key).assets_def

@property
@cached_method
def assets_defs(self) -> Sequence[AssetsDefinition]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(
asset_graph = self._repository_def.asset_graph
for asset_key in self._monitored_asset_keys:
assets_def = (
asset_graph.get_assets_def(asset_key) if asset_graph.has_asset(asset_key) else None
asset_graph.get(asset_key).assets_def if asset_graph.has(asset_key) else None
)
self._assets_by_key[asset_key] = assets_def

Expand Down Expand Up @@ -683,8 +683,8 @@ def _get_asset(self, asset_key: AssetKey, fn_name: str) -> AssetsDefinition:
)
else:
return asset_def
elif repo_def.asset_graph.has_asset(asset_key):
return repo_def.asset_graph.get_assets_def(asset_key)
elif repo_def.asset_graph.has(asset_key):
return repo_def.asset_graph.get(asset_key).assets_def
else:
raise DagsterInvalidInvocationError(
f"Asset key {asset_key} not monitored in sensor and does not exist in target jobs"
Expand Down
Loading

0 comments on commit 0a7bffe

Please sign in to comment.