From 2a97398b4014966aaf77858a3db6a9a7fc0128af Mon Sep 17 00:00:00 2001 From: prha Date: Mon, 30 Dec 2024 16:31:48 -0800 Subject: [PATCH] separate configured limit from slot count --- .../ui-core/src/graphql/schema.graphql | 1 + .../packages/ui-core/src/graphql/types.ts | 3 + .../dagster_graphql/schema/instance.py | 4 + .../graphql/graphql_context_test_suite.py | 40 ++++ .../graphql/test_instance.py | 197 ++++++++++++------ .../_core/storage/event_log/sql_event_log.py | 22 +- .../dagster/dagster/_utils/concurrency.py | 24 ++- 7 files changed, 222 insertions(+), 69 deletions(-) diff --git a/js_modules/dagster-ui/packages/ui-core/src/graphql/schema.graphql b/js_modules/dagster-ui/packages/ui-core/src/graphql/schema.graphql index c9685c7341aa6..0250aef608144 100644 --- a/js_modules/dagster-ui/packages/ui-core/src/graphql/schema.graphql +++ b/js_modules/dagster-ui/packages/ui-core/src/graphql/schema.graphql @@ -2497,6 +2497,7 @@ type ConcurrencyKeyInfo { pendingStepRunIds: [String!]! assignedStepCount: Int! assignedStepRunIds: [String!]! + configuredLimit: Int } type ClaimedConcurrencySlot { diff --git a/js_modules/dagster-ui/packages/ui-core/src/graphql/types.ts b/js_modules/dagster-ui/packages/ui-core/src/graphql/types.ts index 52c3cf5d643df..6286cbb8cb219 100644 --- a/js_modules/dagster-ui/packages/ui-core/src/graphql/types.ts +++ b/js_modules/dagster-ui/packages/ui-core/src/graphql/types.ts @@ -846,6 +846,7 @@ export type ConcurrencyKeyInfo = { assignedStepRunIds: Array; claimedSlots: Array; concurrencyKey: Scalars['String']['output']; + configuredLimit: Maybe; pendingStepCount: Scalars['Int']['output']; pendingStepRunIds: Array; pendingSteps: Array; @@ -7348,6 +7349,8 @@ export const buildConcurrencyKeyInfo = ( overrides && overrides.hasOwnProperty('claimedSlots') ? overrides.claimedSlots! : [], concurrencyKey: overrides && overrides.hasOwnProperty('concurrencyKey') ? overrides.concurrencyKey! : 'quasi', + configuredLimit: + overrides && overrides.hasOwnProperty('configuredLimit') ? overrides.configuredLimit! : 9480, pendingStepCount: overrides && overrides.hasOwnProperty('pendingStepCount') ? overrides.pendingStepCount! : 370, pendingStepRunIds: diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/instance.py b/python_modules/dagster-graphql/dagster_graphql/schema/instance.py index 715f15b6b41cb..60d594bc85792 100644 --- a/python_modules/dagster-graphql/dagster_graphql/schema/instance.py +++ b/python_modules/dagster-graphql/dagster_graphql/schema/instance.py @@ -142,6 +142,7 @@ class GrapheneConcurrencyKeyInfo(graphene.ObjectType): pendingStepRunIds = non_null_list(graphene.String) assignedStepCount = graphene.NonNull(graphene.Int) assignedStepRunIds = non_null_list(graphene.String) + configuredLimit = graphene.Int() class Meta: name = "ConcurrencyKeyInfo" @@ -193,6 +194,9 @@ def resolve_assignedStepCount(self, graphene_info: ResolveInfo): def resolve_assignedStepRunIds(self, graphene_info: ResolveInfo): return list(self._get_concurrency_key_info(graphene_info).assigned_run_ids) + def resolve_configuredLimit(self, graphene_info: ResolveInfo): + return self._get_concurrency_key_info(graphene_info).configured_limit + class GrapheneRunQueueConfig(graphene.ObjectType): maxConcurrentRuns = graphene.NonNull(graphene.Int) diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py index d8667cf266e9b..e449db428e025 100644 --- a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py +++ b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py @@ -267,6 +267,35 @@ def _sqlite_asset_instance(): return MarkedManager(_sqlite_asset_instance, [Marks.asset_aware_instance]) + @staticmethod + def default_concurrency_sqlite_instance(): + @contextmanager + def _sqlite_with_default_concurrency_instance(): + with tempfile.TemporaryDirectory() as temp_dir: + with instance_for_test( + temp_dir=temp_dir, + overrides={ + "scheduler": { + "module": "dagster.utils.test", + "class": "FilesystemTestScheduler", + "config": {"base_dir": temp_dir}, + }, + "run_coordinator": { + "module": "dagster._core.run_coordinator.queued_run_coordinator", + "class": "QueuedRunCoordinator", + }, + "concurrency": { + "default_op_concurrency_limit": 1, + }, + }, + ) as instance: + yield instance + + return MarkedManager( + _sqlite_with_default_concurrency_instance, + [Marks.sqlite_instance, Marks.queued_run_coordinator], + ) + class EnvironmentManagers: @staticmethod @@ -556,6 +585,16 @@ def sqlite_with_default_run_launcher_code_server_cli_env( test_id="sqlite_with_default_run_launcher_code_server_cli_env", ) + @staticmethod + def sqlite_with_default_concurrency_managed_grpc_env( + target=None, location_name="test_location" + ): + return GraphQLContextVariant( + InstanceManagers.default_concurrency_sqlite_instance(), + EnvironmentManagers.managed_grpc(target, location_name), + test_id="sqlite_with_default_concurrency_managed_grpc_env", + ) + @staticmethod def postgres_with_default_run_launcher_managed_grpc_env( target=None, location_name="test_location" @@ -662,6 +701,7 @@ def all_variants(): GraphQLContextVariant.non_launchable_postgres_instance_managed_grpc_env(), GraphQLContextVariant.non_launchable_postgres_instance_lazy_repository(), GraphQLContextVariant.consolidated_sqlite_instance_managed_grpc_env(), + GraphQLContextVariant.sqlite_with_default_concurrency_managed_grpc_env(), ] @staticmethod diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_instance.py b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_instance.py index 813b156c5eebb..3c8e2313ccc7f 100644 --- a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_instance.py +++ b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_instance.py @@ -19,7 +19,33 @@ """ GET_CONCURRENCY_LIMITS_QUERY = """ -query InstanceConcurrencyLimitsQuery { +query InstanceConcurrencyLimitsQuery($concurrencyKey: String!) { + instance { + concurrencyLimit(concurrencyKey: $concurrencyKey) { + concurrencyKey + slotCount + activeSlotCount + activeRunIds + claimedSlots { + runId + stepKey + } + pendingSteps { + runId + stepKey + enqueuedTimestamp + assignedTimestamp + priority + } + configuredLimit + } + } +} + +""" + +ALL_CONCURRENCY_LIMITS_QUERY = """ +query AllConcurrencyLimitsQuery { instance { concurrencyLimits { concurrencyKey @@ -37,6 +63,7 @@ assignedTimestamp priority } + configuredLimit } } } @@ -67,6 +94,40 @@ ) +def fetch_concurrency_limit(graphql_context, key: str): + results = execute_dagster_graphql( + graphql_context, + GET_CONCURRENCY_LIMITS_QUERY, + {"concurrencyKey": key}, + ) + assert results.data + assert "instance" in results.data + assert "concurrencyLimit" in results.data["instance"] + return results.data["instance"]["concurrencyLimit"] + + +def set_concurrency_limit(graphql_context, key: str, limit: int): + execute_dagster_graphql( + graphql_context, + SET_CONCURRENCY_LIMITS_MUTATION, + variables={ + "concurrencyKey": key, + "limit": limit, + }, + ) + + +def fetch_all_concurrency_limits(graphql_context): + results = execute_dagster_graphql( + graphql_context, + ALL_CONCURRENCY_LIMITS_QUERY, + ) + assert results.data + assert "instance" in results.data + assert "concurrencyLimits" in results.data["instance"] + return [limit for limit in results.data["instance"]["concurrencyLimits"]] + + class TestInstanceSettings(BaseTestSuite): def test_instance_settings(self, graphql_context): results = execute_dagster_graphql(graphql_context, INSTANCE_QUERY) @@ -81,78 +142,56 @@ def test_instance_settings(self, graphql_context): def test_concurrency_limits(self, graphql_context): instance = graphql_context.instance - def _fetch_limits(key: str): - results = execute_dagster_graphql( - graphql_context, - GET_CONCURRENCY_LIMITS_QUERY, - ) - assert results.data - assert "instance" in results.data - assert "concurrencyLimits" in results.data["instance"] - limit_info = results.data["instance"]["concurrencyLimits"] - return next(iter([info for info in limit_info if info["concurrencyKey"] == key]), None) - - def _set_limits(key: str, limit: int): - execute_dagster_graphql( - graphql_context, - SET_CONCURRENCY_LIMITS_MUTATION, - variables={ - "concurrencyKey": key, - "limit": limit, - }, - ) - # default limits are empty - assert _fetch_limits("foo") is None + all_limits = fetch_all_concurrency_limits(graphql_context) + assert len(all_limits) == 0 # set a limit - _set_limits("foo", 10) - foo = _fetch_limits("foo") - assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript] - assert foo["slotCount"] == 10 # pyright: ignore[reportOptionalSubscript] - assert foo["activeSlotCount"] == 0 # pyright: ignore[reportOptionalSubscript] - assert foo["activeRunIds"] == [] # pyright: ignore[reportOptionalSubscript] - assert foo["claimedSlots"] == [] # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"] == [] # pyright: ignore[reportOptionalSubscript] + set_concurrency_limit(graphql_context, "foo", 10) + foo = fetch_concurrency_limit(graphql_context, "foo") + assert foo["concurrencyKey"] == "foo" + assert foo["slotCount"] == 10 + assert foo["activeSlotCount"] == 0 + assert foo["activeRunIds"] == [] + assert foo["claimedSlots"] == [] + assert foo["pendingSteps"] == [] # claim a slot run_id = make_new_run_id() instance.event_log_storage.claim_concurrency_slot("foo", run_id, "fake_step_key") - foo = _fetch_limits("foo") - assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript] - assert foo["slotCount"] == 10 # pyright: ignore[reportOptionalSubscript] - assert foo["activeSlotCount"] == 1 # pyright: ignore[reportOptionalSubscript] - assert foo["activeRunIds"] == [run_id] # pyright: ignore[reportOptionalSubscript] - assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] # pyright: ignore[reportOptionalSubscript] - assert len(foo["pendingSteps"]) == 1 # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["runId"] == run_id # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["assignedTimestamp"] is not None # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["priority"] == 0 # pyright: ignore[reportOptionalSubscript] - - # set a new limit - _set_limits("foo", 5) - foo = _fetch_limits("foo") - assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript] - assert foo["slotCount"] == 5 # pyright: ignore[reportOptionalSubscript] - assert foo["activeSlotCount"] == 1 # pyright: ignore[reportOptionalSubscript] - assert foo["activeRunIds"] == [run_id] # pyright: ignore[reportOptionalSubscript] - assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] # pyright: ignore[reportOptionalSubscript] - assert len(foo["pendingSteps"]) == 1 # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["runId"] == run_id # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["assignedTimestamp"] is not None # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"][0]["priority"] == 0 # pyright: ignore[reportOptionalSubscript] - - # free a slot + foo = fetch_concurrency_limit(graphql_context, "foo") + assert foo["concurrencyKey"] == "foo" + assert foo["slotCount"] == 10 + assert foo["activeSlotCount"] == 1 + assert foo["activeRunIds"] == [run_id] + assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] + assert len(foo["pendingSteps"]) == 1 + assert foo["pendingSteps"][0]["runId"] == run_id + assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" + assert foo["pendingSteps"][0]["assignedTimestamp"] is not None + assert foo["pendingSteps"][0]["priority"] == 0 + + set_concurrency_limit(graphql_context, "foo", 5) + foo = fetch_concurrency_limit(graphql_context, "foo") + assert foo["concurrencyKey"] == "foo" + assert foo["slotCount"] == 5 + assert foo["activeSlotCount"] == 1 + assert foo["activeRunIds"] == [run_id] + assert foo["claimedSlots"] == [{"runId": run_id, "stepKey": "fake_step_key"}] + assert len(foo["pendingSteps"]) == 1 + assert foo["pendingSteps"][0]["runId"] == run_id + assert foo["pendingSteps"][0]["stepKey"] == "fake_step_key" + assert foo["pendingSteps"][0]["assignedTimestamp"] is not None + assert foo["pendingSteps"][0]["priority"] == 0 + instance.event_log_storage.free_concurrency_slots_for_run(run_id) - foo = _fetch_limits("foo") - assert foo["concurrencyKey"] == "foo" # pyright: ignore[reportOptionalSubscript] - assert foo["slotCount"] == 5 # pyright: ignore[reportOptionalSubscript] - assert foo["activeSlotCount"] == 0 # pyright: ignore[reportOptionalSubscript] - assert foo["activeRunIds"] == [] # pyright: ignore[reportOptionalSubscript] - assert foo["claimedSlots"] == [] # pyright: ignore[reportOptionalSubscript] - assert foo["pendingSteps"] == [] # pyright: ignore[reportOptionalSubscript] + foo = fetch_concurrency_limit(graphql_context, "foo") + assert foo["concurrencyKey"] == "foo" + assert foo["slotCount"] == 5 + assert foo["activeSlotCount"] == 0 + assert foo["activeRunIds"] == [] + assert foo["claimedSlots"] == [] + assert foo["pendingSteps"] == [] def test_concurrency_free(self, graphql_context): storage = graphql_context.instance.event_log_storage @@ -243,3 +282,31 @@ def test_concurrency_free_run(self, graphql_context): assert foo_info.pending_run_ids == set() assert foo_info.assigned_step_count == 1 assert foo_info.assigned_run_ids == {run_id_2} + + +ConcurrencyTestSuite: Any = make_graphql_context_test_suite( + context_variants=[ + GraphQLContextVariant.sqlite_with_default_concurrency_managed_grpc_env(), + ] +) + + +class TestConcurrencyInstanceSettings(ConcurrencyTestSuite): + def test_default_concurrency(self, graphql_context): + # no limits + all_limits = fetch_all_concurrency_limits(graphql_context) + assert len(all_limits) == 0 + + # default limits are empty + limit = fetch_concurrency_limit(graphql_context, "foo") + assert limit is not None + assert limit["slotCount"] == 0 + assert limit["configuredLimit"] == 1 + + # set a limit + set_concurrency_limit(graphql_context, "foo", 0) + + limit = fetch_concurrency_limit(graphql_context, "foo") + assert limit is not None + assert limit["slotCount"] == 0 + assert limit["configuredLimit"] == 0 diff --git a/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py b/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py index e8b67b15c3768..d66ff6a7bc1a5 100644 --- a/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py +++ b/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py @@ -2680,6 +2680,18 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo: .where(ConcurrencySlotsTable.c.concurrency_key == concurrency_key) ) slot_rows = db_fetch_mappings(conn, slot_query) + slot_count = len([slot_row for slot_row in slot_rows if not slot_row["deleted"]]) + + if self.has_concurrency_limits_table: + limit_row = conn.execute( + db_select([ConcurrencyLimitsTable.c.limit]).where( + ConcurrencyLimitsTable.c.concurrency_key == concurrency_key + ) + ).fetchone() + configured_limit = cast(int, limit_row[0]) if limit_row else None + else: + configured_limit = slot_count + pending_query = ( db_select( [ @@ -2695,9 +2707,16 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo: ) pending_rows = db_fetch_mappings(conn, pending_query) + if ( + configured_limit is None + and self.has_instance + and self._instance.global_op_concurrency_default_limit + ): + configured_limit = self._instance.global_op_concurrency_default_limit + return ConcurrencyKeyInfo( concurrency_key=concurrency_key, - slot_count=len([slot_row for slot_row in slot_rows if not slot_row["deleted"]]), + slot_count=slot_count, claimed_slots=[ ClaimedSlotInfo(run_id=slot_row["run_id"], step_key=slot_row["step_key"]) for slot_row in slot_rows @@ -2715,6 +2734,7 @@ def get_concurrency_info(self, concurrency_key: str) -> ConcurrencyKeyInfo: ) for row in pending_rows ], + configured_limit=configured_limit, ) def get_concurrency_run_ids(self) -> set[str]: diff --git a/python_modules/dagster/dagster/_utils/concurrency.py b/python_modules/dagster/dagster/_utils/concurrency.py index 9b70d9875b3ea..db60868773fd9 100644 --- a/python_modules/dagster/dagster/_utils/concurrency.py +++ b/python_modules/dagster/dagster/_utils/concurrency.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Optional -from dagster._record import record +from dagster._record import IHaveNew, record, record_custom def get_max_concurrency_limit_value() -> int: @@ -68,12 +68,30 @@ class ClaimedSlotInfo: step_key: str -@record -class ConcurrencyKeyInfo: +@record_custom +class ConcurrencyKeyInfo(IHaveNew): concurrency_key: str slot_count: int claimed_slots: list[ClaimedSlotInfo] pending_steps: list[PendingStepInfo] + configured_limit: Optional[int] + + def __new__( + cls, + concurrency_key: str, + slot_count: int, + claimed_slots: list[ClaimedSlotInfo], + pending_steps: list[PendingStepInfo], + configured_limit: Optional[int] = None, + ): + return super().__new__( + cls, + concurrency_key=concurrency_key, + slot_count=slot_count, + claimed_slots=claimed_slots, + pending_steps=pending_steps, + configured_limit=configured_limit, + ) ################################################### # Fields that we need to keep around for backcompat