Skip to content

Commit

Permalink
add support for run-granularity op concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Jan 10, 2025
1 parent 6927f22 commit 8782d96
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 65 deletions.
7 changes: 7 additions & 0 deletions python_modules/dagster/dagster/_core/instance/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def validate_concurrency_config(dagster_config_dict: Mapping[str, Any]):
[],
None,
)
granularity = concurrency_config.get("pools", {}).get("granularity")
if granularity and granularity not in ["run", "op"]:
raise DagsterInvalidConfigError(
f"Found value `{granularity}` for `granularity`, Expected value 'run' or 'op'.",
[],
None,
)

if "run_queue" in dagster_config_dict:
verify_config_match(
Expand Down
190 changes: 134 additions & 56 deletions python_modules/dagster/dagster/_core/op_concurrency_limits_counter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import os
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Optional
from typing import TYPE_CHECKING, Optional

from dagster._core.instance import DagsterInstance
from dagster._core.run_coordinator.queued_run_coordinator import PoolGranularity
from dagster._core.snap.execution_plan_snapshot import ExecutionPlanSnapshot
from dagster._core.storage.dagster_run import (
IN_PROGRESS_RUN_STATUSES,
DagsterRun,
DagsterRunStatus,
RunOpConcurrency,
RunRecord,
)
from dagster._time import get_current_timestamp

if TYPE_CHECKING:
from dagster._utils.concurrency import ConcurrencyKeyInfo


def compute_run_op_concurrency_info_for_snapshot(
plan_snapshot: ExecutionPlanSnapshot,
Expand All @@ -23,21 +28,26 @@ def compute_run_op_concurrency_info_for_snapshot(
root_step_keys = set(
[step_key for step_key, deps in plan_snapshot.step_deps.items() if not deps]
)
pool_counts: Mapping[str, int] = defaultdict(int)
root_pool_counts: Mapping[str, int] = defaultdict(int)
all_pools: set[str] = set()
has_unconstrained_root_nodes = False
for step in plan_snapshot.steps:
if step.key not in root_step_keys:
continue
if step.pool is None:
if step.pool is None and step.key in root_step_keys:
has_unconstrained_root_nodes = True
elif step.pool is None:
continue
elif step.key in root_step_keys:
root_pool_counts[step.pool] += 1
all_pools.add(step.pool)
else:
pool_counts[step.pool] += 1
all_pools.add(step.pool)

if len(pool_counts) == 0:
if len(all_pools) == 0:
return None

return RunOpConcurrency(
root_key_counts=dict(pool_counts),
all_pools=all_pools,
root_key_counts=dict(root_pool_counts),
has_unconstrained_root_nodes=has_unconstrained_root_nodes,
)

Expand All @@ -49,12 +59,14 @@ def __init__(
runs: Sequence[DagsterRun],
in_progress_run_records: Sequence[RunRecord],
slot_count_offset: int = 0,
pool_granularity: PoolGranularity = PoolGranularity.OP,
):
self._root_pools_by_run = {}
self._concurrency_info_by_pool = {}
self._concurrency_info_by_key: dict[str, ConcurrencyKeyInfo] = {}
self._launched_pool_counts = defaultdict(int)
self._in_progress_pool_counts = defaultdict(int)
self._slot_count_offset = slot_count_offset
self._pool_granularity = pool_granularity
self._in_progress_run_ids: set[str] = set(
[record.dagster_run.run_id for record in in_progress_run_records]
)
Expand All @@ -66,89 +78,158 @@ def __init__(
# priority order
self._fetch_concurrency_info(instance, runs)

# fetch all the outstanding concurrency keys for in-progress runs
# fetch all the outstanding pools for in-progress runs
self._process_in_progress_runs(in_progress_run_records)

def _fetch_concurrency_info(self, instance: DagsterInstance, queued_runs: Sequence[DagsterRun]):
# fetch all the concurrency slot information for the root concurrency keys of all the queued
# runs
all_run_pools = set()
# fetch all the concurrency slot information for all the queued runs
all_pools = set()

configured_pools = instance.event_log_storage.get_concurrency_keys()
for run in queued_runs:
if run.run_op_concurrency:
all_run_pools.update(run.run_op_concurrency.root_key_counts.keys())

for key in all_run_pools:
if key is None:
# if using run granularity, consider all the concurrency keys required by the run
# if using op granularity, consider only the root keys
run_pools = (
run.run_op_concurrency.root_key_counts.keys()
if self._pool_granularity == PoolGranularity.OP
else run.run_op_concurrency.all_pools or []
)
all_pools.update(run_pools)

for pool in all_pools:
if pool is None:
continue

if key not in configured_pools:
instance.event_log_storage.initialize_concurrency_limit_to_default(key)
if pool not in configured_pools:
instance.event_log_storage.initialize_concurrency_limit_to_default(pool)

self._concurrency_info_by_pool[key] = instance.event_log_storage.get_concurrency_info(
key
self._concurrency_info_by_key[pool] = instance.event_log_storage.get_concurrency_info(
pool
)

def _should_allocate_slots_for_root_pools(self, record: RunRecord):
def _should_allocate_slots_for_in_progress_run(self, record: RunRecord):
if not record.dagster_run.run_op_concurrency:
return False

status = record.dagster_run.status
if status not in IN_PROGRESS_RUN_STATUSES:
return False

if self._pool_granularity == PoolGranularity.RUN:
return True

if status == DagsterRunStatus.STARTING:
return True

if status != DagsterRunStatus.STARTED or not record.start_time:
return False

time_elapsed = get_current_timestamp() - record.start_time
if time_elapsed < self._started_run_pools_allotted_seconds:
return True

def _slot_counts_for_run(self, run: DagsterRun) -> Mapping[str, int]:
if not run.run_op_concurrency:
return {}

if self._pool_granularity == PoolGranularity.OP:
return {**run.run_op_concurrency.root_key_counts}

else:
assert self._pool_granularity == PoolGranularity.RUN
return {pool: 1 for pool in run.run_op_concurrency.all_pools or []}

def _process_in_progress_runs(self, in_progress_records: Sequence[RunRecord]):
for record in in_progress_records:
if (
self._should_allocate_slots_for_root_pools(record)
and record.dagster_run.run_op_concurrency
):
for (
pool,
count,
) in record.dagster_run.run_op_concurrency.root_key_counts.items():
self._in_progress_pool_counts[pool] += count
if not self._should_allocate_slots_for_in_progress_run(record):
continue

for pool, count in self._slot_counts_for_run(record.dagster_run).items():
self._in_progress_pool_counts[pool] += count

def is_blocked(self, run: DagsterRun) -> bool:
# if any of the ops in the run can make progress (not blocked by concurrency keys), we
# should dequeue
if not run.run_op_concurrency or run.run_op_concurrency.has_unconstrained_root_nodes:
# if there exists a root node that is not concurrency blocked, we should dequeue.
if not run.run_op_concurrency:
return False

for pool in run.run_op_concurrency.root_key_counts.keys():
if pool not in self._concurrency_info_by_pool:
# there is no concurrency limit set for this key, we should dequeue
return False

key_info = self._concurrency_info_by_pool[pool]
available_count = (
key_info.slot_count
- len(key_info.pending_steps)
- self._launched_pool_counts[pool]
- self._in_progress_pool_counts[pool]
)
if available_count > -1 * self._slot_count_offset:
# there exists a root concurrency key that is not blocked, we should dequeue
return False
if (
self._pool_granularity == PoolGranularity.OP
and run.run_op_concurrency.has_unconstrained_root_nodes
):
# if the granularity is at the op level and there exists a root node that is not
# concurrency blocked, we should dequeue.
return False

# if we reached here, then every root concurrency key is blocked, so we should not dequeue
return True
if self._pool_granularity == PoolGranularity.OP:
# we just need to check all of the root concurrency keys, instead of all the concurrency keys
# in the run
for pool in run.run_op_concurrency.root_key_counts.keys():
if pool not in self._concurrency_info_by_key:
# there is no concurrency limit set for this key, we should dequeue
return False

key_info = self._concurrency_info_by_key[pool]
unaccounted_occupied_slots = [
pending_step
for pending_step in key_info.pending_steps
if pending_step.run_id not in self._in_progress_run_ids
]
available_count = (
key_info.slot_count
- len(unaccounted_occupied_slots)
- self._launched_pool_counts[pool]
- self._in_progress_pool_counts[pool]
)
if available_count + self._slot_count_offset > 0:
# there exists a root concurrency key that is not blocked, we should dequeue
return False

# if we reached here, then every root concurrency key is blocked, so we should not dequeue
return True

else:
assert self._pool_granularity == PoolGranularity.RUN

# if the granularity is at the run level, we should check if any of the concurrency
# keys are blocked
for pool in run.run_op_concurrency.all_pools or []:
if pool not in self._concurrency_info_by_key:
# there is no concurrency limit set for this key
continue

key_info = self._concurrency_info_by_key[pool]
unaccounted_occupied_slots = [
pending_step
for pending_step in key_info.pending_steps
if pending_step.run_id not in self._in_progress_run_ids
]
available_count = (
key_info.slot_count
- len(unaccounted_occupied_slots)
- self._launched_pool_counts[pool]
- self._in_progress_pool_counts[pool]
)
if available_count + self._slot_count_offset <= 0:
return True

# if we reached here then there is at least one available slot for every single concurrency key
# required by this run, so we should dequeue
return False

def get_blocked_run_debug_info(self, run: DagsterRun) -> Mapping:
if not run.run_op_concurrency:
return {}

log_info = {}
for pool in run.run_op_concurrency.root_key_counts.keys():
concurrency_info = self._concurrency_info_by_pool.get(pool)
concurrency_info = self._concurrency_info_by_key.get(pool)
if not concurrency_info:
continue

log_info[pool] = {
"granularity": self._pool_granularity.value,
"slot_count": concurrency_info.slot_count,
"pending_step_count": len(concurrency_info.pending_steps),
"pending_step_run_ids": list(
Expand All @@ -160,8 +241,5 @@ def get_blocked_run_debug_info(self, run: DagsterRun) -> Mapping:
return log_info

def update_counters_with_launched_item(self, run: DagsterRun):
if not run.run_op_concurrency:
return
for pool, count in run.run_op_concurrency.root_key_counts.items():
if pool:
self._launched_pool_counts[pool] += count
for pool, count in self._slot_counts_for_run(run).items():
self._launched_pool_counts[pool] += count
Loading

0 comments on commit 8782d96

Please sign in to comment.