Skip to content

Commit

Permalink
Move client-side concurrency behind orchestration flag
Browse files Browse the repository at this point in the history
  • Loading branch information
abrookins committed Jul 23, 2024
1 parent 0d7b445 commit a206fbb
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 172 deletions.
28 changes: 0 additions & 28 deletions docs/3.0rc/api-ref/rest-api/server/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1032,15 +1032,6 @@
"type": "string",
"title": "X-Prefect-Api-Version"
}
},
{
"name": "user-agent",
"in": "header",
"required": false,
"schema": {
"type": "string",
"title": "User-Agent"
}
}
],
"requestBody": {
Expand Down Expand Up @@ -1932,15 +1923,6 @@
"type": "string",
"title": "X-Prefect-Api-Version"
}
},
{
"name": "user-agent",
"in": "header",
"required": false,
"schema": {
"type": "string",
"title": "User-Agent"
}
}
],
"requestBody": {
Expand Down Expand Up @@ -22295,16 +22277,6 @@
"title": "Prefect Experimental Enable Client Side Task Orchestration",
"default": false
},
"PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY": {
"type": "boolean",
"title": "Prefect Experimental Enable Client Side Task Concurrency",
"default": true
},
"PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY": {
"type": "boolean",
"title": "Prefect Experimental Warn Client Side Task Concurrency",
"default": false
},
"PREFECT_RUNNER_PROCESS_LIMIT": {
"type": "integer",
"title": "Prefect Runner Process Limit",
Expand Down
17 changes: 1 addition & 16 deletions src/prefect/server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from uuid import UUID

from fastapi import Body, Depends, Header, HTTPException, status
from packaging.version import InvalidVersion, Version
from packaging.version import Version
from starlette.requests import Request

from prefect.server import schemas
Expand All @@ -33,21 +33,6 @@ def provide_request_api_version(x_prefect_api_version: str = Header(None)):
return Version(x_prefect_api_version)


def provide_request_client_version(user_agent: str = Header(None)):
if not user_agent:
return

# Try to parse a Prefect version from the user agent
try:
client_version, api_version = user_agent.split(" ", 1)
client_version = client_version.split("/")[1]
version = Version(client_version)
except (ValueError, IndexError, InvalidVersion):
return

return version


class EnforceMinimumAPIVersion:
"""
FastAPI Dependency used to check compatibility between the version of the api
Expand Down
3 changes: 2 additions & 1 deletion src/prefect/server/api/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from prefect.server.database.dependencies import provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.orchestration import dependencies as orchestration_dependencies
from prefect.server.orchestration.core_policy import CoreTaskPolicy
from prefect.server.orchestration.policies import BaseOrchestrationPolicy
from prefect.server.schemas.responses import OrchestrationResult
from prefect.server.task_queue import MultiQueue, TaskQueue
Expand Down Expand Up @@ -261,7 +262,7 @@ async def set_task_run_state(
state
), # convert to a full State object
force=force,
task_policy=task_policy,
task_policy=CoreTaskPolicy,
orchestration_parameters=orchestration_parameters,
)

Expand Down
24 changes: 2 additions & 22 deletions src/prefect/server/orchestration/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@

from contextlib import contextmanager

from fastapi import Depends

from prefect.server.api.dependencies import provide_request_client_version
from prefect.settings import PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY

ORCHESTRATION_DEPENDENCIES = {
"task_policy_provider": None,
"flow_policy_provider": None,
Expand All @@ -17,26 +12,11 @@
}


async def provide_task_policy(client_version=Depends(provide_request_client_version)):
async def provide_task_policy():
policy_provider = ORCHESTRATION_DEPENDENCIES.get("task_policy_provider")

if policy_provider is None:
from prefect.server.orchestration.core_policy import (
ClientSideTaskOrchestrationPolicy,
CoreTaskPolicy,
)

if (
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY.value()
and client_version
and (
# Clients older than 3.0.0rc11 do not support client-side task concurrency.
client_version.major == 3
and client_version.pre
and client_version.pre[1] >= 9
)
):
return ClientSideTaskOrchestrationPolicy
from prefect.server.orchestration.core_policy import CoreTaskPolicy

return CoreTaskPolicy

Expand Down
11 changes: 0 additions & 11 deletions src/prefect/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,17 +1321,6 @@ def default_cloud_ui_url(settings, value):
Whether or not to enable experimental client side task run orchestration.
"""

PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY = Setting(bool, default=True)
"""
Whether or not to enable experimental client-side management of task concurrency limits.
"""

PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY = Setting(bool, default=False)
"""
Whether or not to warn when experimental client-side management of task
concurrency limits is used.
"""

# Prefect Events feature flags

PREFECT_RUNNER_PROCESS_LIMIT = Setting(int, default=5)
Expand Down
5 changes: 2 additions & 3 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from prefect.results import BaseResult, ResultFactory, _format_user_supplied_storage_key
from prefect.settings import (
PREFECT_DEBUG_MODE,
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY,
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
PREFECT_TASKS_REFRESH_CACHE,
)
Expand Down Expand Up @@ -767,7 +766,7 @@ async def _call_task_fn():
if transaction.is_committed():
result = transaction.read()
else:
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY.value():
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value():
# Acquire a concurrency slot for each tag, but only if a limit
# matching the tag already exists.
async with aconcurrency(
Expand All @@ -786,7 +785,7 @@ async def _call_task_fn():
if transaction.is_committed():
result = transaction.read()
else:
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY.value():
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value():
# Acquire a concurrency slot for each tag, but only if a limit
# matching the tag already exists.
with concurrency(
Expand Down
40 changes: 10 additions & 30 deletions tests/server/orchestration/api/test_task_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import uuid
from unittest import mock
from uuid import uuid4

import pendulum
Expand All @@ -8,27 +7,27 @@

from prefect.client.orchestration import PrefectClient
from prefect.client.schemas.objects import State
from prefect.events.worker import EventsWorker
from prefect.server import models, schemas
from prefect.server.database.orm_models import TaskRun
from prefect.server.schemas import responses, states
from prefect.server.schemas.responses import OrchestrationResult
from prefect.settings import (
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY,
PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY,
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
temporary_settings,
)
from prefect.states import Pending, Running
from prefect.states import Pending


@pytest.fixture
def enable_client_side_concurrency():
@pytest.fixture(autouse=True, params=[False, True])
def enable_client_side_task_run_orchestration(
request, asserting_events_worker: EventsWorker
):
enabled = request.param
with temporary_settings(
updates={
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY: True,
PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY: False,
}
{PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: enabled}
):
yield
yield enabled


class TestCreateTaskRun:
Expand Down Expand Up @@ -593,25 +592,6 @@ async def test_autonomous_task_run_aborts_if_enters_pending_from_disallowed_stat

assert response_2.status == responses.SetStateStatus.ABORT

async def test_set_task_run_state_uses_client_orchestration_policy(
self, task_run, flow_run, prefect_client, enable_client_side_concurrency
):
await prefect_client.set_flow_run_state(
flow_run_id=flow_run.id, state=Running()
)
await prefect_client.set_task_run_state(
task_run_id=task_run.id, state=Pending(), force=True
)

with mock.patch(
"prefect.server.orchestration.core_policy.SecureTaskConcurrencySlots.before_transition",
) as mock_slot_transition:
response = await prefect_client.set_task_run_state(
task_run_id=task_run.id, state=Running()
)
assert response.status == responses.SetStateStatus.ACCEPT
mock_slot_transition.assert_not_called()


class TestTaskRunHistory:
async def test_history_interval_must_be_one_second_or_larger(self, client):
Expand Down
92 changes: 31 additions & 61 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
from prefect.logging import get_run_logger
from prefect.results import PersistedResult, ResultFactory, UnpersistedResult
from prefect.settings import (
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY,
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY,
PREFECT_TASK_DEFAULT_RETRIES,
temporary_settings,
)
Expand Down Expand Up @@ -183,28 +181,6 @@ async def foo():
return 42


@pytest.fixture
def enable_client_side_concurrency():
with temporary_settings(
updates={
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY: True,
PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY: False,
}
):
yield


@pytest.fixture
def disable_client_side_concurrency():
with temporary_settings(
updates={
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_CONCURRENCY: False,
PREFECT_EXPERIMENTAL_WARN_CLIENT_SIDE_TASK_CONCURRENCY: False,
}
):
yield


class TestTaskRunEngine:
async def test_basic_init(self):
engine = TaskRunEngine(task=foo)
Expand Down Expand Up @@ -2108,7 +2084,7 @@ async def g():


class TestTaskConcurrencyLimits:
async def test_tag_concurrency(self, enable_client_side_concurrency):
async def test_tag_concurrency(self, enable_client_side_task_run_orchestration):
@task(tags=["limit-tag"])
async def bar():
return 42
Expand All @@ -2123,16 +2099,19 @@ async def bar():
) as release_spy:
await bar()

acquire_spy.assert_called_once_with(
["limit-tag"], 1, timeout_seconds=None, create_if_missing=False
)
if enable_client_side_task_run_orchestration:
acquire_spy.assert_called_once_with(
["limit-tag"], 1, timeout_seconds=None, create_if_missing=False
)

names, occupy, occupy_seconds = release_spy.call_args[0]
assert names == ["limit-tag"]
assert occupy == 1
assert occupy_seconds > 0
names, occupy, occupy_seconds = release_spy.call_args[0]
assert names == ["limit-tag"]
assert occupy == 1
assert occupy_seconds > 0
else:
assert acquire_spy.call_count == 0

def test_tag_concurrency_sync(self, enable_client_side_concurrency):
def test_tag_concurrency_sync(self, enable_client_side_task_run_orchestration):
@task(tags=["limit-tag"])
def bar():
return 42
Expand All @@ -2147,17 +2126,20 @@ def bar():
) as release_spy:
bar()

acquire_spy.assert_called_once_with(
["limit-tag"], 1, timeout_seconds=None, create_if_missing=False
)
if enable_client_side_task_run_orchestration:
acquire_spy.assert_called_once_with(
["limit-tag"], 1, timeout_seconds=None, create_if_missing=False
)

names, occupy, occupy_seconds = release_spy.call_args[0]
assert names == ["limit-tag"]
assert occupy == 1
assert occupy_seconds > 0
names, occupy, occupy_seconds = release_spy.call_args[0]
assert names == ["limit-tag"]
assert occupy == 1
assert occupy_seconds > 0
else:
assert acquire_spy.call_count == 0

async def test_tag_concurrency_does_not_create_limits(
self, enable_client_side_concurrency, prefect_client
self, enable_client_side_task_run_orchestration, prefect_client
):
@task(tags=["limit-tag"])
async def bar():
Expand All @@ -2169,27 +2151,15 @@ async def bar():
) as acquire_spy:
await bar()

acquire_spy.assert_called_once_with(
["limit-tag"], 1, timeout_seconds=None, create_if_missing=False
)

limits = await prefect_client.read_concurrency_limits(10, 0)
assert len(limits) == 0

def test_does_not_use_concurrency_limit_if_experiment_is_disabled(
self, disable_client_side_concurrency
):
@task(tags=["limit-tag"])
def bar():
return 42

with mock.patch(
"prefect.concurrency.sync._acquire_concurrency_slots",
wraps=_acquire_concurrency_slots,
) as acquire_spy:
bar()
if enable_client_side_task_run_orchestration:
acquire_spy.assert_called_once_with(
["limit-tag"], 1, timeout_seconds=None, create_if_missing=False
)

acquire_spy.assert_not_called()
limits = await prefect_client.read_concurrency_limits(10, 0)
assert len(limits) == 0
else:
assert acquire_spy.call_count == 0


class TestRunStateIsDenormalized:
Expand Down

0 comments on commit a206fbb

Please sign in to comment.