From 746bbe6a620939135c69615d58e13133d03e55c8 Mon Sep 17 00:00:00 2001 From: Jean Luciano Date: Mon, 9 Sep 2024 13:22:14 -0500 Subject: [PATCH] `Runner` enforces `Deployment.concurrency_limit` (#15085) Co-authored-by: Andrew Brookins --- src/prefect/deployments/runner.py | 1 + src/prefect/runner/runner.py | 75 ++++++++++++++++-- tests/runner/test_runner.py | 127 ++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 5 deletions(-) diff --git a/src/prefect/deployments/runner.py b/src/prefect/deployments/runner.py index f49586f406ba..30f06976713f 100644 --- a/src/prefect/deployments/runner.py +++ b/src/prefect/deployments/runner.py @@ -462,6 +462,7 @@ def from_flow( paused: Whether or not to set this deployment as paused. schedules: A list of schedule objects defining when to execute runs of this deployment. Used to define multiple schedules or additional scheduling options like `timezone`. + concurrency_limit: The maximum number of concurrent runs this deployment will allow. triggers: A list of triggers that should kick of a run of this flow. parameters: A dictionary of default parameter values to pass to runs of this flow. description: A description for the created deployment. Defaults to the flow's diff --git a/src/prefect/runner/runner.py b/src/prefect/runner/runner.py index d4e83f90c185..44e0030bc5b6 100644 --- a/src/prefect/runner/runner.py +++ b/src/prefect/runner/runner.py @@ -66,6 +66,11 @@ def fast_flow(): ) from prefect.client.schemas.objects import Flow as APIFlow from prefect.client.schemas.objects import FlowRun, State, StateType +from prefect.concurrency.asyncio import ( + AcquireConcurrencySlotTimeoutError, + ConcurrencySlotAcquisitionError, + concurrency, +) from prefect.events import DeploymentTriggerTypes, TriggerTypes from prefect.events.related import tags_as_related_resources from prefect.events.schemas.events import RelatedResource @@ -81,7 +86,12 @@ def fast_flow(): PREFECT_RUNNER_SERVER_ENABLE, get_current_settings, ) -from prefect.states import Crashed, Pending, exception_to_failed_state +from prefect.states import ( + AwaitingConcurrencySlot, + Crashed, + Pending, + exception_to_failed_state, +) from prefect.types.entrypoint import EntrypointType from prefect.utilities.asyncutils import ( asyncnullcontext, @@ -226,6 +236,7 @@ async def add_flow( rrule: Optional[Union[Iterable[str], str]] = None, paused: Optional[bool] = None, schedules: Optional["FlexibleScheduleList"] = None, + concurrency_limit: Optional[int] = None, parameters: Optional[dict] = None, triggers: Optional[List[Union[DeploymentTriggerTypes, TriggerTypes]]] = None, description: Optional[str] = None, @@ -248,6 +259,10 @@ async def add_flow( or a timedelta object. If a number is given, it will be interpreted as seconds. cron: A cron schedule of when to execute runs of this flow. rrule: An rrule schedule of when to execute runs of this flow. + paused: Whether or not to set the created deployment as paused. + schedules: A list of schedule objects defining when to execute runs of this flow. + Used to define multiple schedules or additional scheduling options like `timezone`. + concurrency_limit: The maximum number of concurrent runs of this flow to allow. triggers: A list of triggers that should kick of a run of this flow. parameters: A dictionary of default parameter values to pass to runs of this flow. description: A description for the created deployment. Defaults to the flow's @@ -280,6 +295,7 @@ async def add_flow( version=version, enforce_parameter_schema=enforce_parameter_schema, entrypoint_type=entrypoint_type, + concurrency_limit=concurrency_limit, ) return await self.add_deployment(deployment) @@ -959,6 +975,7 @@ async def _submit_scheduled_flow_runs( """ submittable_flow_runs = flow_run_response submittable_flow_runs.sort(key=lambda run: run.next_scheduled_start_time) + for i, flow_run in enumerate(submittable_flow_runs): if flow_run.id in self._submitting_flow_run_ids: continue @@ -1025,12 +1042,40 @@ async def _submit_run_and_capture_errors( ) -> Union[Optional[int], Exception]: run_logger = self._get_flow_run_logger(flow_run) + if flow_run.deployment_id: + deployment = await self._client.read_deployment(flow_run.deployment_id) + if deployment and deployment.concurrency_limit: + limit_name = f"deployment:{deployment.id}" + concurrency_ctx = concurrency + else: + limit_name = None + concurrency_ctx = asyncnullcontext + try: - status_code = await self._run_process( - flow_run=flow_run, - task_status=task_status, - entrypoint=entrypoint, + async with concurrency_ctx( + limit_name, occupy=deployment.concurrency_limit, max_retries=0 + ): + status_code = await self._run_process( + flow_run=flow_run, + task_status=task_status, + entrypoint=entrypoint, + ) + except ( + AcquireConcurrencySlotTimeoutError, + ConcurrencySlotAcquisitionError, + ) as exc: + self._logger.info( + ( + "Deployment %s reached its concurrency limit when attempting to execute flow run %s. Will attempt to execute later." + ), + flow_run.deployment_id, + flow_run.name, ) + await self._propose_scheduled_state(flow_run) + + if not task_status._future.done(): + task_status.started(exc) + return exc except Exception as exc: if not task_status._future.done(): # This flow run was being submitted and did not start successfully @@ -1116,6 +1161,26 @@ async def _propose_failed_state(self, flow_run: "FlowRun", exc: Exception) -> No exc_info=True, ) + async def _propose_scheduled_state(self, flow_run: "FlowRun") -> None: + run_logger = self._get_flow_run_logger(flow_run) + try: + state = await propose_state( + self._client, + AwaitingConcurrencySlot(), + flow_run_id=flow_run.id, + ) + self._logger.info(f"Flow run {flow_run.id} now has state {state.name}") + except Abort as exc: + run_logger.info( + ( + f"Aborted rescheduling of flow run '{flow_run.id}'. " + f"Server sent an abort signal: {exc}" + ), + ) + pass + except Exception: + run_logger.exception(f"Failed to update state of flow run '{flow_run.id}'") + async def _propose_crashed_state(self, flow_run: "FlowRun", message: str) -> None: run_logger = self._get_flow_run_logger(flow_run) try: diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 2839bd293e21..5be55b7bb652 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -26,6 +26,11 @@ from prefect.client.schemas.actions import DeploymentScheduleCreate from prefect.client.schemas.objects import StateType from prefect.client.schemas.schedules import CronSchedule, IntervalSchedule +from prefect.concurrency.asyncio import ( + AcquireConcurrencySlotTimeoutError, + _acquire_concurrency_slots, + _release_concurrency_slots, +) from prefect.deployments.runner import ( DeploymentApplyError, EntrypointType, @@ -631,6 +636,128 @@ async def test_runner_respects_set_limit( flow_run = await prefect_client.read_flow_run(flow_run_id=bad_run.id) assert flow_run.state.is_completed() + @pytest.mark.usefixtures("use_hosted_api_server") + async def test_runner_enforces_deployment_concurrency_limits( + self, prefect_client: PrefectClient, caplog + ): + async def test(*args, **kwargs): + return 0 + + with mock.patch( + "prefect.concurrency.asyncio._acquire_concurrency_slots", + wraps=_acquire_concurrency_slots, + ) as acquire_spy: + with mock.patch( + "prefect.concurrency.asyncio._release_concurrency_slots", + wraps=_release_concurrency_slots, + ) as release_spy: + async with Runner(pause_on_shutdown=False) as runner: + deployment = RunnerDeployment.from_flow( + flow=dummy_flow_1, + name=__file__, + concurrency_limit=1, + ) + + deployment_id = await runner.add_deployment(deployment) + + flow_run = await prefect_client.create_flow_run_from_deployment( + deployment_id=deployment_id + ) + + assert flow_run.state.is_scheduled() + + runner.run = test # simulate running a flow + + await runner._get_and_submit_flow_runs() + + acquire_spy.assert_called_once_with( + [f"deployment:{deployment_id}"], + 1, + timeout_seconds=None, + create_if_missing=True, + max_retries=0, + ) + + names, occupy, occupy_seconds = release_spy.call_args[0] + assert names == [f"deployment:{deployment_id}"] + assert occupy == 1 + assert occupy_seconds > 0 + + @pytest.mark.usefixtures("use_hosted_api_server") + async def test_runner_proposes_awaiting_concurrency_limit_state_name( + self, prefect_client: PrefectClient, caplog + ): + async def test(*args, **kwargs): + return 0 + + with mock.patch( + "prefect.concurrency.asyncio._acquire_concurrency_slots", + wraps=_acquire_concurrency_slots, + ) as acquire_spy: + # Simulate a Locked response from the API + acquire_spy.side_effect = AcquireConcurrencySlotTimeoutError + + async with Runner(pause_on_shutdown=False) as runner: + deployment = RunnerDeployment.from_flow( + flow=dummy_flow_1, + name=__file__, + concurrency_limit=1, + ) + + deployment_id = await runner.add_deployment(deployment) + + flow_run = await prefect_client.create_flow_run_from_deployment( + deployment_id=deployment_id + ) + + assert flow_run.state.is_scheduled() + + runner.run = test # simulate running a flow + + await runner._get_and_submit_flow_runs() + + acquire_spy.assert_called_once_with( + [f"deployment:{deployment_id}"], + 1, + timeout_seconds=None, + create_if_missing=True, + max_retries=0, + ) + + flow_run = await prefect_client.read_flow_run(flow_run.id) + assert flow_run.state.name == "AwaitingConcurrencySlot" + + @pytest.mark.usefixtures("use_hosted_api_server") + async def test_runner_does_not_attempt_to_acquire_limit_if_deployment_has_no_concurrency_limit( + self, prefect_client: PrefectClient, caplog + ): + async def test(*args, **kwargs): + return 0 + + with mock.patch( + "prefect.concurrency.asyncio._acquire_concurrency_slots", + wraps=_acquire_concurrency_slots, + ) as acquire_spy: + async with Runner(pause_on_shutdown=False) as runner: + deployment = RunnerDeployment.from_flow( + flow=dummy_flow_1, + name=__file__, + ) + + deployment_id = await runner.add_deployment(deployment) + + flow_run = await prefect_client.create_flow_run_from_deployment( + deployment_id=deployment_id + ) + + assert flow_run.state.is_scheduled() + + runner.run = test # simulate running a flow + + await runner._get_and_submit_flow_runs() + + acquire_spy.assert_not_called() + async def test_handles_spaces_in_sys_executable(self, monkeypatch, prefect_client): """ Regression test for https://github.com/PrefectHQ/prefect/issues/10820