Skip to content

Commit

Permalink
[typing] update prefect.futures (#16381)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Dec 20, 2024
1 parent 8bc0923 commit 6f5d463
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import abc
import asyncio
import collections
import concurrent.futures
import threading
import uuid
from collections.abc import Generator, Iterator
from functools import partial
from typing import Any, Callable, Generic, List, Optional, Set, Union, cast
from typing import Any, Callable, Generic, Optional, Union

from typing_extensions import TypeVar
from typing_extensions import NamedTuple, Self, TypeVar

from prefect.client.orchestration import get_client
from prefect.client.schemas.objects import TaskRun
from prefect.exceptions import ObjectNotFound
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.states import Pending, State
Expand Down Expand Up @@ -50,7 +48,7 @@ def state(self) -> State:
return self._final_state
client = get_client(sync_client=True)
try:
task_run = cast(TaskRun, client.read_task_run(task_run_id=self.task_run_id))
task_run = client.read_task_run(task_run_id=self.task_run_id)
except ObjectNotFound:
# We'll be optimistic and assume this task will eventually start
# TODO: Consider using task run events to wait for the task to start
Expand Down Expand Up @@ -92,7 +90,7 @@ def result(
"""

@abc.abstractmethod
def add_done_callback(self, fn):
def add_done_callback(self, fn: Callable[["PrefectFuture[R]"], None]):
"""
Add a callback to be run when the future completes or is cancelled.
Expand All @@ -102,24 +100,29 @@ def add_done_callback(self, fn):
...


class PrefectWrappedFuture(PrefectFuture, abc.ABC, Generic[R, F]):
class PrefectWrappedFuture(PrefectFuture[R], abc.ABC, Generic[R, F]):
"""
A Prefect future that wraps another future object.
Type Parameters:
R: The return type of the future
F: The type of the wrapped future
"""

def __init__(self, task_run_id: uuid.UUID, wrapped_future: F):
self._wrapped_future = wrapped_future
self._wrapped_future: F = wrapped_future
super().__init__(task_run_id)

@property
def wrapped_future(self) -> F:
"""The underlying future object wrapped by this Prefect future"""
return self._wrapped_future

def add_done_callback(self, fn: Callable[[PrefectFuture[R]], None]):
def add_done_callback(self, fn: Callable[[PrefectFuture[R]], None]) -> None:
"""Add a callback to be executed when the future completes."""
if not self._final_state:

def call_with_self(future):
def call_with_self(future: F):
"""Call the callback with self as the argument, this is necessary to ensure we remove the future from the pending set"""
fn(self)

Expand All @@ -128,7 +131,7 @@ def call_with_self(future):
fn(self)


class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future]):
class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future[R]]):
"""
A Prefect future that wraps a concurrent.futures.Future. This future is used
when the task run is submitted to a ThreadPoolExecutor.
Expand Down Expand Up @@ -193,7 +196,7 @@ class PrefectDistributedFuture(PrefectFuture[R]):
any task run scheduled in Prefect's API.
"""

done_callbacks: List[Callable[[PrefectFuture[R]], None]] = []
done_callbacks: list[Callable[[PrefectFuture[R]], None]] = []
waiter = None

def wait(self, timeout: Optional[float] = None) -> None:
Expand Down Expand Up @@ -270,7 +273,7 @@ def add_done_callback(self, fn: Callable[[PrefectFuture[R]], None]):
return
TaskRunWaiter.add_done_callback(self._task_run_id, partial(fn, self))

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, PrefectDistributedFuture):
return False
return self.task_run_id == other.task_run_id
Expand All @@ -279,7 +282,7 @@ def __hash__(self):
return hash(self.task_run_id)


class PrefectFutureList(list, Iterator, Generic[F]):
class PrefectFutureList(list[PrefectFuture[R]], Iterator[PrefectFuture[R]]):
"""
A list of Prefect futures.
Expand All @@ -298,10 +301,10 @@ def wait(self, timeout: Optional[float] = None) -> None:
wait(self, timeout=timeout)

def result(
self: "PrefectFutureList[R]",
self: Self,
timeout: Optional[float] = None,
raise_on_failure: bool = True,
) -> List[R]:
) -> list[R]:
"""
Get the results of all task runs associated with the futures in the list.
Expand Down Expand Up @@ -331,21 +334,22 @@ def result(


def as_completed(
futures: List[PrefectFuture[R]], timeout: Optional[float] = None
futures: list[PrefectFuture[R]], timeout: Optional[float] = None
) -> Generator[PrefectFuture[R], None]:
unique_futures: Set[PrefectFuture[R]] = set(futures)
unique_futures: set[PrefectFuture[R]] = set(futures)
total_futures = len(unique_futures)
pending = unique_futures
try:
with timeout_context(timeout):
done = {f for f in unique_futures if f._final_state}
done = {f for f in unique_futures if f._final_state} # type: ignore[privateUsage]
pending = unique_futures - done
yield from done

finished_event = threading.Event()
finished_lock = threading.Lock()
finished_futures = []
finished_futures: list[PrefectFuture[R]] = []

def add_to_done(future):
def add_to_done(future: PrefectFuture[R]):
with finished_lock:
finished_futures.append(future)
finished_event.set()
Expand All @@ -370,10 +374,19 @@ def add_to_done(future):
)


DoneAndNotDoneFutures = collections.namedtuple("DoneAndNotDoneFutures", "done not_done")
class DoneAndNotDoneFutures(NamedTuple, Generic[R]):
"""A named 2-tuple of sets.
multiple inheritance supported in 3.11+, use typing_extensions.NamedTuple
"""

done: set[PrefectFuture[R]]
not_done: set[PrefectFuture[R]]


def wait(futures: List[PrefectFuture[R]], timeout=None) -> DoneAndNotDoneFutures:
def wait(
futures: list[PrefectFuture[R]], timeout: Optional[float] = None
) -> DoneAndNotDoneFutures[R]:
"""
Wait for the futures in the given sequence to complete.
Expand Down Expand Up @@ -431,9 +444,11 @@ def resolve_futures_to_states(
Unsupported object types will be returned without modification.
"""
futures: Set[PrefectFuture[R]] = set()
futures: set[PrefectFuture[R]] = set()

def _collect_futures(futures, expr, context):
def _collect_futures(
futures: set[PrefectFuture[R]], expr: Any, context: Any
) -> Union[PrefectFuture[R], Any]:
# Expressions inside quotes should not be traversed
if isinstance(context.get("annotation"), quote):
raise StopVisiting()
Expand All @@ -455,14 +470,14 @@ def _collect_futures(futures, expr, context):
return expr

# Get final states for each future
states = []
states: list[State] = []
for future in futures:
future.wait()
states.append(future.state)

states_by_future = dict(zip(futures, states))

def replace_futures_with_states(expr, context):
def replace_futures_with_states(expr: Any, context: Any) -> Any:
# Expressions inside quotes should not be modified
if isinstance(context.get("annotation"), quote):
raise StopVisiting()
Expand Down

0 comments on commit 6f5d463

Please sign in to comment.