diff --git a/src/prefect/futures.py b/src/prefect/futures.py index 8fa027878bcd..f10e3ca00f36 100644 --- a/src/prefect/futures.py +++ b/src/prefect/futures.py @@ -1,17 +1,15 @@ import abc import asyncio -import collections import concurrent.futures import threading import uuid from collections.abc import Generator, Iterator from functools import partial -from typing import Any, Callable, Generic, List, Optional, Set, Union, cast +from typing import Any, Callable, Generic, Optional, Union -from typing_extensions import TypeVar +from typing_extensions import NamedTuple, Self, TypeVar from prefect.client.orchestration import get_client -from prefect.client.schemas.objects import TaskRun from prefect.exceptions import ObjectNotFound from prefect.logging.loggers import get_logger, get_run_logger from prefect.states import Pending, State @@ -50,7 +48,7 @@ def state(self) -> State: return self._final_state client = get_client(sync_client=True) try: - task_run = cast(TaskRun, client.read_task_run(task_run_id=self.task_run_id)) + task_run = client.read_task_run(task_run_id=self.task_run_id) except ObjectNotFound: # We'll be optimistic and assume this task will eventually start # TODO: Consider using task run events to wait for the task to start @@ -92,7 +90,7 @@ def result( """ @abc.abstractmethod - def add_done_callback(self, fn): + def add_done_callback(self, fn: Callable[["PrefectFuture[R]"], None]): """ Add a callback to be run when the future completes or is cancelled. @@ -102,13 +100,17 @@ def add_done_callback(self, fn): ... -class PrefectWrappedFuture(PrefectFuture, abc.ABC, Generic[R, F]): +class PrefectWrappedFuture(PrefectFuture[R], abc.ABC, Generic[R, F]): """ A Prefect future that wraps another future object. + + Type Parameters: + R: The return type of the future + F: The type of the wrapped future """ def __init__(self, task_run_id: uuid.UUID, wrapped_future: F): - self._wrapped_future = wrapped_future + self._wrapped_future: F = wrapped_future super().__init__(task_run_id) @property @@ -116,10 +118,11 @@ def wrapped_future(self) -> F: """The underlying future object wrapped by this Prefect future""" return self._wrapped_future - def add_done_callback(self, fn: Callable[[PrefectFuture[R]], None]): + def add_done_callback(self, fn: Callable[[PrefectFuture[R]], None]) -> None: + """Add a callback to be executed when the future completes.""" if not self._final_state: - def call_with_self(future): + def call_with_self(future: F): """Call the callback with self as the argument, this is necessary to ensure we remove the future from the pending set""" fn(self) @@ -128,7 +131,7 @@ def call_with_self(future): fn(self) -class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future]): +class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future[R]]): """ A Prefect future that wraps a concurrent.futures.Future. This future is used when the task run is submitted to a ThreadPoolExecutor. @@ -193,7 +196,7 @@ class PrefectDistributedFuture(PrefectFuture[R]): any task run scheduled in Prefect's API. """ - done_callbacks: List[Callable[[PrefectFuture[R]], None]] = [] + done_callbacks: list[Callable[[PrefectFuture[R]], None]] = [] waiter = None def wait(self, timeout: Optional[float] = None) -> None: @@ -270,7 +273,7 @@ def add_done_callback(self, fn: Callable[[PrefectFuture[R]], None]): return TaskRunWaiter.add_done_callback(self._task_run_id, partial(fn, self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, PrefectDistributedFuture): return False return self.task_run_id == other.task_run_id @@ -279,7 +282,7 @@ def __hash__(self): return hash(self.task_run_id) -class PrefectFutureList(list, Iterator, Generic[F]): +class PrefectFutureList(list[PrefectFuture[R]], Iterator[PrefectFuture[R]]): """ A list of Prefect futures. @@ -298,10 +301,10 @@ def wait(self, timeout: Optional[float] = None) -> None: wait(self, timeout=timeout) def result( - self: "PrefectFutureList[R]", + self: Self, timeout: Optional[float] = None, raise_on_failure: bool = True, - ) -> List[R]: + ) -> list[R]: """ Get the results of all task runs associated with the futures in the list. @@ -331,21 +334,22 @@ def result( def as_completed( - futures: List[PrefectFuture[R]], timeout: Optional[float] = None + futures: list[PrefectFuture[R]], timeout: Optional[float] = None ) -> Generator[PrefectFuture[R], None]: - unique_futures: Set[PrefectFuture[R]] = set(futures) + unique_futures: set[PrefectFuture[R]] = set(futures) total_futures = len(unique_futures) + pending = unique_futures try: with timeout_context(timeout): - done = {f for f in unique_futures if f._final_state} + done = {f for f in unique_futures if f._final_state} # type: ignore[privateUsage] pending = unique_futures - done yield from done finished_event = threading.Event() finished_lock = threading.Lock() - finished_futures = [] + finished_futures: list[PrefectFuture[R]] = [] - def add_to_done(future): + def add_to_done(future: PrefectFuture[R]): with finished_lock: finished_futures.append(future) finished_event.set() @@ -370,10 +374,19 @@ def add_to_done(future): ) -DoneAndNotDoneFutures = collections.namedtuple("DoneAndNotDoneFutures", "done not_done") +class DoneAndNotDoneFutures(NamedTuple, Generic[R]): + """A named 2-tuple of sets. + + multiple inheritance supported in 3.11+, use typing_extensions.NamedTuple + """ + + done: set[PrefectFuture[R]] + not_done: set[PrefectFuture[R]] -def wait(futures: List[PrefectFuture[R]], timeout=None) -> DoneAndNotDoneFutures: +def wait( + futures: list[PrefectFuture[R]], timeout: Optional[float] = None +) -> DoneAndNotDoneFutures[R]: """ Wait for the futures in the given sequence to complete. @@ -431,9 +444,11 @@ def resolve_futures_to_states( Unsupported object types will be returned without modification. """ - futures: Set[PrefectFuture[R]] = set() + futures: set[PrefectFuture[R]] = set() - def _collect_futures(futures, expr, context): + def _collect_futures( + futures: set[PrefectFuture[R]], expr: Any, context: Any + ) -> Union[PrefectFuture[R], Any]: # Expressions inside quotes should not be traversed if isinstance(context.get("annotation"), quote): raise StopVisiting() @@ -455,14 +470,14 @@ def _collect_futures(futures, expr, context): return expr # Get final states for each future - states = [] + states: list[State] = [] for future in futures: future.wait() states.append(future.state) states_by_future = dict(zip(futures, states)) - def replace_futures_with_states(expr, context): + def replace_futures_with_states(expr: Any, context: Any) -> Any: # Expressions inside quotes should not be modified if isinstance(context.get("annotation"), quote): raise StopVisiting()