Skip to content

Commit

Permalink
add default support for backing off of checking concurrency claims (#…
Browse files Browse the repository at this point in the history
…17109)

## Summary & Motivation
Instead of checking every second, we can check every 10 seconds (backing
off exponentially), to grab concurrency slots. Will relieve any db
pressure just from checking contentious rows. Shouldn't affect
starvation, since we still have priority checks.

## How I Tested These Changes
BK
  • Loading branch information
prha authored Oct 12, 2023
1 parent 7ee29e4 commit d670828
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from dagster._core.instance import DagsterInstance

DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL = 1
INITIAL_INTERVAL_VALUE = 1
STEP_UP_BASE = 1.1
MAX_CONCURRENCY_CLAIM_BLOCKED_INTERVAL = 15


class InstanceConcurrencyContext:
Expand All @@ -34,6 +36,7 @@ def __init__(self, instance: DagsterInstance, run_id: str):
self._run_id = run_id
self._global_concurrency_keys = None
self._pending_timeouts = defaultdict(float)
self._pending_claim_counts = defaultdict(int)
self._pending_claims = set()
self._claims = set()

Expand All @@ -54,6 +57,7 @@ def __exit__(

for step_key in to_clear:
del self._pending_timeouts[step_key]
del self._pending_claim_counts[step_key]
self._pending_claims.remove(step_key)

self._context_guard = False
Expand Down Expand Up @@ -89,12 +93,11 @@ def claim(self, concurrency_key: str, step_key: str, priority: int = 0):
)

if not claim_status.is_claimed:
interval = (
claim_status.sleep_interval
if claim_status.sleep_interval
else DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL
interval = _calculate_timeout_interval(
claim_status.sleep_interval, self._pending_claim_counts[step_key]
)
self._pending_timeouts[step_key] = time.time() + interval
self._pending_claim_counts[step_key] += 1
return False

if step_key in self._pending_claims:
Expand Down Expand Up @@ -122,3 +125,17 @@ def free_step(self, step_key) -> None:

self._instance.event_log_storage.free_concurrency_slot_for_step(self._run_id, step_key)
self._claims.remove(step_key)


def _calculate_timeout_interval(sleep_interval: Optional[float], pending_claim_count: int) -> float:
if sleep_interval is not None:
return sleep_interval

if pending_claim_count > 30:
# with the current values, we will always hit the max by the 30th claim attempt
return MAX_CONCURRENCY_CLAIM_BLOCKED_INTERVAL

# increase the step up value exponentially, up to a max of 15 seconds (starting from 0)
step_up_value = STEP_UP_BASE**pending_claim_count - 1
interval = INITIAL_INTERVAL_VALUE + step_up_value
return min(MAX_CONCURRENCY_CLAIM_BLOCKED_INTERVAL, interval)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
from dagster._core.execution.plan.instance_concurrency_context import (
DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL,
INITIAL_INTERVAL_VALUE,
STEP_UP_BASE,
InstanceConcurrencyContext,
)
from dagster._core.utils import make_new_run_id
Expand Down Expand Up @@ -82,11 +83,37 @@ def test_default_interval(concurrency_instance):
# we have not waited long enough to query the db again
assert concurrency_instance.event_log_storage.get_check_calls("b") == call_count

time.sleep(DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL)
time.sleep(INITIAL_INTERVAL_VALUE)
context.claim("foo", "b")
assert concurrency_instance.event_log_storage.get_check_calls("b") == call_count + 1


def test_backoff_interval(concurrency_instance):
run_id = make_new_run_id()
concurrency_instance.event_log_storage.set_concurrency_slots("foo", 1)

with InstanceConcurrencyContext(concurrency_instance, run_id) as context:
assert context.claim("foo", "a")
assert not context.claim("foo", "b")
call_count = concurrency_instance.event_log_storage.get_check_calls("b")

context.claim("foo", "b")
# we have not waited long enough to query the db again
assert concurrency_instance.event_log_storage.get_check_calls("b") == call_count

time.sleep(INITIAL_INTERVAL_VALUE)
context.claim("foo", "b")
assert concurrency_instance.event_log_storage.get_check_calls("b") == call_count + 1

# sleeping another second will not incur another check call, there's an exponential backoff
time.sleep(INITIAL_INTERVAL_VALUE)
context.claim("foo", "b")
assert concurrency_instance.event_log_storage.get_check_calls("b") == call_count + 1
time.sleep(STEP_UP_BASE - INITIAL_INTERVAL_VALUE)
context.claim("foo", "b")
assert concurrency_instance.event_log_storage.get_check_calls("b") == call_count + 2


def test_custom_interval(concurrency_custom_sleep_instance):
run_id = make_new_run_id()
storage = concurrency_custom_sleep_instance.event_log_storage
Expand All @@ -101,13 +128,13 @@ def test_custom_interval(concurrency_custom_sleep_instance):
# we have not waited long enough to query the db again
assert storage.get_check_calls("b") == call_count

assert DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL < CUSTOM_SLEEP_INTERVAL
time.sleep(DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL)
assert INITIAL_INTERVAL_VALUE < CUSTOM_SLEEP_INTERVAL
time.sleep(INITIAL_INTERVAL_VALUE)
context.claim("foo", "b")
# we have waited the default interval, but not the custom interval
assert storage.get_check_calls("b") == call_count

interval_to_custom = CUSTOM_SLEEP_INTERVAL - DEFAULT_CONCURRENCY_CLAIM_BLOCKED_INTERVAL
interval_to_custom = CUSTOM_SLEEP_INTERVAL - INITIAL_INTERVAL_VALUE
time.sleep(interval_to_custom)
context.claim("foo", "b")
assert storage.get_check_calls("b") == call_count + 1

0 comments on commit d670828

Please sign in to comment.