Skip to content

Commit

Permalink
Optimize querying with latest_run_required_tags (dagster-io#20333)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Previously, if you had many upstream updated partitions, we'd need to do
an individual "latest_materialization_record" query for each of them.

Now, we take advantage of the fact that we do a single batched fetch of
the latest materialization storage ids for each upstream partition,
meaning we know the storage ids of all the records we need to fetch.
From there, we can do a single call (per upstream asset, assuming there
aren't more than 10,000 updated partitions) to fetch the records for
each of those storage ids.

After doing some perf-checking, this is not actually a silver bullet, as
if there is only a single upstream partition that needs to be checked,
the perf gain is pretty unimpressive. It's basically just that the query
time no longer scales very quickly in relation to the number of upstream
updated partitions. So before if there were 100 upstream updated
partitions that needed to get queried, that'd take like 20 seconds, and
now it'll still take ~1 second.

## How I Tested These Changes
  • Loading branch information
OwenKephart authored and nikomancy committed May 1, 2024
1 parent af1bdac commit 72cac0f
Showing 1 changed file with 31 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import datetime
import os
from abc import ABC, abstractmethod, abstractproperty
from collections import defaultdict
from typing import (
TYPE_CHECKING,
AbstractSet,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Expand Down Expand Up @@ -35,6 +37,7 @@
get_time_partitions_def,
)
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.event_api import AssetRecordsFilter
from dagster._core.storage.dagster_run import IN_PROGRESS_RUN_STATUSES, RunsFilter
from dagster._core.storage.tags import AUTO_MATERIALIZE_TAG
from dagster._serdes.serdes import (
Expand Down Expand Up @@ -445,22 +448,42 @@ def passes(
return asset_partitions

will_update_asset_partitions: Set[AssetKeyPartitionKey] = set()
storage_ids_to_fetch_by_key: Dict[AssetKey, List[int]] = defaultdict(list)

asset_partitions_by_latest_run_id: Dict[str, Set[AssetKeyPartitionKey]] = defaultdict(set)
for asset_partition in asset_partitions:
if context.will_update_asset_partition(asset_partition):
will_update_asset_partitions.add(asset_partition)
else:
record = context.instance_queryer.get_latest_materialization_or_observation_record(
asset_partition
latest_storage_id = (
context.instance_queryer.get_latest_materialization_or_observation_storage_id(
asset_partition=asset_partition
)
)
if latest_storage_id is not None:
storage_ids_to_fetch_by_key[asset_partition.asset_key].append(latest_storage_id)

if record is None:
raise RuntimeError(
f"No materialization record found for asset partition {asset_partition}"
)
asset_partitions_by_latest_run_id: Dict[str, Set[AssetKeyPartitionKey]] = defaultdict(set)

asset_partitions_by_latest_run_id[record.run_id].add(asset_partition)
step = int(os.getenv("DAGSTER_ASSET_DAEMON_RUN_TAGS_EVENT_FETCH_LIMIT", "1000"))

for asset_key, storage_ids_to_fetch in storage_ids_to_fetch_by_key.items():
for i in range(0, len(storage_ids_to_fetch), step):
storage_ids = storage_ids_to_fetch[i : i + step]
fetch_records = (
context.instance_queryer.instance.fetch_observations
if context.asset_graph.get(asset_key).is_observable
else context.instance_queryer.instance.fetch_materializations
)
for record in fetch_records(
records_filter=AssetRecordsFilter(
asset_key=asset_key,
storage_ids=storage_ids,
),
limit=step,
).records:
asset_partitions_by_latest_run_id[record.run_id].add(
AssetKeyPartitionKey(asset_key, record.partition_key)
)

if len(asset_partitions_by_latest_run_id) > 0:
run_ids_with_required_tags = context.instance_queryer.instance.get_run_ids(
Expand Down

0 comments on commit 72cac0f

Please sign in to comment.