Skip to content

Commit

Permalink
Adds as_completed utility for PrefectFuture (#14641)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano authored Jul 26, 2024
1 parent 530521d commit 44de8f4
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 3 deletions.
20 changes: 20 additions & 0 deletions src/integrations/prefect-dask/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from prefect_dask import DaskTaskRunner

from prefect import flow, task
from prefect.futures import as_completed
from prefect.server.schemas.states import StateType
from prefect.states import State
from prefect.testing.fixtures import ( # noqa: F401
Expand Down Expand Up @@ -248,6 +249,25 @@ def test_flow():
assert bx.type == StateType.PENDING
assert cx.type == StateType.COMPLETED

def test_as_completed_yields_correct_order(self, task_runner):
@task
def sleep_task(seconds):
time.sleep(seconds)
return seconds

timings = [1, 5, 10]
with task_runner:
done_futures = []
futures = [
task_runner.submit(
sleep_task, parameters={"seconds": seconds}, wait_for=[]
)
for seconds in reversed(timings)
]
for future in as_completed(futures=futures):
done_futures.append(future.result())
assert done_futures == timings

async def test_wait_captures_exceptions_as_crashed_state(self, task_runner):
"""
Dask wraps the exception, interrupts will result in "Cancelled" tasks
Expand Down
11 changes: 11 additions & 0 deletions src/integrations/prefect-ray/prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ def result(
_result = run_coro_as_sync(_result)
return _result

def add_done_callback(self, fn):
if not self._final_state:

def call_with_self(future):
"""Call the callback with self as the argument, this is necessary to ensure we remove the future from the pending set"""
fn(self)

self._wrapped_future._on_completed(call_with_self)
return
fn(self)

def __del__(self):
if self._final_state:
return
Expand Down
19 changes: 19 additions & 0 deletions src/integrations/prefect-ray/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import prefect.task_engine
import tests
from prefect import flow, task
from prefect.futures import as_completed
from prefect.states import State, StateType
from prefect.testing.fixtures import ( # noqa: F401
hosted_api_server,
Expand Down Expand Up @@ -371,6 +372,24 @@ def test_flow():
assert bx.type == StateType.PENDING
assert cx.type == StateType.COMPLETED

def test_as_completed_yields_correct_order(self, task_runner):
@task
def task_a(seconds):
time.sleep(seconds)
return seconds

timings = [1, 5, 10]

@flow(version="test", task_runner=task_runner)
def test_flow():
done_futures = []
futures = [task_a.submit(seconds) for seconds in reversed(timings)]
for future in as_completed(futures=futures):
done_futures.append(future.result())
assert done_futures[-1] == timings[-1]

test_flow()

def get_sleep_time(self) -> float:
"""
Return an amount of time to sleep for concurrency tests.
Expand Down
86 changes: 84 additions & 2 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import collections
import concurrent.futures
import inspect
import threading
import uuid
from collections.abc import Iterator
from collections.abc import Generator, Iterator
from functools import partial
from typing import Any, Generic, List, Optional, Set, Union, cast
from typing import Any, Callable, Generic, List, Optional, Set, Union, cast

from typing_extensions import TypeVar

Expand Down Expand Up @@ -91,6 +92,16 @@ def result(
The result of the task run.
"""

@abc.abstractmethod
def add_done_callback(self, fn):
"""
Add a callback to be run when the future completes or is cancelled.
Args:
fn: A callable that will be called with this future as its only argument when the future completes or is cancelled.
"""
...


class PrefectWrappedFuture(PrefectFuture, abc.ABC, Generic[R, F]):
"""
Expand All @@ -106,6 +117,17 @@ def wrapped_future(self) -> F:
"""The underlying future object wrapped by this Prefect future"""
return self._wrapped_future

def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
if not self._final_state:

def call_with_self(future):
"""Call the callback with self as the argument, this is necessary to ensure we remove the future from the pending set"""
fn(self)

self._wrapped_future.add_done_callback(call_with_self)
return
fn(self)


class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future]):
"""
Expand Down Expand Up @@ -138,6 +160,7 @@ def result(

if isinstance(future_result, State):
self._final_state = future_result

else:
return future_result

Expand Down Expand Up @@ -172,6 +195,9 @@ class PrefectDistributedFuture(PrefectFuture[R]):
any task run scheduled in Prefect's API.
"""

done_callbacks: List[Callable[[PrefectFuture], None]] = []
waiter = None

@deprecated_async_method
def wait(self, timeout: Optional[float] = None) -> None:
return run_coro_as_sync(self.wait_async(timeout=timeout))
Expand Down Expand Up @@ -235,11 +261,27 @@ async def result_async(
raise_on_failure=raise_on_failure, fetch=True
)

def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
if self._final_state:
fn(self)
return
TaskRunWaiter.instance()
with get_client(sync_client=True) as client:
task_run = client.read_task_run(task_run_id=self._task_run_id)
if task_run.state.is_final():
self._final_state = task_run.state
fn(self)
return
TaskRunWaiter.add_done_callback(self._task_run_id, partial(fn, self))

def __eq__(self, other):
if not isinstance(other, PrefectDistributedFuture):
return False
return self.task_run_id == other.task_run_id

def __hash__(self):
return hash(self.task_run_id)


class PrefectFutureList(list, Iterator, Generic[F]):
"""
Expand Down Expand Up @@ -292,6 +334,46 @@ def result(
) from exc


def as_completed(
futures: List[PrefectFuture], timeout: Optional[float] = None
) -> Generator[PrefectFuture, None]:
unique_futures: Set[PrefectFuture] = set(futures)
total_futures = len(unique_futures)
try:
with timeout_context(timeout):
done = {f for f in unique_futures if f._final_state}
pending = unique_futures - done
yield from done

finished_event = threading.Event()
finished_lock = threading.Lock()
finished_futures = []

def add_to_done(future):
with finished_lock:
finished_futures.append(future)
finished_event.set()

for future in pending:
future.add_done_callback(add_to_done)

while pending:
finished_event.wait()
with finished_lock:
done = finished_futures
finished_futures = []
finished_event.clear()

for future in done:
pending.remove(future)
yield future

except TimeoutError:
raise TimeoutError(
"%d (of %d) futures unfinished" % (len(pending), total_futures)
)


DoneAndNotDoneFutures = collections.namedtuple("DoneAndNotDoneFutures", "done not_done")


Expand Down
25 changes: 24 additions & 1 deletion src/prefect/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import atexit
import threading
import uuid
from typing import Dict, Optional
from typing import Callable, Dict, Optional

import anyio
from cachetools import TTLCache
Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(self):
maxsize=10000, ttl=600
)
self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._observed_completed_task_runs_lock = threading.Lock()
self._completion_events_lock = threading.Lock()
Expand Down Expand Up @@ -135,6 +136,8 @@ async def _consume_events(self, consumer_started: asyncio.Event):
# so the waiter can wake up the waiting coroutine
if task_run_id in self._completion_events:
self._completion_events[task_run_id].set()
if task_run_id in self._completion_callbacks:
self._completion_callbacks[task_run_id]()
except Exception as exc:
self.logger.error(f"Error processing event: {exc}")

Expand Down Expand Up @@ -195,6 +198,26 @@ async def wait_for_task_run(
# Remove the event from the cache after it has been waited on
instance._completion_events.pop(task_run_id, None)

@classmethod
def add_done_callback(cls, task_run_id: uuid.UUID, callback):
"""
Add a callback to be called when a task run finishes.
Args:
task_run_id: The ID of the task run to wait for.
callback: The callback to call when the task run finishes.
"""
instance = cls.instance()
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
callback()
return

with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_callbacks[task_run_id] = callback

@classmethod
def instance(cls):
"""
Expand Down
99 changes: 99 additions & 0 deletions tests/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
PrefectFuture,
PrefectFutureList,
PrefectWrappedFuture,
as_completed,
resolve_futures_to_states,
wait,
)
from prefect.states import Completed, Failed
from prefect.task_engine import run_task_async, run_task_sync
from prefect.task_runners import ThreadPoolTaskRunner


