Skip to content

Commit

Permalink
Make the asset status cache do no additional DB work when the cursor …
Browse files Browse the repository at this point in the history
…is up to date and there are no in-progress runs materializing the asset (#21194)

Summary:
The goal of this PR is to speed up the asset partition status cache in
the hopefully reasonably common case where the cache is up to date and
there are no in-progress runs currently materializing the asset. By
leveraging the "last_planned_materialization_storage_id" field on
AssetEntry, which is set in some storages but not others, we can add
additional checks to short-circuit any DB queries using only the
information that has already been fetched on the AssetEntry.

## Summary & Motivation

## How I Tested These Changes
  • Loading branch information
gibsondan authored May 1, 2024
1 parent e3d3ff9 commit bc81704
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@
from dagster._core.remote_representation.code_location import CodeLocation
from dagster._core.remote_representation.external import ExternalRepository
from dagster._core.remote_representation.external_data import ExternalAssetNode
from dagster._core.storage.event_log.base import AssetRecord
from dagster._core.storage.batch_asset_record_loader import BatchAssetRecordLoader
from dagster._core.storage.event_log.sql_event_log import get_max_event_records_limit
from dagster._core.storage.partition_status_cache import (
build_failed_and_in_progress_partition_subset,
get_and_update_asset_status_cache_value,
get_last_planned_storage_id,
get_materialized_multipartitions,
get_validated_partition_keys,
is_cacheable_partition_type,
Expand Down Expand Up @@ -414,7 +415,7 @@ def get_partition_subsets(
instance: DagsterInstance,
asset_key: AssetKey,
dynamic_partitions_loader: DynamicPartitionsStore,
asset_record: Optional[AssetRecord],
batch_asset_record_loader: Optional[BatchAssetRecordLoader],
partitions_def: Optional[PartitionsDefinition] = None,
) -> Tuple[Optional[PartitionsSubset], Optional[PartitionsSubset], Optional[PartitionsSubset]]:
"""Returns a tuple of PartitionSubset objects: the first is the materialized partitions,
Expand All @@ -431,7 +432,7 @@ def get_partition_subsets(
asset_key,
partitions_def,
dynamic_partitions_loader,
asset_record,
batch_asset_record_loader,
)
materialized_subset = (
updated_cache_value.deserialize_materialized_partition_subsets(partitions_def)
Expand Down Expand Up @@ -470,8 +471,19 @@ def get_partition_subsets(
else partitions_def.empty_subset()
)

if batch_asset_record_loader:
asset_record = batch_asset_record_loader.get_asset_record(asset_key)
else:
asset_record = next(iter(instance.get_asset_records(asset_keys=[asset_key])), None)

failed_subset, in_progress_subset, _ = build_failed_and_in_progress_partition_subset(
instance, asset_key, partitions_def, dynamic_partitions_loader
instance,
asset_key,
partitions_def,
dynamic_partitions_loader,
last_planned_materialization_storage_id=get_last_planned_storage_id(
instance, asset_key, asset_record
),
)

return materialized_subset, failed_subset, in_progress_subset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .external import ensure_valid_config, get_external_job_or_raise

if TYPE_CHECKING:
from dagster._core.workspace.batch_asset_record_loader import BatchAssetRecordLoader
from dagster._core.storage.batch_asset_record_loader import BatchAssetRecordLoader

from ..schema.asset_graph import GrapheneAssetLatestInfo
from ..schema.errors import GrapheneRunNotFoundError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
ExternalTimeWindowPartitionsDefinitionData,
)
from dagster._core.snap.node import GraphDefSnap, OpDefSnap
from dagster._core.storage.batch_asset_record_loader import BatchAssetRecordLoader
from dagster._core.utils import is_valid_email
from dagster._core.workspace.batch_asset_record_loader import BatchAssetRecordLoader
from dagster._core.workspace.permissions import Permissions
from dagster._utils.caching_instance_queryer import CachingInstanceQueryer

Expand Down Expand Up @@ -1140,11 +1140,7 @@ def resolve_assetPartitionStatuses(
graphene_info.context.instance,
asset_key,
self._dynamic_partitions_loader,
(
self._asset_record_loader.get_asset_record(asset_key)
if self._asset_record_loader
else None
),
self._asset_record_loader,
partitions_def,
)

Expand Down Expand Up @@ -1174,11 +1170,7 @@ def resolve_partitionStats(
graphene_info.context.instance,
asset_key,
self._dynamic_partitions_loader,
(
self._asset_record_loader.get_asset_record(self._external_asset_node.asset_key)
if self._asset_record_loader
else None
),
self._asset_record_loader,
(
self._external_asset_node.partitions_def_data.get_partitions_definition()
if self._external_asset_node.partitions_def_data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,8 @@ def test_default_partitions(self, graphql_context: WorkspaceRequestContext) -> N
# Test that when partition a is materialized that the materialized partitions are a
_create_partitioned_run(graphql_context, "partition_materialization_job", partition_key="a")

graphql_context.asset_record_loader.clear_cache()

selector = infer_job_selector(graphql_context, "partition_materialization_job")
result = execute_dagster_graphql(
graphql_context,
Expand All @@ -1386,6 +1388,8 @@ def test_default_partitions(self, graphql_context: WorkspaceRequestContext) -> N
# Test that when partition c is materialized that the materialized partitions are a, c
_create_partitioned_run(graphql_context, "partition_materialization_job", partition_key="c")

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_1D_ASSET_PARTITIONS,
Expand Down Expand Up @@ -1428,6 +1432,8 @@ def test_partition_stats(self, graphql_context: WorkspaceRequestContext):
tags={"fail": "true"},
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_PARTITION_STATS,
Expand All @@ -1449,6 +1455,8 @@ def test_partition_stats(self, graphql_context: WorkspaceRequestContext):
tags={"fail": "true"},
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_PARTITION_STATS,
Expand Down Expand Up @@ -1480,6 +1488,8 @@ def test_partition_stats(self, graphql_context: WorkspaceRequestContext):
assert not result.errors
assert result.data

graphql_context.asset_record_loader.clear_cache()

stats_result = execute_dagster_graphql(
graphql_context,
GET_PARTITION_STATS,
Expand Down Expand Up @@ -1592,6 +1602,8 @@ def _get_datetime_float(dt_str):
graphql_context, "time_partitioned_assets_job", partition_key=time_0
)

graphql_context.asset_record_loader.clear_cache()

selector = infer_job_selector(graphql_context, "time_partitioned_assets_job")
result = execute_dagster_graphql(
graphql_context,
Expand All @@ -1616,6 +1628,8 @@ def _get_datetime_float(dt_str):
graphql_context, "time_partitioned_assets_job", partition_key=time_2
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_1D_ASSET_PARTITIONS,
Expand All @@ -1642,6 +1656,8 @@ def _get_datetime_float(dt_str):
graphql_context, "time_partitioned_assets_job", partition_key=time_1
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_1D_ASSET_PARTITIONS,
Expand Down Expand Up @@ -1760,11 +1776,11 @@ def get_response_by_asset(response):
assert result["asset_1"]["latestRun"] is None
assert result["asset_1"]["latestMaterialization"] is None

graphql_context.asset_record_loader.clear_cache()

# Test with 1 run on all assets
first_run_id = _create_run(graphql_context, "failure_assets_job")

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_ASSET_LATEST_RUN_STATS,
Expand All @@ -1789,15 +1805,15 @@ def get_response_by_asset(response):
assert result["asset_3"]["latestRun"]["id"] == first_run_id
assert result["asset_3"]["latestMaterialization"] is None

graphql_context.asset_record_loader.clear_cache()

# Confirm that asset selection is respected
run_id = _create_run(
graphql_context,
"failure_assets_job",
asset_selection=[{"path": ["asset_3"]}],
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_ASSET_LATEST_RUN_STATS,
Expand Down Expand Up @@ -2145,6 +2161,9 @@ def _get_date_float(dt_str):
MultiPartitionKey({"date": partition_field[0], "ab": partition_field[1]}),
asset_selection=[AssetKey("multipartitions_1")],
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_2D_ASSET_PARTITIONS,
Expand Down Expand Up @@ -2262,6 +2281,9 @@ def _get_date_float(dt_str):
MultiPartitionKey({"date": partition_field[0], "ab": partition_field[1]}),
tags={"fail": "true"},
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_2D_ASSET_PARTITIONS,
Expand Down Expand Up @@ -2308,6 +2330,9 @@ def _get_date_float(dt_str):
MultiPartitionKey({"date": partition_field[0], "ab": partition_field[1]}),
tags={"fail": "true"},
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_2D_ASSET_PARTITIONS,
Expand Down Expand Up @@ -2337,6 +2362,9 @@ def _get_date_float(dt_str):
"multipartitions_fail_job",
MultiPartitionKey({"date": partition_field[0], "ab": partition_field[1]}),
)

graphql_context.asset_record_loader.clear_cache()

result = execute_dagster_graphql(
graphql_context,
GET_2D_ASSET_PARTITIONS,
Expand Down Expand Up @@ -2375,6 +2403,9 @@ def test_dynamic_dim_in_multipartitions_def(self, graphql_context: WorkspaceRequ
"dynamic_in_multipartitions_success_job",
MultiPartitionKey({"dynamic": "1", "static": "a"}),
)

graphql_context.asset_record_loader.clear_cache()

counter = Counter()
traced_counter.set(counter)
result = execute_dagster_graphql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import NamedTuple, Optional

import dagster._check as check
from dagster import EventLogEntry
from dagster._core.events import DagsterEventType
from dagster._core.events.log import DagsterEventType, EventLogEntry
from dagster._serdes.serdes import deserialize_value
from dagster._utils import datetime_as_float

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Iterable, Mapping, Optional, Sequence, Set
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Set

from dagster import (
DagsterInstance,
_check as check,
)
import dagster._check as check
from dagster._core.definitions.events import AssetKey
from dagster._core.events.log import EventLogEntry
from dagster._core.storage.event_log.base import AssetRecord
from dagster._core.instance import DagsterInstance

if TYPE_CHECKING:
from dagster._core.storage.event_log.base import AssetRecord


class BatchAssetRecordLoader:
Expand All @@ -17,21 +17,21 @@ class BatchAssetRecordLoader:
def __init__(self, instance: DagsterInstance, asset_keys: Iterable[AssetKey]):
self._instance = instance
self._unfetched_asset_keys: Set[AssetKey] = set(asset_keys)
self._asset_records: Mapping[AssetKey, Optional[AssetRecord]] = {}
self._asset_records: Mapping[AssetKey, Optional["AssetRecord"]] = {}

def add_asset_keys(self, asset_keys: Iterable[AssetKey]):
unfetched_asset_keys = set(asset_keys).difference(self._asset_records.keys())
self._unfetched_asset_keys = self._unfetched_asset_keys.union(unfetched_asset_keys)

def get_asset_record(self, asset_key: AssetKey) -> Optional[AssetRecord]:
def get_asset_record(self, asset_key: AssetKey) -> Optional["AssetRecord"]:
if asset_key not in self._asset_records and asset_key not in self._unfetched_asset_keys:
check.failed(
f"Asset key {asset_key} not recognized for this loader. Expected one of:"
f" {self._unfetched_asset_keys.union(self._asset_records.keys())}"
)

if asset_key in self._unfetched_asset_keys:
self._fetch()
self.fetch()

return self._asset_records.get(asset_key)

Expand All @@ -40,7 +40,10 @@ def clear_cache(self):
self._unfetched_asset_keys = self._unfetched_asset_keys.union(self._asset_records.keys())
self._asset_records = {}

def get_asset_records(self, asset_keys: Sequence[AssetKey]) -> Sequence[AssetRecord]:
def has_cached_asset_record(self, asset_key: AssetKey):
return asset_key in self._asset_records

def get_asset_records(self, asset_keys: Sequence[AssetKey]) -> Sequence["AssetRecord"]:
records = [self.get_asset_record(asset_key) for asset_key in asset_keys]
return [record for record in records if record]

Expand All @@ -65,7 +68,7 @@ def get_latest_observation_for_asset_key(self, asset_key: AssetKey) -> Optional[

return asset_record.asset_entry.last_observation

def _fetch(self) -> None:
def fetch(self) -> None:
if not self._unfetched_asset_keys:
return

Expand Down
16 changes: 16 additions & 0 deletions python_modules/dagster/dagster/_core/storage/event_log/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class AssetEntry(
# This is an optional field which can be used for more performant last observation
# queries if the underlying storage supports it
("last_observation_record", Optional[EventLogRecord]),
("last_planned_materialization_storage_id", Optional[int]),
("last_planned_materialization_run_id", Optional[str]),
],
)
):
Expand All @@ -71,6 +73,8 @@ def __new__(
asset_details: Optional[AssetDetails] = None,
cached_status: Optional["AssetStatusCacheValue"] = None,
last_observation_record: Optional[EventLogRecord] = None,
last_planned_materialization_storage_id: Optional[int] = None,
last_planned_materialization_run_id: Optional[str] = None,
):
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

Expand All @@ -88,6 +92,14 @@ def __new__(
last_observation_record=check.opt_inst_param(
last_observation_record, "last_observation_record", EventLogRecord
),
last_planned_materialization_storage_id=check.opt_int_param(
last_planned_materialization_storage_id,
"last_planned_materialization_storage_id",
),
last_planned_materialization_run_id=check.opt_str_param(
last_planned_materialization_run_id,
"last_planned_materialization_run_id",
),
)

@property
Expand Down Expand Up @@ -293,6 +305,10 @@ def get_asset_records(
) -> Sequence[AssetRecord]:
pass

@property
def asset_records_have_last_planned_materialization_storage_id(self) -> bool:
return False

@abstractmethod
def has_asset_key(self, asset_key: AssetKey) -> bool:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def _construct_asset_record_from_row(
if can_cache_asset_status_data
else None
),
last_planned_materialization_storage_id=None,
),
)
else:
Expand Down
Loading

0 comments on commit bc81704

Please sign in to comment.