Skip to content

Commit

Permalink
[external-assets] Update AssetGraph accessor callsites to fetch nodes
Browse files Browse the repository at this point in the history
[INTERNAL_BRANCH=sean/external-assets-asset-graph-nodes]
  • Loading branch information
smackesey committed Mar 7, 2024
1 parent b3ba639 commit 620d0c2
Show file tree
Hide file tree
Showing 18 changed files with 87 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_implicit_auto_materialize_policy(
asset_key: AssetKey, asset_graph: BaseAssetGraph
) -> Optional[AutoMaterializePolicy]:
"""For backcompat with pre-auto materialize policy graphs, assume a default scope of 1 day."""
auto_materialize_policy = asset_graph.get_auto_materialize_policy(asset_key)
auto_materialize_policy = asset_graph.get(asset_key).auto_materialize_policy
if auto_materialize_policy is None:
time_partitions_def = get_time_partitions_def(asset_graph.get_partitions_def(asset_key))
if time_partitions_def is None:
Expand Down Expand Up @@ -147,9 +147,7 @@ def auto_materialize_asset_keys_and_parents(self) -> AbstractSet[AssetKey]:
@property
def asset_records_to_prefetch(self) -> Sequence[AssetKey]:
return [
key
for key in self.auto_materialize_asset_keys_and_parents
if self.asset_graph.has_asset(key)
key for key in self.auto_materialize_asset_keys_and_parents if self.asset_graph.has(key)
]

