Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor some concurrency utilities to be sync_compatible #15273

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/prefect/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from prefect.client.orchestration import get_client
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
from prefect.utilities.asyncutils import sync_compatible

from .context import ConcurrencyContext
from .events import (
Expand Down Expand Up @@ -134,6 +135,7 @@ async def rate_limit(
_emit_concurrency_acquisition_events(limits, occupy)


@sync_compatible
async def _acquire_concurrency_slots(
names: List[str],
slots: int,
Expand Down Expand Up @@ -161,6 +163,7 @@ async def _acquire_concurrency_slots(
return _response_to_minimal_concurrency_limit_response(response_or_exception)


@sync_compatible
async def _release_concurrency_slots(
names: List[str], slots: int, occupancy_seconds: float
) -> List[MinimalConcurrencyLimitResponse]:
Expand Down
29 changes: 6 additions & 23 deletions src/prefect/concurrency/sync.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from contextlib import contextmanager
from typing import (
Any,
Awaitable,
Callable,
Generator,
List,
Optional,
Expand All @@ -19,8 +16,6 @@
# pendulum < 3
from pendulum.period import Period as Interval # type: ignore

from prefect._internal.concurrency.api import create_call, from_sync
from prefect._internal.concurrency.event_loop import get_running_loop
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse

from .asyncio import (
Expand Down Expand Up @@ -76,13 +71,13 @@ def main():

names = names if isinstance(names, list) else [names]

limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
_acquire_concurrency_slots,
limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots(
names,
occupy,
timeout_seconds=timeout_seconds,
create_if_missing=create_if_missing,
max_retries=max_retries,
_sync=True,
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, occupy)
Expand All @@ -91,11 +86,11 @@ def main():
yield
finally:
occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time)
_call_async_function_from_sync(
_release_concurrency_slots,
_release_concurrency_slots(
names,
occupy,
occupancy_period.total_seconds(),
_sync=True,
)
_emit_concurrency_release_events(limits, occupy, emitted_events)

Expand All @@ -122,24 +117,12 @@ def rate_limit(

names = names if isinstance(names, list) else [names]

limits = _call_async_function_from_sync(
_acquire_concurrency_slots,
limits = _acquire_concurrency_slots(
names,
occupy,
mode="rate_limit",
timeout_seconds=timeout_seconds,
create_if_missing=create_if_missing,
_sync=True,
)
_emit_concurrency_acquisition_events(limits, occupy)


def _call_async_function_from_sync(
fn: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
) -> T:
loop = get_running_loop()
call = create_call(fn, *args, **kwargs)

if loop is not None:
return from_sync.call_soon_in_loop_thread(call).result()
else:
return call() # type: ignore [return-value]
3 changes: 3 additions & 0 deletions src/prefect/concurrency/v1/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pendulum.period import Period as Interval # type: ignore

from prefect.client.orchestration import get_client
from prefect.utilities.asyncutils import sync_compatible

from .context import ConcurrencyContext
from .events import (
Expand Down Expand Up @@ -98,6 +99,7 @@ async def main():
_emit_concurrency_release_events(limits, emitted_events, task_run_id)


@sync_compatible
async def _acquire_concurrency_slots(
names: List[str],
task_run_id: UUID,
Expand All @@ -120,6 +122,7 @@ async def _acquire_concurrency_slots(
return _response_to_concurrency_limit_response(response_or_exception)


@sync_compatible
async def _release_concurrency_slots(
names: List[str],
task_run_id: UUID,
Expand Down
9 changes: 4 additions & 5 deletions src/prefect/concurrency/v1/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pendulum

from ...client.schemas.responses import MinimalConcurrencyLimitResponse
from ..sync import _call_async_function_from_sync

try:
from pendulum import Interval
Expand Down Expand Up @@ -70,11 +69,11 @@ def main():

names = names if isinstance(names, list) else [names]

limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
_acquire_concurrency_slots,
limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots(
names,
timeout_seconds=timeout_seconds,
task_run_id=task_run_id,
_sync=True,
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id)
Expand All @@ -83,10 +82,10 @@ def main():
yield
finally:
occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time)
_call_async_function_from_sync(
_release_concurrency_slots,
_release_concurrency_slots(
names,
task_run_id,
occupancy_period.total_seconds(),
_sync=True,
)
_emit_concurrency_release_events(limits, emitted_events, task_run_id)
50 changes: 0 additions & 50 deletions tests/concurrency/test_concurrency_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,6 @@ async def my_flow():
assert executed


@pytest.mark.skip(
reason="New engine does not support calling async from sync",
)
def test_concurrency_mixed_sync_async(
concurrency_limit: ConcurrencyLimitV2,
):
executed = False

@task
async def resource_heavy():
nonlocal executed
async with concurrency("test", occupy=1):
executed = True

@flow
def my_flow():
resource_heavy()

assert not executed

my_flow()

assert executed


async def test_concurrency_emits_events(
concurrency_limit: ConcurrencyLimitV2,
other_concurrency_limit: ConcurrencyLimitV2,
Expand Down Expand Up @@ -275,31 +250,6 @@ async def my_flow():
assert executed


@pytest.mark.skip(
reason="New engine does not support calling async from sync",
)
def test_rate_limit_mixed_sync_async(
concurrency_limit_with_decay: ConcurrencyLimitV2,
):
executed = False

@task
async def resource_heavy():
nonlocal executed
await rate_limit("test", occupy=1)
executed = True

@flow
def my_flow():
resource_heavy()

assert not executed

my_flow()

assert executed


async def test_rate_limit_emits_events(
concurrency_limit_with_decay: ConcurrencyLimitV2,
other_concurrency_limit_with_decay: ConcurrencyLimitV2,
Expand Down
2 changes: 2 additions & 0 deletions tests/concurrency/test_concurrency_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def resource_heavy():
timeout_seconds=None,
create_if_missing=True,
max_retries=None,
_sync=True,
)

# On release we calculate how many seconds the slots were occupied
Expand Down Expand Up @@ -253,6 +254,7 @@ def resource_heavy():
mode="rate_limit",
timeout_seconds=None,
create_if_missing=True,
_sync=True,
)

# When used as a rate limit concurrency slots are not explicitly
Expand Down
2 changes: 1 addition & 1 deletion tests/concurrency/v1/test_concurrency_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def resource_heavy():
resource_heavy()

acquire_spy.assert_called_once_with(
["test"], timeout_seconds=None, task_run_id=task_run_id
["test"], timeout_seconds=None, task_run_id=task_run_id, _sync=True
)

names, _task_run_id, occupy_seconds = release_spy.call_args[0]
Expand Down
5 changes: 4 additions & 1 deletion tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2315,7 +2315,10 @@ def bar():
bar()

acquire_spy.assert_called_once_with(
["limit-tag"], task_run_id=task_run_id, timeout_seconds=None
["limit-tag"],
task_run_id=task_run_id,
timeout_seconds=None,
_sync=True,
)

names, _task_run_id, occupy_seconds = release_spy.call_args[0]
Expand Down
Loading