Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds as_completed utility for PrefectFuture #14641

Merged
merged 18 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/integrations/prefect-dask/prefect_dask/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def result(
_result = run_coro_as_sync(_result)
return _result

def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved
if not self._final_state:

def call_with_self(future):
fn(self)

self._wrapped_future.add_done_callback(call_with_self)
return
fn(self)

def __del__(self):
if self._final_state or self._wrapped_future.done():
return
Expand Down
20 changes: 20 additions & 0 deletions src/integrations/prefect-dask/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from prefect_dask import DaskTaskRunner

from prefect import flow, task
from prefect.futures import as_completed
from prefect.server.schemas.states import StateType
from prefect.states import State
from prefect.testing.fixtures import ( # noqa: F401
Expand Down Expand Up @@ -248,6 +249,25 @@ def test_flow():
assert bx.type == StateType.PENDING
assert cx.type == StateType.COMPLETED

def test_as_completed_yields_correct_order(self, task_runner):
@task
def sleep_task(seconds):
time.sleep(seconds)
return seconds

timings = [1, 5, 10]
with task_runner:
done_futures = []
futures = [
task_runner.submit(
sleep_task, parameters={"seconds": seconds}, wait_for=[]
)
for seconds in timings
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved
]
for future in as_completed(futures=futures):
done_futures.append(future.result())
assert done_futures == timings

async def test_wait_captures_exceptions_as_crashed_state(self, task_runner):
"""
Dask wraps the exception, interrupts will result in "Cancelled" tasks
Expand Down
10 changes: 10 additions & 0 deletions src/integrations/prefect-ray/prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ def result(
_result = run_coro_as_sync(_result)
return _result

def add_done_callback(self, fn):
if not self._final_state:

def call_with_self(future):
fn(self)

self._wrapped_future.future().add_done_callback(call_with_self)
return
fn(self)

def __del__(self):
if self._final_state:
return
Expand Down
19 changes: 19 additions & 0 deletions src/integrations/prefect-ray/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import prefect.task_engine
import tests
from prefect import flow, task
from prefect.futures import as_completed
from prefect.states import State, StateType
from prefect.testing.fixtures import ( # noqa: F401
hosted_api_server,
Expand Down Expand Up @@ -371,6 +372,24 @@ def test_flow():
assert bx.type == StateType.PENDING
assert cx.type == StateType.COMPLETED

def test_as_completed_yields_correct_order(self, task_runner):
@task
def task_a(seconds):
time.sleep(seconds)
return seconds

timings = [1, 5, 10]

@flow(version="test", task_runner=task_runner)
def test_flow():
done_futures = []
futures = [task_a.submit(seconds) for seconds in timings]
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved
for future in as_completed(futures=futures):
done_futures.append(future.result())
assert done_futures == timings

test_flow(return_state=False)

def get_sleep_time(self) -> float:
"""
Return an amount of time to sleep for concurrency tests.
Expand Down
85 changes: 83 additions & 2 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import collections
import concurrent.futures
import inspect
import threading
import uuid
from collections.abc import Iterator
from collections.abc import Generator, Iterator
from functools import partial
from typing import Any, Generic, List, Optional, Set, Union, cast
from typing import Any, Callable, Generic, List, Optional, Set, Union, cast

from typing_extensions import TypeVar

Expand Down Expand Up @@ -91,6 +92,16 @@ def result(
The result of the task run.
"""

@abc.abstractmethod
def add_done_callback(self, fn):
...
"""
Add a callback to be run when the future completes or is cancelled.

Args:
fn: A callable that will be called with this future as its only argument when the future completes or is cancelled.
"""
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved


class PrefectWrappedFuture(PrefectFuture, abc.ABC, Generic[R, F]):
"""
Expand Down Expand Up @@ -138,6 +149,7 @@ def result(

if isinstance(future_result, State):
self._final_state = future_result

else:
return future_result

Expand All @@ -150,6 +162,16 @@ def result(
_result = run_coro_as_sync(_result)
return _result

def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved
if not self._final_state:

def call_with_self(future):
fn(self)

self._wrapped_future.add_done_callback(call_with_self)
return
fn(self)

def __del__(self):
if self._final_state or self._wrapped_future.done():
return
Expand All @@ -172,6 +194,9 @@ class PrefectDistributedFuture(PrefectFuture[R]):
any task run scheduled in Prefect's API.
"""

done_callbacks: List[Callable[[PrefectFuture], None]] = []
waiter = None

@deprecated_async_method
def wait(self, timeout: Optional[float] = None) -> None:
return run_coro_as_sync(self.wait_async(timeout=timeout))
Expand Down Expand Up @@ -235,11 +260,27 @@ async def result_async(
raise_on_failure=raise_on_failure, fetch=True
)

def add_done_callback(self, fn: Callable[[PrefectFuture], None]):
if self._final_state:
fn(self)
return
TaskRunWaiter.instance()
with get_client(sync_client=True) as client:
task_run = client.read_task_run(task_run_id=self._task_run_id)
if task_run.state.is_final():
self._final_state = task_run.state
fn(self)
return
TaskRunWaiter.add_done_callback(self._task_run_id, partial(fn, self))

def __eq__(self, other):
if not isinstance(other, PrefectDistributedFuture):
return False
return self.task_run_id == other.task_run_id

def __hash__(self):
return hash(self.task_run_id)


class PrefectFutureList(list, Iterator, Generic[F]):
"""
Expand Down Expand Up @@ -292,6 +333,46 @@ def result(
) from exc


def as_completed(
futures: List[PrefectFuture], timeout: Optional[float] = None
) -> Generator[PrefectFuture, None]:
Comment on lines +337 to +339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems pretty elegant to me!

unique_futures: Set[PrefectFuture] = set(futures)
total_futures = len(unique_futures)
try:
with timeout_context(timeout):
done = {f for f in unique_futures if f._final_state}
pending = unique_futures - done
yield from done

finished_event = threading.Event()
finished_lock = threading.Lock()
finished_futures = []

def add_to_done(future):
with finished_lock:
finished_futures.append(future)
finished_event.set()

for future in pending:
future.add_done_callback(add_to_done)

while pending:
finished_event.wait()
with finished_lock:
done = finished_futures
finished_futures = []
finished_event.clear()

for future in done:
pending.remove(future)
yield future

except TimeoutError:
raise TimeoutError(
"%d (of %d) futures unfinished" % (len(pending), total_futures)
)


DoneAndNotDoneFutures = collections.namedtuple("DoneAndNotDoneFutures", "done not_done")


Expand Down
25 changes: 24 additions & 1 deletion src/prefect/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import atexit
import threading
import uuid
from typing import Dict, Optional
from typing import Callable, Dict, Optional

import anyio
from cachetools import TTLCache
Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(self):
maxsize=10000, ttl=600
)
self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._observed_completed_task_runs_lock = threading.Lock()
self._completion_events_lock = threading.Lock()
Expand Down Expand Up @@ -135,6 +136,8 @@ async def _consume_events(self, consumer_started: asyncio.Event):
# so the waiter can wake up the waiting coroutine
if task_run_id in self._completion_events:
self._completion_events[task_run_id].set()
if task_run_id in self._completion_callbacks:
self._completion_callbacks[task_run_id]()
except Exception as exc:
self.logger.error(f"Error processing event: {exc}")

Expand Down Expand Up @@ -195,6 +198,26 @@ async def wait_for_task_run(
# Remove the event from the cache after it has been waited on
instance._completion_events.pop(task_run_id, None)

@classmethod
def add_done_callback(cls, task_run_id: uuid.UUID, callback):
"""
Add a callback to be called when a task run finishes.

Args:
task_run_id: The ID of the task run to wait for.
callback: The callback to call when the task run finishes.
"""
instance = cls.instance()
with instance._observed_completed_task_runs_lock:
if task_run_id in instance._observed_completed_task_runs:
callback()
return

with instance._completion_events_lock:
# Cache the event for the task run ID so the consumer can set it
# when the event is received
instance._completion_events[task_run_id] = callback
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def instance(cls):
"""
Expand Down
78 changes: 78 additions & 0 deletions tests/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
PrefectFuture,
PrefectFutureList,
PrefectWrappedFuture,
as_completed,
resolve_futures_to_states,
wait,
)
from prefect.states import Completed, Failed
from prefect.task_engine import run_task_async, run_task_sync
from prefect.task_runners import ThreadPoolTaskRunner


class MockFuture(PrefectWrappedFuture):
Expand All @@ -37,6 +39,9 @@ def result(
) -> Any:
return self._final_state.result()

def add_done_callback(self, fn):
return super().add_done_callback(fn)
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved


class TestUtilityFunctions:
def test_wait(self):
Expand All @@ -55,6 +60,79 @@ def test_wait_with_timeout(self):
futures = wait(mock_futures, timeout=0.01)
assert futures.not_done == {mock_futures[-1]}

def test_as_completed(self):
mock_futures = [MockFuture(data=i) for i in range(5)]
for future in as_completed(mock_futures):
assert future.state.is_completed()

@pytest.mark.timeout(method="thread")
def test_as_completed_with_timeout(self):
mock_futures = [MockFuture(data=i) for i in range(5)]
hanging_future = Future()
mock_futures.append(PrefectConcurrentFuture(uuid.uuid4(), hanging_future))

with pytest.raises(TimeoutError) as exc_info:
for future in as_completed(mock_futures, timeout=0.01):
assert future.state.is_completed()

assert (
exc_info.value.args[0] == f"1 (of {len(mock_futures)}) futures unfinished"
)

# @pytest.mark.timeout(method="thread")
@pytest.mark.usefixtures("use_hosted_api_server")
def test_as_completed_yields_correct_order(self):
@task
def my_test_task(seconds):
import time

time.sleep(seconds)
return seconds

with ThreadPoolTaskRunner() as runner:
futures = []
timings = [1, 5, 10]

for i in timings:
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved
parameters = {"seconds": i}
future = runner.submit(my_test_task, parameters)
future.parameters = parameters
futures.append(future)
results = []
for future in as_completed(futures):
results.append(future.result())
assert results == timings

async def test_as_completed_yields_correct_order_dist(self, task_run):
@task
async def my_task(seconds):
import time

time.sleep(seconds)
return seconds

futures = []
timings = [1, 5, 10]
for i in timings:
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved
task_run = await my_task.create_run(parameters={"seconds": i})
future = PrefectDistributedFuture(task_run_id=task_run.id)

futures.append(future)
asyncio.create_task(
run_task_async(
task=my_task,
task_run_id=future.task_run_id,
task_run=task_run,
parameters={"seconds": i},
return_type="state",
)
)
results = []
with pytest.raises(MissingResult):
for future in as_completed(futures):
results.append(future.result())
assert results == timings


class TestPrefectConcurrentFuture:
def test_wait_with_timeout(self):
Expand Down
Loading
Loading