Skip to content

Commit

Permalink
[external-assets] Update AssetGraph accessor callsites to fetch nodes (
Browse files Browse the repository at this point in the history
…#20132)

## Summary & Motivation

Internal companion PR: dagster-io/internal#8462

Part 2 of `AssetGraph` node implementation. Here we remove the old
property accessor methods on `AssetGraph` and change the callsites to
access nodes and call the accessor on the node objects instead.

## How I Tested These Changes

Existing test suite.
  • Loading branch information
smackesey authored and PedramNavid committed Mar 28, 2024
1 parent fcc4c9b commit b8d6f48
Show file tree
Hide file tree
Showing 19 changed files with 94 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_implicit_auto_materialize_policy(
asset_key: AssetKey, asset_graph: BaseAssetGraph
) -> Optional[AutoMaterializePolicy]:
"""For backcompat with pre-auto materialize policy graphs, assume a default scope of 1 day."""
auto_materialize_policy = asset_graph.get_auto_materialize_policy(asset_key)
auto_materialize_policy = asset_graph.get(asset_key).auto_materialize_policy
if auto_materialize_policy is None:
time_partitions_def = get_time_partitions_def(asset_graph.get_partitions_def(asset_key))
if time_partitions_def is None:
Expand Down Expand Up @@ -147,9 +147,7 @@ def auto_materialize_asset_keys_and_parents(self) -> AbstractSet[AssetKey]:
@property
def asset_records_to_prefetch(self) -> Sequence[AssetKey]:
return [
key
for key in self.auto_materialize_asset_keys_and_parents
if self.asset_graph.has_asset(key)
key for key in self.auto_materialize_asset_keys_and_parents if self.asset_graph.has(key)
]

@property
Expand Down Expand Up @@ -192,7 +190,7 @@ def evaluate_asset(
"""
# convert the legacy AutoMaterializePolicy to an Evaluator
asset_condition = check.not_none(
self.asset_graph.get_auto_materialize_policy(asset_key)
self.asset_graph.get(asset_key).auto_materialize_policy
).to_asset_condition()

asset_cursor = self.cursor.get_previous_evaluation_state(asset_key)
Expand Down Expand Up @@ -434,7 +432,7 @@ def build_run_requests_with_backfill_policies(
run_requests.append(RunRequest(asset_selection=list(asset_keys), tags=tags))
else:
backfill_policies = {
check.not_none(asset_graph.get_backfill_policy(asset_key))
check.not_none(asset_graph.get(asset_key).backfill_policy)
for asset_key in asset_keys
}
if len(backfill_policies) == 1:
Expand All @@ -453,7 +451,7 @@ def build_run_requests_with_backfill_policies(
else:
# if backfill policies are different, we need to backfill them separately
for asset_key in asset_keys:
backfill_policy = asset_graph.get_backfill_policy(asset_key)
backfill_policy = asset_graph.get(asset_key).backfill_policy
run_requests.extend(
_build_run_requests_with_backfill_policy(
[asset_key],
Expand Down Expand Up @@ -567,7 +565,7 @@ def get_auto_observe_run_requests(
assets_to_auto_observe: Set[AssetKey] = set()
for asset_key in auto_observe_asset_keys:
last_observe_request_timestamp = last_observe_request_timestamp_by_asset_key.get(asset_key)
auto_observe_interval_minutes = asset_graph.get_auto_observe_interval_minutes(asset_key)
auto_observe_interval_minutes = asset_graph.get(asset_key).auto_observe_interval_minutes

if auto_observe_interval_minutes and (
last_observe_request_timestamp is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.metadata import ArbitraryMetadataMapping
from dagster._core.definitions.partition import PartitionsDefinition
from dagster._core.definitions.partition_mapping import PartitionMapping
from dagster._core.definitions.resolved_asset_deps import ResolvedAssetDependencies
Expand Down Expand Up @@ -61,6 +62,10 @@ def is_external(self) -> bool:
def is_executable(self) -> bool:
return self.assets_def.is_executable

@property
def metadata(self) -> ArbitraryMetadataMapping:
return self.assets_def.metadata_by_key.get(self.key, {})

@property
def is_partitioned(self) -> bool:
return self.assets_def.partitions_def is not None
Expand Down Expand Up @@ -245,9 +250,6 @@ def get_execution_set_asset_and_check_keys(
asset_unit_keys if asset_or_check_key in asset_unit_keys else {asset_or_check_key}
)

def get_assets_def(self, asset_key: AssetKey) -> AssetsDefinition:
return self.get(asset_key).assets_def

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
return list(dict.fromkeys(asset.assets_def for asset in self.asset_nodes))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def _compare_base_and_branch_assets(self, asset_key: "AssetKey") -> Sequence[Cha
return [ChangeReason.NEW]

changes = []
if self.branch_asset_graph.get_code_version(
asset_key
) != self.base_asset_graph.get_code_version(asset_key):
if (
self.branch_asset_graph.get(asset_key).code_version
!= self.base_asset_graph.get(asset_key).code_version
):
changes.append(ChangeReason.CODE_VERSION)

if self.branch_asset_graph.get_parents(asset_key) != self.base_asset_graph.get_parents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,8 @@ def evaluate_for_asset(
# ignore missing or unexecutable assets, which will never have a materialization or
# observation
if not (
context.asset_graph.has_asset(parent.asset_key)
and context.asset_graph.is_executable(parent.asset_key)
context.asset_graph.has(parent.asset_key)
and context.asset_graph.get(parent.asset_key).is_executable
):
continue
if not context.instance_queryer.asset_partition_has_materialization_or_observation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ def asset_nodes(self) -> Iterable[T_AssetNode]:
def has(self, asset_key: AssetKey) -> bool:
return asset_key in self._asset_nodes_by_key

# To be removed in upstack PR and callsites replaced with `has`
def has_asset(self, asset_key: AssetKey) -> bool:
return self.has(asset_key)

def get(self, asset_key: AssetKey) -> T_AssetNode:
return self._asset_nodes_by_key[asset_key]

Expand All @@ -193,30 +189,18 @@ def all_asset_keys(self) -> AbstractSet[AssetKey]:
def materializable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_materializable}

def is_materializable(self, key: AssetKey) -> bool:
return self.get(key).is_materializable

@cached_property
def observable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_observable}

def is_observable(self, key: AssetKey) -> bool:
return self.get(key).is_observable

@cached_property
def external_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_external}

def is_external(self, key: AssetKey) -> bool:
return self.get(key).is_external

@cached_property
def executable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_executable}

def is_executable(self, key: AssetKey) -> bool:
return self.get(key).is_executable

@cached_property
def toposorted_asset_keys(self) -> Sequence[AssetKey]:
"""Return topologically sorted asset keys in graph. Keys with the same topological level are
Expand Down Expand Up @@ -287,24 +271,6 @@ def get_partitions_in_range(
def is_partitioned(self, asset_key: AssetKey) -> bool:
return self.get_partitions_def(asset_key) is not None

def get_group_name(self, asset_key: AssetKey) -> Optional[str]:
return self.get(asset_key).group_name

def get_freshness_policy(self, asset_key: AssetKey) -> Optional[FreshnessPolicy]:
return self.get(asset_key).freshness_policy

def get_auto_materialize_policy(self, asset_key: AssetKey) -> Optional[AutoMaterializePolicy]:
return self.get(asset_key).auto_materialize_policy

def get_auto_observe_interval_minutes(self, asset_key: AssetKey) -> Optional[float]:
return self.get(asset_key).auto_observe_interval_minutes

def get_backfill_policy(self, asset_key: AssetKey) -> Optional[BackfillPolicy]:
return self.get(asset_key).backfill_policy

def get_code_version(self, asset_key: AssetKey) -> Optional[str]:
return self.get(asset_key).code_version

def have_same_partitioning(self, asset_key1: AssetKey, asset_key2: AssetKey) -> bool:
"""Returns whether the given assets have the same partitions definition."""
return self.get(asset_key1).partitions_def == self.get(asset_key2).partitions_def
Expand Down Expand Up @@ -336,7 +302,7 @@ def get_parents(self, asset_key: AssetKey) -> AbstractSet[AssetKey]:
"""Returns all asset keys that are direct dependencies on the given asset key."""
return self.get(asset_key).parent_keys

def get_ancestors(
def get_ancestor_asset_keys(
self, asset_key: AssetKey, include_self: bool = False
) -> AbstractSet[AssetKey]:
"""Returns all nth-order dependencies of an asset."""
Expand Down Expand Up @@ -643,7 +609,7 @@ def get_downstream_freshness_policies(
downstream_policies = set().union(
*(
self.get_downstream_freshness_policies(asset_key=child_key)
for child_key in self.get_children(asset_key)
for child_key in self.get(asset_key).child_keys
if child_key != asset_key
)
)
Expand Down
13 changes: 6 additions & 7 deletions python_modules/dagster/dagster/_core/definitions/data_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def _upstream_records_by_key(

for parent_key in self.asset_graph.get_parents(asset_key):
if not (
self.asset_graph.has_asset(parent_key)
and self.asset_graph.is_executable(parent_key)
self.asset_graph.has(parent_key) and self.asset_graph.get(parent_key).is_executable
):
continue

Expand Down Expand Up @@ -319,7 +318,7 @@ def _calculate_data_time_by_key(
cursor=record_id,
partitions_def=partitions_def,
)
elif self.asset_graph.is_observable(asset_key):
elif self.asset_graph.get(asset_key).is_observable:
return self._calculate_data_time_by_key_observable_source(
asset_key=asset_key,
record_id=record_id,
Expand Down Expand Up @@ -533,18 +532,18 @@ def get_minutes_overdue(
asset_key: AssetKey,
evaluation_time: datetime.datetime,
) -> Optional[FreshnessMinutes]:
freshness_policy = self.asset_graph.get_freshness_policy(asset_key)
if freshness_policy is None:
asset = self.asset_graph.get(asset_key)
if asset.freshness_policy is None:
raise DagsterInvariantViolationError(
"Cannot calculate minutes late for asset without a FreshnessPolicy"
)

if self.asset_graph.is_observable(asset_key):
if asset.is_observable:
current_data_time = self._get_source_data_time(asset_key, current_time=evaluation_time)
else:
current_data_time = self.get_current_data_time(asset_key, current_time=evaluation_time)

return freshness_policy.minutes_overdue(
return asset.freshness_policy.minutes_overdue(
data_time=current_data_time,
evaluation_time=evaluation_time,
)
42 changes: 24 additions & 18 deletions python_modules/dagster/dagster/_core/definitions/data_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,14 @@ def get_current_data_version(
def _get_status(self, key: "AssetKeyPartitionKey") -> StaleStatus:
# The status loader does not support querying for the stale status of a
# partitioned asset without specifying a partition, so we return here.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return StaleStatus.FRESH
else:
current_version = self._get_current_data_version(key=key)
if current_version == NULL_DATA_VERSION:
return StaleStatus.MISSING
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return StaleStatus.FRESH
else:
causes = self._get_stale_causes(key=key)
Expand All @@ -440,9 +441,10 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
# Querying for the stale status of a partitioned asset without specifying a partition key
# is strictly speaking undefined, but we return an empty list here (from which FRESH status
# is inferred) for backcompat.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return []
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return []
else:
current_version = self._get_current_data_version(key=key)
Expand All @@ -454,21 +456,22 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
)

def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitionKey") -> bool:
dep_asset = self.asset_graph.get(dep_key.asset_key)
if dep_key.partition_key is None:
current_data_version = self._get_current_data_version(key=dep_key)
return provenance.input_data_versions[dep_key.asset_key] != current_data_version
else:
cursor = provenance.input_storage_ids[dep_key.asset_key]
updated_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
after_cursor=cursor,
)
if updated_record:
previous_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
before_cursor=cursor + 1 if cursor else None,
)
Expand All @@ -485,7 +488,7 @@ def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitio
def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterator[StaleCause]:
from dagster._core.definitions.events import AssetKeyPartitionKey

code_version = self.asset_graph.get_code_version(key.asset_key)
code_version = self.asset_graph.get(key.asset_key).code_version
provenance = self._get_current_data_provenance(key=key)

asset_deps = self.asset_graph.get_parents(key.asset_key)
Expand Down Expand Up @@ -513,6 +516,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# partition counts.
partition_deps = self._get_partition_dependencies(key=key)
for dep_key in sorted(partition_deps):
dep_asset = self.asset_graph.get(dep_key.asset_key)
if self._get_status(key=dep_key) == StaleStatus.STALE:
yield StaleCause(
key,
Expand All @@ -532,9 +536,10 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# Currently we exclude assets downstream of AllPartitionMappings from stale
# status logic due to potentially huge numbers of dependencies.
elif self._is_dep_updated(provenance, dep_key):
report_data_version = self.asset_graph.get_code_version(
dep_key.asset_key
) is not None or self._is_current_data_version_user_provided(key=dep_key)
report_data_version = (
dep_asset.code_version is not None
or self._is_current_data_version_user_provided(key=dep_key)
)
yield StaleCause(
key,
StaleCauseCategory.DATA,
Expand Down Expand Up @@ -563,7 +568,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# timestamps instead of versions this should be removable eventually since
# provenance is on all newer materializations. If dep is a source, then we'll never
# provide a stale reason here.
elif not self.asset_graph.is_external(dep_key.asset_key):
elif not dep_asset.is_external:
dep_materialization = self._get_latest_data_version_record(key=dep_key)
if dep_materialization is None:
# The input must be new if it has no materialization
Expand Down Expand Up @@ -622,7 +627,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi
# Currently we can only use asset records, which are fetched in one shot, for non-source
# assets. This is because the most recent AssetObservation is not stored on the AssetRecord.
record = self._get_latest_data_version_record(key=key)
if self.asset_graph.is_external(key.asset_key) and record is None:
if self.asset_graph.get(key.asset_key).is_external and record is None:
return DEFAULT_DATA_VERSION
elif record is None:
return NULL_DATA_VERSION
Expand All @@ -632,7 +637,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi

@cached_method
def _is_current_data_version_user_provided(self, *, key: "AssetKeyPartitionKey") -> bool:
if self.asset_graph.is_external(key.asset_key):
if self.asset_graph.get(key.asset_key).is_external:
return True
else:
provenance = self._get_current_data_provenance(key=key)
Expand All @@ -654,10 +659,11 @@ def _get_current_data_provenance(
# are at the root of the graph (have no dependencies) or are downstream of a volatile asset.
@cached_method
def _is_volatile(self, *, key: "AssetKey") -> bool:
if self.asset_graph.is_external(key):
return self.asset_graph.is_observable(key)
asset = self.asset_graph.get(key)
if asset.is_external:
return asset.is_observable
else:
deps = self.asset_graph.get_parents(key)
deps = asset.get_parents(key)
return len(deps) == 0 or any(self._is_volatile(key=dep_key) for dep_key in deps)

@cached_method
Expand All @@ -678,10 +684,10 @@ def _get_latest_data_version_record(
# If an asset record is cached, all of its ancestors have already been cached.
if (
key.partition_key is None
and not self.asset_graph.is_external(key.asset_key)
and not self.asset_graph.get(key.asset_key).is_external
and not self.instance_queryer.has_cached_asset_record(key.asset_key)
):
ancestors = self.asset_graph.get_ancestors(key.asset_key, include_self=True)
ancestors = self.asset_graph.get_ancestor_asset_keys(key.asset_key, include_self=True)
self.instance_queryer.prefetch_asset_records(ancestors)
return self.instance_queryer.get_latest_materialization_or_observation_record(
asset_partition=key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def freshness_evaluation_results_for_asset_key(
execution_period,
evaluation_data,
) = get_execution_period_and_evaluation_data_for_policies(
local_policy=context.asset_graph.get_freshness_policy(asset_key),
local_policy=context.asset_graph.get(asset_key).freshness_policy,
policies=context.asset_graph.get_downstream_freshness_policies(asset_key=asset_key),
effective_data_time=effective_data_time,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _wrapped_fn(context: SensorEvaluationContext):

minutes_late_by_key: Dict[AssetKey, Optional[float]] = {}
for asset_key in monitored_keys:
freshness_policy = asset_graph.get_freshness_policy(asset_key)
freshness_policy = asset_graph.get(asset_key).freshness_policy
if freshness_policy is None:
continue

Expand Down
Loading

0 comments on commit b8d6f48

Please sign in to comment.