Skip to content

Commit

Permalink
coerce tag-based concurrency keys into pools
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Jan 15, 2025
1 parent b3eaae4 commit a120802
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def asset(
e.g. `team:finops`.
kinds (Optional[Set[str]]): A list of strings representing the kinds of the asset. These
will be made visible in the Dagster UI.
pool (Optional[str]): A string that identifies the concurrency limit group that governs
this asset's execution.
pool (Optional[str]): A string that identifies the concurrency pool that governs this asset's execution.
non_argument_deps (Optional[Union[Set[AssetKey], Set[str]]]): Deprecated, use deps instead.
Set of asset keys that are upstream dependencies, but do not pass an input to the asset.
Hidden parameter not exposed in the decorator signature, but passed in kwargs.
Expand Down Expand Up @@ -610,8 +609,8 @@ def multi_asset(
by this function.
check_specs (Optional[Sequence[AssetCheckSpec]]): Specs for asset checks that
execute in the decorated function after materializing the assets.
pool (Optional[str]): A string that identifies the concurrency limit group that
governs this multi-asset's execution.
pool (Optional[str]): A string that identifies the concurrency pool that governs this
multi-asset's execution.
non_argument_deps (Optional[Union[Set[AssetKey], Set[str]]]): Deprecated, use deps instead.
Set of asset keys that are upstream dependencies, but do not pass an input to the
multi_asset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DagsterInvalidInvocationError,
DagsterInvariantViolationError,
)
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG
from dagster._core.types.dagster_type import DagsterType, DagsterTypeKind
from dagster._utils import IHasInternalInit
from dagster._utils.warnings import normalize_renamed_param, preview_warning
Expand Down Expand Up @@ -295,12 +296,13 @@ def with_retry_policy(self, retry_policy: RetryPolicy) -> "PendingNodeInvocation

@property
def pool(self) -> Optional[str]:
"""Optional[str]: The concurrency group for this op."""
return self._pool
"""Optional[str]: The concurrency pool for this op."""
# fallback to fetching from tags for backwards compatibility
return self._pool if self._pool else self.tags.get(GLOBAL_CONCURRENCY_TAG)

@property
def pools(self) -> set[str]:
"""Optional[str]: The concurrency group for this op."""
"""Optional[str]: The concurrency pools for this op node."""
return {self._pool} if self._pool else set()

def is_from_decorator(self) -> bool:
Expand Down
9 changes: 4 additions & 5 deletions python_modules/dagster/dagster/_core/execution/plan/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dagster._core.execution.plan.state import KnownExecutionState
from dagster._core.execution.plan.step import ExecutionStep
from dagster._core.execution.retries import RetryMode, RetryState
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG, PRIORITY_TAG
from dagster._core.storage.tags import PRIORITY_TAG
from dagster._utils.interrupts import pop_captured_interrupt
from dagster._utils.tags import TagConcurrencyLimitsCounter

Expand Down Expand Up @@ -339,14 +339,13 @@ def get_steps_to_execute(
run_scoped_concurrency_limits_counter.update_counters_with_launched_item(step)

# fallback to fetching from tags for backwards compatibility
pool = step.pool if step.pool else step.tags.get(GLOBAL_CONCURRENCY_TAG)
if pool and self._instance_concurrency_context:
if step.pool and self._instance_concurrency_context:
try:
step_priority = int(step.tags.get(PRIORITY_TAG, 0))
except ValueError:
step_priority = 0

if not self._instance_concurrency_context.claim(pool, step.key, step_priority):
if not self._instance_concurrency_context.claim(step.pool, step.key, step_priority):
continue

batch.append(step)
Expand Down Expand Up @@ -643,7 +642,7 @@ def concurrency_event_iterator(
):
step = self.get_step_by_key(step_key)
step_context = plan_context.for_step(step)
pool = cast(str, step.pool)
pool = check.inst(step.pool, str)
self._messaged_concurrency_slots[step_key] = time.time()
is_initial_message = last_messaged_timestamp is None
yield DagsterEvent.step_concurrency_blocked(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ def compute_run_op_concurrency_info_for_snapshot(
root_step_keys = set(
[step_key for step_key, deps in plan_snapshot.step_deps.items() if not deps]
)
concurrency_key_counts: Mapping[str, int] = defaultdict(int)
pool_counts: Mapping[str, int] = defaultdict(int)
has_unconstrained_root_nodes = False
for step in plan_snapshot.steps:
if step.key not in root_step_keys:
continue
if step.concurrency_key is None:
if step.pool is None:
has_unconstrained_root_nodes = True
else:
concurrency_key_counts[step.concurrency_key] += 1
pool_counts[step.pool] += 1

if len(concurrency_key_counts) == 0:
if len(pool_counts) == 0:
return None

return RunOpConcurrency(
root_key_counts=dict(concurrency_key_counts),
root_key_counts=dict(pool_counts),
has_unconstrained_root_nodes=has_unconstrained_root_nodes,
)

Expand All @@ -50,12 +50,15 @@ def __init__(
in_progress_run_records: Sequence[RunRecord],
slot_count_offset: int = 0,
):
self._root_concurrency_keys_by_run = {}
self._concurrency_info_by_key = {}
self._launched_concurrency_key_counts = defaultdict(int)
self._in_progress_concurrency_key_counts = defaultdict(int)
self._root_pools_by_run = {}
self._concurrency_info_by_pool = {}
self._launched_pool_counts = defaultdict(int)
self._in_progress_pool_counts = defaultdict(int)
self._slot_count_offset = slot_count_offset
self._started_run_concurrency_keys_allotted_seconds = int(
self._in_progress_run_ids: set[str] = set(
[record.dagster_run.run_id for record in in_progress_run_records]
)
self._started_run_pools_allotted_seconds = int(
os.getenv("DAGSTER_OP_CONCURRENCY_KEYS_ALLOTTED_FOR_STARTED_RUN_SECONDS", "5")
)

Expand All @@ -69,45 +72,45 @@ def __init__(
def _fetch_concurrency_info(self, instance: DagsterInstance, queued_runs: Sequence[DagsterRun]):
# fetch all the concurrency slot information for the root concurrency keys of all the queued
# runs
all_run_concurrency_keys = set()
all_run_pools = set()

configured_concurrency_keys = instance.event_log_storage.get_concurrency_keys()
configured_pools = instance.event_log_storage.get_concurrency_keys()
for run in queued_runs:
if run.run_op_concurrency:
all_run_concurrency_keys.update(run.run_op_concurrency.root_key_counts.keys())
all_run_pools.update(run.run_op_concurrency.root_key_counts.keys())

for key in all_run_concurrency_keys:
for key in all_run_pools:
if key is None:
continue

if key not in configured_concurrency_keys:
if key not in configured_pools:
instance.event_log_storage.initialize_concurrency_limit_to_default(key)

self._concurrency_info_by_key[key] = instance.event_log_storage.get_concurrency_info(
self._concurrency_info_by_pool[key] = instance.event_log_storage.get_concurrency_info(
key
)

def _should_allocate_slots_for_root_concurrency_keys(self, record: RunRecord):
def _should_allocate_slots_for_root_pools(self, record: RunRecord):
status = record.dagster_run.status
if status == DagsterRunStatus.STARTING:
return True
if status != DagsterRunStatus.STARTED or not record.start_time:
return False
time_elapsed = get_current_timestamp() - record.start_time
if time_elapsed < self._started_run_concurrency_keys_allotted_seconds:
if time_elapsed < self._started_run_pools_allotted_seconds:
return True

def _process_in_progress_runs(self, in_progress_records: Sequence[RunRecord]):
for record in in_progress_records:
if (
self._should_allocate_slots_for_root_concurrency_keys(record)
self._should_allocate_slots_for_root_pools(record)
and record.dagster_run.run_op_concurrency
):
for (
concurrency_key,
pool,
count,
) in record.dagster_run.run_op_concurrency.root_key_counts.items():
self._in_progress_concurrency_key_counts[concurrency_key] += count
self._in_progress_pool_counts[pool] += count

def is_blocked(self, run: DagsterRun) -> bool:
# if any of the ops in the run can make progress (not blocked by concurrency keys), we
Expand All @@ -116,17 +119,17 @@ def is_blocked(self, run: DagsterRun) -> bool:
# if there exists a root node that is not concurrency blocked, we should dequeue.
return False

for concurrency_key in run.run_op_concurrency.root_key_counts.keys():
if concurrency_key not in self._concurrency_info_by_key:
for pool in run.run_op_concurrency.root_key_counts.keys():
if pool not in self._concurrency_info_by_pool:
# there is no concurrency limit set for this key, we should dequeue
return False

key_info = self._concurrency_info_by_key[concurrency_key]
key_info = self._concurrency_info_by_pool[pool]
available_count = (
key_info.slot_count
- len(key_info.pending_steps)
- self._launched_concurrency_key_counts[concurrency_key]
- self._in_progress_concurrency_key_counts[concurrency_key]
- self._launched_pool_counts[pool]
- self._in_progress_pool_counts[pool]
)
if available_count > -1 * self._slot_count_offset:
# there exists a root concurrency key that is not blocked, we should dequeue
Expand All @@ -140,25 +143,25 @@ def get_blocked_run_debug_info(self, run: DagsterRun) -> Mapping:
return {}

log_info = {}
for concurrency_key in run.run_op_concurrency.root_key_counts.keys():
concurrency_info = self._concurrency_info_by_key.get(concurrency_key)
for pool in run.run_op_concurrency.root_key_counts.keys():
concurrency_info = self._concurrency_info_by_pool.get(pool)
if not concurrency_info:
continue

log_info[concurrency_key] = {
log_info[pool] = {
"slot_count": concurrency_info.slot_count,
"pending_step_count": len(concurrency_info.pending_steps),
"pending_step_run_ids": list(
{step.run_id for step in concurrency_info.pending_steps}
),
"launched_count": self._launched_concurrency_key_counts[concurrency_key],
"in_progress_count": self._in_progress_concurrency_key_counts[concurrency_key],
"launched_count": self._launched_pool_counts[pool],
"in_progress_count": self._in_progress_pool_counts[pool],
}
return log_info

def update_counters_with_launched_item(self, run: DagsterRun):
if not run.run_op_concurrency:
return
for concurrency_key, count in run.run_op_concurrency.root_key_counts.items():
if concurrency_key:
self._launched_concurrency_key_counts[concurrency_key] += count
for pool, count in run.run_op_concurrency.root_key_counts.items():
if pool:
self._launched_pool_counts[pool] += count
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
UnresolvedCollectExecutionStep,
UnresolvedMappedExecutionStep,
)
from dagster._core.storage.tags import GLOBAL_CONCURRENCY_TAG
from dagster._serdes import create_snapshot_id, whitelist_for_serdes
from dagster._utils.error import SerializableErrorInfo

Expand Down Expand Up @@ -184,16 +183,6 @@ def __new__(
pool=check.opt_str_param(pool, "pool"),
)

@property
def concurrency_key(self):
# Separate property in case the snapshot was created before pool was added as
# a separate argument from tags
if self.pool:
return self.pool
if not self.tags:
return None
return self.tags.get(GLOBAL_CONCURRENCY_TAG)


@whitelist_for_serdes
class ExecutionStepInputSnap(
Expand Down

0 comments on commit a120802

Please sign in to comment.