class MockFuture(PrefectWrappedFuture):
Expand Down Expand Up @@ -55,6 +57,103 @@ def test_wait_with_timeout(self):
futures = wait(mock_futures, timeout=0.01)
assert futures.not_done == {mock_futures[-1]}

def test_as_completed(self):
mock_futures = [MockFuture(data=i) for i in range(5)]
for future in as_completed(mock_futures):
assert future.state.is_completed()

@pytest.mark.timeout(method="thread")
def test_as_completed_with_timeout(self):
mock_futures = [MockFuture(data=i) for i in range(5)]
hanging_future = Future()
mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), hanging_future))

with pytest.raises(TimeoutError) as exc_info:
for future in as_completed(mock_futures, timeout=0.01):
assert future.state.is_completed()

assert (
exc_info.value.args[0] == f"1 (of {len(mock_futures)}) futures unfinished"
)

# @pytest.mark.timeout(method="thread")
@pytest.mark.usefixtures("use_hosted_api_server")
def test_as_completed_yields_correct_order(self):
@task
def my_test_task(seconds):
import time

time.sleep(seconds)
return seconds

with ThreadPoolTaskRunner() as runner:
futures = []
timings = [1, 5, 10]

for i in reversed(timings):
parameters = {"seconds": i}
future = runner.submit(my_test_task, parameters)
future.parameters = parameters
futures.append(future)
results = []
for future in as_completed(futures):
results.append(future.result())
assert results == timings

def test_as_completed_timeout(self, caplog):
@task
def my_test_task(seconds):
import time

time.sleep(seconds)
return seconds

with ThreadPoolTaskRunner() as runner:
futures = []
timings = [1, 5, 10]

for i in reversed(timings):
parameters = {"seconds": i}
future = runner.submit(my_test_task, parameters)
future.parameters = parameters
futures.append(future)
results = []
with pytest.raises(TimeoutError) as exc_info:
for future in as_completed(futures, timeout=5):
results.append(future.result())
assert exc_info.value.args[0] == f"2 (of {len(timings)}) futures unfinished"

async def test_as_completed_yields_correct_order_dist(self, task_run):
@task
async def my_task(seconds):
import time

time.sleep(seconds)
return seconds

futures = []
timings = [1, 5, 10]
for i in reversed(timings):
task_run = await my_task.create_run(parameters={"seconds": i})
future = PrefectDistributedFuture(task_run_id=task_run.id)

futures.append(future)
asyncio.create_task(
run_task_async(
task=my_task,
task_run_id=future.task_run_id,
task_run=task_run,
parameters={"seconds": i},
return_type="state",
)
)
results = []
with pytest.raises(MissingResult):
for future in as_completed(futures):
results.append(future.result())

assert results == timings


class TestPrefectConcurrentFuture:
def test_wait_with_timeout(self):
Expand Down

0 comments on commit 44de8f4

Please sign in to comment.