Skip to content

Commit

Permalink
separate configured limit from slot count
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Jan 17, 2025
1 parent 647d5df commit 222005b
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 69 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions js_modules/dagster-ui/packages/ui-core/src/graphql/types.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class GrapheneConcurrencyKeyInfo(graphene.ObjectType):
pendingStepRunIds = non_null_list(graphene.String)
assignedStepCount = graphene.NonNull(graphene.Int)
assignedStepRunIds = non_null_list(graphene.String)
configuredLimit = graphene.Int()

class Meta:
name = "ConcurrencyKeyInfo"
Expand Down Expand Up @@ -193,6 +194,9 @@ def resolve_assignedStepCount(self, graphene_info: ResolveInfo):
def resolve_assignedStepRunIds(self, graphene_info: ResolveInfo):
return list(self._get_concurrency_key_info(graphene_info).assigned_run_ids)

def resolve_configuredLimit(self, graphene_info: ResolveInfo):
return self._get_concurrency_key_info(graphene_info).configured_limit


class GrapheneRunQueueConfig(graphene.ObjectType):
maxConcurrentRuns = graphene.NonNull(graphene.Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,35 @@ def _sqlite_asset_instance():

return MarkedManager(_sqlite_asset_instance, [Marks.asset_aware_instance])

@staticmethod
def default_concurrency_sqlite_instance():
@contextmanager
def _sqlite_with_default_concurrency_instance():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
temp_dir=temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
},
"run_coordinator": {
"module": "dagster._core.run_coordinator.queued_run_coordinator",
"class": "QueuedRunCoordinator",
},
"concurrency": {
"default_op_concurrency_limit": 1,
},
},
) as instance:
yield instance

return MarkedManager(
_sqlite_with_default_concurrency_instance,
[Marks.sqlite_instance, Marks.queued_run_coordinator],
)


class EnvironmentManagers:
@staticmethod
Expand Down Expand Up @@ -556,6 +585,16 @@ def sqlite_with_default_run_launcher_code_server_cli_env(
test_id="sqlite_with_default_run_launcher_code_server_cli_env",
)

@staticmethod
def sqlite_with_default_concurrency_managed_grpc_env(
target=None, location_name="test_location"
):
return GraphQLContextVariant(
InstanceManagers.default_concurrency_sqlite_instance(),
EnvironmentManagers.managed_grpc(target, location_name),
test_id="sqlite_with_default_concurrency_managed_grpc_env",
)

@staticmethod
def postgres_with_default_run_launcher_managed_grpc_env(
target=None, location_name="test_location"
Expand Down Expand Up @@ -662,6 +701,7 @@ def all_variants():
GraphQLContextVariant.non_launchable_postgres_instance_managed_grpc_env(),
GraphQLContextVariant.non_launchable_postgres_instance_lazy_repository(),
GraphQLContextVariant.consolidated_sqlite_instance_managed_grpc_env(),
GraphQLContextVariant.sqlite_with_default_concurrency_managed_grpc_env(),
]

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,33 @@
"""

GET_CONCURRENCY_LIMITS_QUERY = """
query InstanceConcurrencyLimitsQuery {
query InstanceConcurrencyLimitsQuery($concurrencyKey: String!) {
instance {
concurrencyLimit(concurrencyKey: $concurrencyKey) {
concurrencyKey
slotCount
activeSlotCount
activeRunIds
claimedSlots {
runId
stepKey
}
pendingSteps {
runId
stepKey
enqueuedTimestamp
assignedTimestamp
priority
}
configuredLimit
}
}
}
"""

ALL_CONCURRENCY_LIMITS_QUERY = """
query AllConcurrencyLimitsQuery {
instance {
concurrencyLimits {
concurrencyKey
Expand All @@ -37,6 +63,7 @@
assignedTimestamp
priority
}
configuredLimit
}
}
}
Expand Down Expand Up @@ -67,6 +94,40 @@
)


def fetch_concurrency_limit(graphql_context, key: str):
results = execute_dagster_graphql(
graphql_context,
GET_CONCURRENCY_LIMITS_QUERY,
{"concurrencyKey": key},
)
assert results.data
assert "instance" in results.data
assert "concurrencyLimit" in results.data["instance"]
return results.data["instance"]["concurrencyLimit"]


def set_concurrency_limit(graphql_context, key: str, limit: int):
execute_dagster_graphql(
graphql_context,
SET_CONCURRENCY_LIMITS_MUTATION,
variables={
"concurrencyKey": key,
"limit": limit,
},
)


def fetch_all_concurrency_limits(graphql_context):
results = execute_dagster_graphql(
graphql_context,
ALL_CONCURRENCY_LIMITS_QUERY,
)
assert results.data
assert "instance" in results.data
assert "concurrencyLimits" in results.data["instance"]
return [limit for limit in results.data["instance"]["concurrencyLimits"]]


class TestInstanceSettings(BaseTestSuite):
def test_instance_settings(self, graphql_context):
results = execute_dagster_graphql(graphql_context, INSTANCE_QUERY)
Expand All @@ -81,78 +142,56 @@ def test_instance_settings(self, graphql_context):
def test_concurrency_limits(self, graphql_context):
instance = graphql_context.instance

def _fetch_limits(key: str):
results = execute_dagster_graphql(
graphql_context,
GET_CONCURRENCY_LIMITS_QUERY,
)
assert results.data
assert "instance" in results.data
assert "concurrencyLimits" in results.data["instance"]
limit_info = results.data["instance"]["concurrencyLimits"]
return next(iter([info for info in limit_info if info["concurrencyKey"] == key]), None)

def _set_limits(key: str, limit: int):
execute_dagster_graphql(
graphql_context,
SET_CONCURRENCY_LIMITS_MUTATION,
variables={
"concurrencyKey": key,
"limit": limit,
},
)