@property
Expand Down Expand Up @@ -192,7 +190,7 @@ def evaluate_asset(
"""
# convert the legacy AutoMaterializePolicy to an Evaluator
asset_condition = check.not_none(
self.asset_graph.get_auto_materialize_policy(asset_key)
self.asset_graph.get(asset_key).auto_materialize_policy
).to_asset_condition()

asset_cursor = self.cursor.get_previous_evaluation_state(asset_key)
Expand Down Expand Up @@ -434,7 +432,7 @@ def build_run_requests_with_backfill_policies(
run_requests.append(RunRequest(asset_selection=list(asset_keys), tags=tags))
else:
backfill_policies = {
check.not_none(asset_graph.get_backfill_policy(asset_key))
check.not_none(asset_graph.get(asset_key).backfill_policy)
for asset_key in asset_keys
}
if len(backfill_policies) == 1:
Expand All @@ -453,7 +451,7 @@ def build_run_requests_with_backfill_policies(
else:
# if backfill policies are different, we need to backfill them separately
for asset_key in asset_keys:
backfill_policy = asset_graph.get_backfill_policy(asset_key)
backfill_policy = asset_graph.get(asset_key).backfill_policy
run_requests.extend(
_build_run_requests_with_backfill_policy(
[asset_key],
Expand Down Expand Up @@ -567,7 +565,7 @@ def get_auto_observe_run_requests(
assets_to_auto_observe: Set[AssetKey] = set()
for asset_key in auto_observe_asset_keys:
last_observe_request_timestamp = last_observe_request_timestamp_by_asset_key.get(asset_key)
auto_observe_interval_minutes = asset_graph.get_auto_observe_interval_minutes(asset_key)
auto_observe_interval_minutes = asset_graph.get(asset_key).auto_observe_interval_minutes

if auto_observe_interval_minutes and (
last_observe_request_timestamp is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.metadata import ArbitraryMetadataMapping
from dagster._core.definitions.partition import PartitionsDefinition
from dagster._core.definitions.partition_mapping import PartitionMapping
from dagster._core.definitions.resolved_asset_deps import ResolvedAssetDependencies
Expand Down Expand Up @@ -62,6 +63,10 @@ def is_external(self) -> bool:
def is_executable(self) -> bool:
return self.assets_def.is_executable

@property
def metadata(self) -> ArbitraryMetadataMapping:
return self.assets_def.metadata_by_key.get(self.key, {})

@property
def is_partitioned(self) -> bool:
return self.assets_def.partitions_def is not None
Expand Down Expand Up @@ -246,9 +251,6 @@ def get_execution_set_asset_and_check_keys(
asset_unit_keys if asset_or_check_key in asset_unit_keys else {asset_or_check_key}
)

def get_assets_def(self, asset_key: AssetKey) -> AssetsDefinition:
return self.get(asset_key).assets_def

@property
@cached_method
def assets_defs(self) -> Sequence[AssetsDefinition]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ def _compare_base_and_branch_assets(self, asset_key: "AssetKey") -> Sequence[Cha
return [ChangeReason.NEW]

changes = []
if self.branch_asset_graph.get_code_version(
asset_key
) != self.base_asset_graph.get_code_version(asset_key):
if (
self.branch_asset_graph.get(asset_key).code_version
!= self.base_asset_graph.get(asset_key).code_version
):
changes.append(ChangeReason.CODE_VERSION)

if self.branch_asset_graph.get_parents(asset_key) != self.base_asset_graph.get_parents(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,8 @@ def evaluate_for_asset(
# ignore missing or unexecutable assets, which will never have a materialization or
# observation
if not (
context.asset_graph.has_asset(parent.asset_key)
and context.asset_graph.is_executable(parent.asset_key)
context.asset_graph.has(parent.asset_key)
and context.asset_graph.get(parent.asset_key).is_executable
):
continue
if not context.instance_queryer.asset_partition_has_materialization_or_observation(
Expand Down
13 changes: 6 additions & 7 deletions python_modules/dagster/dagster/_core/definitions/data_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def _upstream_records_by_key(

for parent_key in self.asset_graph.get_parents(asset_key):
if not (
self.asset_graph.has_asset(parent_key)
and self.asset_graph.is_executable(parent_key)
self.asset_graph.has(parent_key) and self.asset_graph.get(parent_key).is_executable
):
continue

Expand Down Expand Up @@ -319,7 +318,7 @@ def _calculate_data_time_by_key(
cursor=record_id,
partitions_def=partitions_def,
)
elif self.asset_graph.is_observable(asset_key):
elif self.asset_graph.get(asset_key).is_observable:
return self._calculate_data_time_by_key_observable_source(
asset_key=asset_key,
record_id=record_id,
Expand Down Expand Up @@ -533,18 +532,18 @@ def get_minutes_overdue(
asset_key: AssetKey,
evaluation_time: datetime.datetime,
) -> Optional[FreshnessMinutes]:
freshness_policy = self.asset_graph.get_freshness_policy(asset_key)
if freshness_policy is None:
asset = self.asset_graph.get(asset_key)
if asset.freshness_policy is None:
raise DagsterInvariantViolationError(
"Cannot calculate minutes late for asset without a FreshnessPolicy"
)

if self.asset_graph.is_observable(asset_key):
if asset.is_observable:
current_data_time = self._get_source_data_time(asset_key, current_time=evaluation_time)
else:
current_data_time = self.get_current_data_time(asset_key, current_time=evaluation_time)

return freshness_policy.minutes_overdue(
return asset.freshness_policy.minutes_overdue(
data_time=current_data_time,
evaluation_time=evaluation_time,
)
40 changes: 23 additions & 17 deletions python_modules/dagster/dagster/_core/definitions/data_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,14 @@ def get_current_data_version(
def _get_status(self, key: "AssetKeyPartitionKey") -> StaleStatus:
# The status loader does not support querying for the stale status of a
# partitioned asset without specifying a partition, so we return here.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return StaleStatus.FRESH
else:
current_version = self._get_current_data_version(key=key)
if current_version == NULL_DATA_VERSION:
return StaleStatus.MISSING
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return StaleStatus.FRESH
else:
causes = self._get_stale_causes(key=key)
Expand All @@ -440,9 +441,10 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
# Querying for the stale status of a partitioned asset without specifying a partition key
# is strictly speaking undefined, but we return an empty list here (from which FRESH status
# is inferred) for backcompat.
if self.asset_graph.is_partitioned(key.asset_key) and not key.partition_key:
asset = self.asset_graph.get(key.asset_key)
if asset.is_partitioned and not key.partition_key:
return []
elif self.asset_graph.is_external(key.asset_key):
elif asset.is_external:
return []
else:
current_version = self._get_current_data_version(key=key)
Expand All @@ -454,21 +456,22 @@ def _get_stale_causes(self, key: "AssetKeyPartitionKey") -> Sequence[StaleCause]
)

def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitionKey") -> bool:
dep_asset = self.asset_graph.get(dep_key.asset_key)
if dep_key.partition_key is None:
current_data_version = self._get_current_data_version(key=dep_key)
return provenance.input_data_versions[dep_key.asset_key] != current_data_version
else:
cursor = provenance.input_storage_ids[dep_key.asset_key]
updated_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
after_cursor=cursor,
)
if updated_record:
previous_record = self._instance.get_latest_data_version_record(
dep_key.asset_key,
self.asset_graph.is_external(dep_key.asset_key),
dep_asset.is_external,
dep_key.partition_key,
before_cursor=cursor + 1 if cursor else None,
)
Expand All @@ -485,7 +488,7 @@ def _is_dep_updated(self, provenance: DataProvenance, dep_key: "AssetKeyPartitio
def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterator[StaleCause]:
from dagster._core.definitions.events import AssetKeyPartitionKey

code_version = self.asset_graph.get_code_version(key.asset_key)
code_version = self.asset_graph.get(key.asset_key).code_version
provenance = self._get_current_data_provenance(key=key)

asset_deps = self.asset_graph.get_parents(key.asset_key)
Expand Down Expand Up @@ -513,6 +516,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# partition counts.
partition_deps = self._get_partition_dependencies(key=key)
for dep_key in sorted(partition_deps):
dep_asset = self.asset_graph.get(dep_key.asset_key)
if self._get_status(key=dep_key) == StaleStatus.STALE:
yield StaleCause(
key,
Expand All @@ -532,9 +536,10 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# Currently we exclude assets downstream of AllPartitionMappings from stale
# status logic due to potentially huge numbers of dependencies.
elif self._is_dep_updated(provenance, dep_key):
report_data_version = self.asset_graph.get_code_version(
dep_key.asset_key
) is not None or self._is_current_data_version_user_provided(key=dep_key)
report_data_version = (
dep_asset.code_version is not None
or self._is_current_data_version_user_provided(key=dep_key)
)
yield StaleCause(
key,
StaleCauseCategory.DATA,
Expand Down Expand Up @@ -563,7 +568,7 @@ def _get_stale_causes_materialized(self, key: "AssetKeyPartitionKey") -> Iterato
# timestamps instead of versions this should be removable eventually since
# provenance is on all newer materializations. If dep is a source, then we'll never
# provide a stale reason here.
elif not self.asset_graph.is_external(dep_key.asset_key):
elif not dep_asset.is_external:
dep_materialization = self._get_latest_data_version_record(key=dep_key)
if dep_materialization is None:
# The input must be new if it has no materialization
Expand Down Expand Up @@ -622,7 +627,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi
# Currently we can only use asset records, which are fetched in one shot, for non-source
# assets. This is because the most recent AssetObservation is not stored on the AssetRecord.
record = self._get_latest_data_version_record(key=key)
if self.asset_graph.is_external(key.asset_key) and record is None:
if self.asset_graph.get(key.asset_key).is_external and record is None:
return DEFAULT_DATA_VERSION
elif record is None:
return NULL_DATA_VERSION
Expand All @@ -632,7 +637,7 @@ def _get_current_data_version(self, *, key: "AssetKeyPartitionKey") -> DataVersi

@cached_method
def _is_current_data_version_user_provided(self, *, key: "AssetKeyPartitionKey") -> bool:
if self.asset_graph.is_external(key.asset_key):
if self.asset_graph.get(key.asset_key).is_external:
return True
else:
provenance = self._get_current_data_provenance(key=key)
Expand All @@ -654,10 +659,11 @@ def _get_current_data_provenance(
# are at the root of the graph (have no dependencies) or are downstream of a volatile asset.
@cached_method
def _is_volatile(self, *, key: "AssetKey") -> bool:
if self.asset_graph.is_external(key):
return self.asset_graph.is_observable(key)
asset = self.asset_graph.get(key)
if asset.is_external:
return asset.is_observable
else:
deps = self.asset_graph.get_parents(key)
deps = asset.get_parents(key)
return len(deps) == 0 or any(self._is_volatile(key=dep_key) for dep_key in deps)

@cached_method
Expand All @@ -678,7 +684,7 @@ def _get_latest_data_version_record(
# If an asset record is cached, all of its ancestors have already been cached.
if (
key.partition_key is None
and not self.asset_graph.is_external(key.asset_key)
and not self.asset_graph.get(key.asset_key).is_external
and not self.instance_queryer.has_cached_asset_record(key.asset_key)
):
ancestors = self.asset_graph.get_ancestors(key.asset_key, include_self=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def freshness_evaluation_results_for_asset_key(
execution_period,
evaluation_data,
) = get_execution_period_and_evaluation_data_for_policies(
local_policy=context.asset_graph.get_freshness_policy(asset_key),
local_policy=context.asset_graph.get(asset_key).freshness_policy,
policies=context.asset_graph.get_downstream_freshness_policies(asset_key=asset_key),
effective_data_time=effective_data_time,
current_time=current_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _wrapped_fn(context: SensorEvaluationContext):

minutes_late_by_key: Dict[AssetKey, Optional[float]] = {}
for asset_key in monitored_keys:
freshness_policy = asset_graph.get_freshness_policy(asset_key)
freshness_policy = asset_graph.get(asset_key).freshness_policy
if freshness_policy is None:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(
asset_graph = self._repository_def.asset_graph
for asset_key in self._monitored_asset_keys:
assets_def = (
asset_graph.get_assets_def(asset_key) if asset_graph.has_asset(asset_key) else None
asset_graph.get(asset_key).assets_def if asset_graph.has(asset_key) else None
)
self._assets_by_key[asset_key] = assets_def

Expand Down Expand Up @@ -683,8 +683,8 @@ def _get_asset(self, asset_key: AssetKey, fn_name: str) -> AssetsDefinition:
)
else:
return asset_def
elif repo_def.asset_graph.has_asset(asset_key):
return repo_def.asset_graph.get_assets_def(asset_key)
elif repo_def.asset_graph.has(asset_key):
return repo_def.asset_graph.get(asset_key).assets_def
else:
raise DagsterInvalidInvocationError(
f"Asset key {asset_key} not monitored in sensor and does not exist in target jobs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ def execute_asset_backfill_iteration_inner(

# check if all assets have backfill policies if any of them do, otherwise, raise error
asset_backfill_policies = [
asset_graph.get_backfill_policy(asset_key)
asset_graph.get(asset_key).backfill_policy
for asset_key in {
asset_partition.asset_key for asset_partition in asset_partitions_to_request
}
Expand Down Expand Up @@ -1318,9 +1318,9 @@ def can_run_with_parent(
this tick.
"""
parent_target_subset = target_subset.get_asset_subset(parent.asset_key, asset_graph)
parent_backfill_policy = asset_graph.get_backfill_policy(parent.asset_key)
parent_backfill_policy = asset_graph.get(parent.asset_key).backfill_policy
candidate_target_subset = target_subset.get_asset_subset(candidate.asset_key, asset_graph)
candidate_backfill_policy = asset_graph.get_backfill_policy(candidate.asset_key)
candidate_backfill_policy = asset_graph.get(candidate.asset_key).backfill_policy
partition_mapping = asset_graph.get_partition_mapping(
candidate.asset_key, in_asset_key=parent.asset_key
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,14 @@ def _external_sensors(self) -> Dict[str, "ExternalSensor"]:
default_sensor_asset_keys = set()

for asset_key in asset_graph.materializable_asset_keys:
policy = asset_graph.get_auto_materialize_policy(asset_key)
if not policy:
if not asset_graph.get(asset_key).auto_materialize_policy:
continue

if asset_key not in covered_asset_keys:
default_sensor_asset_keys.add(asset_key)

for asset_key in asset_graph.observable_asset_keys:
if asset_graph.get_auto_observe_interval_minutes(asset_key) is None:
if asset_graph.get(asset_key).auto_observe_interval_minutes is None:
continue

has_any_auto_observe_source_assets = True
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_daemon/asset_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,14 +683,14 @@ def _process_auto_materialize_tick_generator(
auto_materialize_asset_keys = {
target_key
for target_key in eligible_keys
if asset_graph.get_auto_materialize_policy(target_key) is not None
if asset_graph.get(target_key).auto_materialize_policy is not None
}
num_target_assets = len(auto_materialize_asset_keys)

auto_observe_asset_keys = {
key
for key in eligible_keys
if asset_graph.get_auto_observe_interval_minutes(key) is not None
if asset_graph.get(key).auto_observe_interval_minutes is not None
}
num_auto_observe_assets = len(auto_observe_asset_keys)

Expand Down
Loading

0 comments on commit 620d0c2

Please sign in to comment.