# default limits are empty
assert _fetch_limits("foo") is None
all_limits = fetch_all_concurrency_limits(graphql_context)
assert len(all_limits) == 0

# set a limit
_set_limits("foo", 10)
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 10 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 0 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"] == [] # pyright: ignore[reportOptionalSubscript]
set_concurrency_limit(graphql_context, "foo", 10)
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 10
assert foo["activeSlotCount"] == 0
assert foo["activeRunIds"] == []
assert foo["claimedSlots"] == []
assert foo["pendingSteps"] == []

# claim a slot
run_id = make_new_run_id()
instance.event_log_storage.claim_concurrency_slot("foo", run_id, "fake_step_key")
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 10 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [run_id] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] # pyright: ignore[reportOptionalSubscript]
assert len(foo["pendingSteps"]) == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["runId"] == run_id # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["priority"] == 0 # pyright: ignore[reportOptionalSubscript]

# set a new limit
_set_limits("foo", 5)
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 5 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [run_id] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] # pyright: ignore[reportOptionalSubscript]
assert len(foo["pendingSteps"]) == 1 # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["runId"] == run_id # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"][0]["priority"] == 0 # pyright: ignore[reportOptionalSubscript]

# free a slot
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 10
assert foo["activeSlotCount"] == 1
assert foo["activeRunIds"] == [run_id]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}]
assert len(foo["pendingSteps"]) == 1
assert foo["pendingSteps"][0]["runId"] == run_id
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key"
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None
assert foo["pendingSteps"][0]["priority"] == 0

set_concurrency_limit(graphql_context, "foo", 5)
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 5
assert foo["activeSlotCount"] == 1
assert foo["activeRunIds"] == [run_id]
assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}]
assert len(foo["pendingSteps"]) == 1
assert foo["pendingSteps"][0]["runId"] == run_id
assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key"
assert foo["pendingSteps"][0]["assignedTimestamp"] is not None
assert foo["pendingSteps"][0]["priority"] == 0

instance.event_log_storage.free_concurrency_slots_for_run(run_id)
foo = _fetch_limits("foo")
assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript]
assert foo["slotCount"] == 5 # pyright: ignore[reportOptionalSubscript]
assert foo["activeSlotCount"] == 0 # pyright: ignore[reportOptionalSubscript]
assert foo["activeRunIds"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["claimedSlots"] == [] # pyright: ignore[reportOptionalSubscript]
assert foo["pendingSteps"] == [] # pyright: ignore[reportOptionalSubscript]
foo = fetch_concurrency_limit(graphql_context, "foo")
assert foo["concurrencyKey"] == "foo"
assert foo["slotCount"] == 5
assert foo["activeSlotCount"] == 0
assert foo["activeRunIds"] == []
assert foo["claimedSlots"] == []
assert foo["pendingSteps"] == []

def test_concurrency_free(self, graphql_context):
storage = graphql_context.instance.event_log_storage
Expand Down Expand Up @@ -243,3 +282,31 @@ def test_concurrency_free_run(self, graphql_context):
assert foo_info.pending_run_ids == set()
assert foo_info.assigned_step_count == 1
assert foo_info.assigned_run_ids == {run_id_2}


ConcurrencyTestSuite: Any = make_graphql_context_test_suite(
context_variants=[
GraphQLContextVariant.sqlite_with_default_concurrency_managed_grpc_env(),
]
)


class TestConcurrencyInstanceSettings(ConcurrencyTestSuite):
def test_default_concurrency(self, graphql_context):
# no limits
all_limits = fetch_all_concurrency_limits(graphql_context)
assert len(all_limits) == 0

# default limits are empty
limit = fetch_concurrency_limit(graphql_context, "foo")
assert limit is not None
assert limit["slotCount"] == 0
assert limit["configuredLimit"] == 1

# set a limit
set_concurrency_limit(graphql_context, "foo", 0)

limit = fetch_concurrency_limit(graphql_context, "foo")
assert limit is not None
assert limit["slotCount"] == 0
assert limit["configuredLimit"] == 0
Original file line number Diff line number Diff line change
Expand Up @@ -2680,6 +2680,18 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo:
.where(ConcurrencySlotsTable.c.concurrency_key == concurrency_key)
)
slot_rows = db_fetch_mappings(conn, slot_query)
slot_count = len([slot_row for slot_row in slot_rows if not slot_row["deleted"]])

if self.has_concurrency_limits_table:
limit_row = conn.execute(
db_select([ConcurrencyLimitsTable.c.limit]).where(
ConcurrencyLimitsTable.c.concurrency_key == concurrency_key
)
).fetchone()
configured_limit = cast(int, limit_row[0]) if limit_row else None
else:
configured_limit = slot_count

pending_query = (
db_select(
[
Expand All @@ -2695,9 +2707,16 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo:
)
pending_rows = db_fetch_mappings(conn, pending_query)

if (
configured_limit is None
and self.has_instance
and self._instance.global_op_concurrency_default_limit
):
configured_limit = self._instance.global_op_concurrency_default_limit

return ConcurrencyKeyInfo(
concurrency_key=concurrency_key,
slot_count=len([slot_row for slot_row in slot_rows if not slot_row["deleted"]]),
slot_count=slot_count,
claimed_slots=[
ClaimedSlotInfo(run_id=slot_row["run_id"], step_key=slot_row["step_key"])
for slot_row in slot_rows
Expand All @@ -2715,6 +2734,7 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo:
)
for row in pending_rows
],
configured_limit=configured_limit,
)

def get_concurrency_run_ids(self) -> set[str]:
Expand Down
Loading

0 comments on commit 222005b

Please sign in to comment.