diff --git a/src/prefect/_internal/retries.py b/src/prefect/_internal/retries.py index e7c02f4b2ef1..552b40b9c4a7 100644 --- a/src/prefect/_internal/retries.py +++ b/src/prefect/_internal/retries.py @@ -29,7 +29,7 @@ def retry_async_fn( retry_on_exceptions: tuple[type[Exception], ...] = (Exception,), operation_name: Optional[str] = None, ) -> Callable[ - [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, Optional[R]]] + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] ]: """A decorator for retrying an async function. @@ -48,9 +48,9 @@ def retry_async_fn( def decorator( func: Callable[P, Coroutine[Any, Any, R]], - ) -> Callable[P, Coroutine[Any, Any, Optional[R]]]: + ) -> Callable[P, Coroutine[Any, Any, R]]: @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: name = operation_name or func.__name__ for attempt in range(max_attempts): try: @@ -67,6 +67,9 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: f"Retrying in {delay:.2f} seconds..." ) await asyncio.sleep(delay) + # Technically unreachable, but this raise helps pyright know that this function + # won't return None. + raise Exception(f"Function {name!r} failed after {max_attempts} attempts") return wrapper diff --git a/src/prefect/server/events/schemas/automations.py b/src/prefect/server/events/schemas/automations.py index 8f54426990ab..d1c118c6cefa 100644 --- a/src/prefect/server/events/schemas/automations.py +++ b/src/prefect/server/events/schemas/automations.py @@ -317,7 +317,7 @@ class EventTrigger(ResourceTrigger): @model_validator(mode="before") @classmethod def enforce_minimum_within_for_proactive_triggers( - cls, data: Dict[str, Any] + cls, data: Dict[str, Any] | Any ) -> Dict[str, Any]: if not isinstance(data, dict): return data @@ -342,7 +342,7 @@ def enforce_minimum_within_for_proactive_triggers( return data - def covers(self, event: ReceivedEvent): + def covers(self, event: ReceivedEvent) -> bool: if not self.covers_resources(event.resource, event.related): return False @@ -356,10 +356,10 @@ def immediate(self) -> bool: """Does this reactive trigger fire immediately for all events?""" return self.posture == Posture.Reactive and self.within == timedelta(0) - _event_pattern: Optional[re.Pattern] = PrivateAttr(None) + _event_pattern: Optional[re.Pattern[str]] = PrivateAttr(None) @property - def event_pattern(self) -> re.Pattern: + def event_pattern(self) -> re.Pattern[str]: """A regular expression which may be evaluated against any event string to determine if this trigger would be interested in the event""" if self._event_pattern: @@ -625,13 +625,15 @@ class Firing(PrefectBaseModel): id: UUID = Field(default_factory=uuid4) - trigger: ServerTriggerTypes = Field(..., description="The trigger that is firing") + trigger: Union[ServerTriggerTypes, CompositeTrigger] = Field( + default=..., description="The trigger that is firing" + ) trigger_states: Set[TriggerState] = Field( - ..., + default=..., description="The state changes represented by this Firing", ) triggered: DateTime = Field( - ..., + default=..., description=( "The time at which this trigger fired, which may differ from the " "occurred time of the associated event (as events processing may always " @@ -654,7 +656,7 @@ class Firing(PrefectBaseModel): ), ) triggering_event: Optional[ReceivedEvent] = Field( - None, + default=None, description=( "The most recent event associated with this Firing. This may be the " "event that caused the trigger to fire (for Reactive triggers), or the " @@ -662,8 +664,8 @@ class Firing(PrefectBaseModel): "change event (for a Metric trigger)." ), ) - triggering_value: Any = Field( - None, + triggering_value: Optional[Any] = Field( + default=None, description=( "A value associated with this firing of a trigger. Maybe used to " "convey additional information at the point of firing, like the value of " diff --git a/src/prefect/server/events/services/actions.py b/src/prefect/server/events/services/actions.py index 36d7b77a8968..0a2b38b88d73 100644 --- a/src/prefect/server/events/services/actions.py +++ b/src/prefect/server/events/services/actions.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import asyncio -from typing import Optional +from typing import TYPE_CHECKING, NoReturn from prefect.logging import get_logger from prefect.server.events import actions -from prefect.server.utilities.messaging import create_consumer +from prefect.server.utilities.messaging import Consumer, create_consumer + +if TYPE_CHECKING: + import logging -logger = get_logger(__name__) +logger: "logging.Logger" = get_logger(__name__) class Actions: @@ -13,11 +18,11 @@ class Actions: name: str = "Actions" - consumer_task: Optional[asyncio.Task] = None + consumer_task: asyncio.Task[None] | None = None - async def start(self): + async def start(self) -> NoReturn: assert self.consumer_task is None, "Actions already started" - self.consumer = create_consumer("actions") + self.consumer: Consumer = create_consumer("actions") async with actions.consumer() as handler: self.consumer_task = asyncio.create_task(self.consumer.run(handler)) @@ -28,7 +33,7 @@ async def start(self): except asyncio.CancelledError: pass - async def stop(self): + async def stop(self) -> None: assert self.consumer_task is not None, "Actions not started" self.consumer_task.cancel() try: diff --git a/src/prefect/server/events/services/event_logger.py b/src/prefect/server/events/services/event_logger.py index b4feb8964e29..80ddc17a0a89 100644 --- a/src/prefect/server/events/services/event_logger.py +++ b/src/prefect/server/events/services/event_logger.py @@ -1,14 +1,19 @@ +from __future__ import annotations + import asyncio -from typing import Optional +from typing import TYPE_CHECKING, NoReturn import pendulum import rich from prefect.logging import get_logger from prefect.server.events.schemas.events import ReceivedEvent -from prefect.server.utilities.messaging import Message, create_consumer +from prefect.server.utilities.messaging import Consumer, Message, create_consumer + +if TYPE_CHECKING: + import logging -logger = get_logger(__name__) +logger: "logging.Logger" = get_logger(__name__) class EventLogger: @@ -16,11 +21,11 @@ class EventLogger: name: str = "EventLogger" - consumer_task: Optional[asyncio.Task] = None + consumer_task: asyncio.Task[None] | None = None - async def start(self): + async def start(self) -> NoReturn: assert self.consumer_task is None, "Logger already started" - self.consumer = create_consumer("events") + self.consumer: Consumer = create_consumer("events") console = rich.console.Console() @@ -46,7 +51,7 @@ async def handler(message: Message): except asyncio.CancelledError: pass - async def stop(self): + async def stop(self) -> None: assert self.consumer_task is not None, "Logger not started" self.consumer_task.cancel() try: diff --git a/src/prefect/server/events/services/event_persister.py b/src/prefect/server/events/services/event_persister.py index 2d810f37c218..fe02741088b9 100644 --- a/src/prefect/server/events/services/event_persister.py +++ b/src/prefect/server/events/services/event_persister.py @@ -3,10 +3,12 @@ storage as fast as it can. Never gets tired. """ +from __future__ import annotations + import asyncio from contextlib import asynccontextmanager from datetime import timedelta -from typing import AsyncGenerator, List, Optional +from typing import TYPE_CHECKING, AsyncGenerator, List, NoReturn import pendulum import sqlalchemy as sa @@ -15,14 +17,22 @@ from prefect.server.database import provide_database_interface from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.events.storage.database import write_events -from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer +from prefect.server.utilities.messaging import ( + Consumer, + Message, + MessageHandler, + create_consumer, +) from prefect.settings import ( PREFECT_API_SERVICES_EVENT_PERSISTER_BATCH_SIZE, PREFECT_API_SERVICES_EVENT_PERSISTER_FLUSH_INTERVAL, PREFECT_EVENTS_RETENTION_PERIOD, ) -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) class EventPersister: @@ -30,10 +40,10 @@ class EventPersister: name: str = "EventLogger" - consumer_task: Optional[asyncio.Task] = None + consumer_task: asyncio.Task[None] | None = None def __init__(self): - self._started_event: Optional[asyncio.Event] = None + self._started_event: asyncio.Event | None = None @property def started_event(self) -> asyncio.Event: @@ -45,9 +55,9 @@ def started_event(self) -> asyncio.Event: def started_event(self, value: asyncio.Event) -> None: self._started_event = value - async def start(self): + async def start(self) -> NoReturn: assert self.consumer_task is None, "Event persister already started" - self.consumer = create_consumer("events") + self.consumer: Consumer = create_consumer("events") async with create_handler( batch_size=PREFECT_API_SERVICES_EVENT_PERSISTER_BATCH_SIZE.value(), @@ -64,7 +74,7 @@ async def start(self): except asyncio.CancelledError: pass - async def stop(self): + async def stop(self) -> None: assert self.consumer_task is not None, "Event persister not started" self.consumer_task.cancel() try: diff --git a/src/prefect/server/events/services/triggers.py b/src/prefect/server/events/services/triggers.py index 1cba60476a3a..81a536f8f3e1 100644 --- a/src/prefect/server/events/services/triggers.py +++ b/src/prefect/server/events/services/triggers.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import asyncio -from typing import Optional +from typing import TYPE_CHECKING, Any, NoReturn, Optional from prefect.logging import get_logger from prefect.server.events import triggers from prefect.server.services.loop_service import LoopService -from prefect.server.utilities.messaging import create_consumer +from prefect.server.utilities.messaging import Consumer, create_consumer from prefect.settings import PREFECT_EVENTS_PROACTIVE_GRANULARITY -logger = get_logger(__name__) +if TYPE_CHECKING: + import logging + +logger: "logging.Logger" = get_logger(__name__) class ReactiveTriggers: @@ -15,11 +20,11 @@ class ReactiveTriggers: name: str = "ReactiveTriggers" - consumer_task: Optional[asyncio.Task] = None + consumer_task: asyncio.Task[None] | None = None - async def start(self): + async def start(self) -> NoReturn: assert self.consumer_task is None, "Reactive triggers already started" - self.consumer = create_consumer("events") + self.consumer: Consumer = create_consumer("events") async with triggers.consumer() as handler: self.consumer_task = asyncio.create_task(self.consumer.run(handler)) @@ -30,7 +35,7 @@ async def start(self): except asyncio.CancelledError: pass - async def stop(self): + async def stop(self) -> None: assert self.consumer_task is not None, "Reactive triggers not started" self.consumer_task.cancel() try: @@ -43,7 +48,7 @@ async def stop(self): class ProactiveTriggers(LoopService): - def __init__(self, loop_seconds: Optional[float] = None, **kwargs): + def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any): super().__init__( loop_seconds=( loop_seconds @@ -52,5 +57,5 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs): **kwargs, ) - async def run_once(self): + async def run_once(self) -> None: await triggers.evaluate_proactive_triggers() diff --git a/src/prefect/server/events/triggers.py b/src/prefect/server/events/triggers.py index ae25b8bd95dc..c7140aa93045 100644 --- a/src/prefect/server/events/triggers.py +++ b/src/prefect/server/events/triggers.py @@ -55,13 +55,14 @@ from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.utilities.messaging import Message, MessageHandler from prefect.settings import PREFECT_EVENTS_EXPIRED_BUCKET_BUFFER -from prefect.types import DateTime if TYPE_CHECKING: + import logging + from prefect.server.database.orm_models import ORMAutomationBucket -logger = get_logger(__name__) +logger: "logging.Logger" = get_logger(__name__) AutomationID: TypeAlias = UUID TriggerID: TypeAlias = UUID @@ -74,9 +75,9 @@ async def evaluate( session: AsyncSession, trigger: EventTrigger, bucket: "ORMAutomationBucket", - now: DateTime, + now: pendulum.DateTime, triggering_event: Optional[ReceivedEvent], -) -> Optional["ORMAutomationBucket"]: +) -> "ORMAutomationBucket | None": """Evaluates an Automation, either triggered by a specific event or proactively on a time interval. Evaluating a Automation updates the associated counters for each automation, and will fire the associated action if it has met the threshold.""" @@ -249,7 +250,7 @@ async def evaluate( bucket = await start_new_bucket( session, trigger, - bucketing_key=bucket.bucketing_key, + bucketing_key=tuple(bucket.bucketing_key), start=start, end=end, count=0, @@ -259,7 +260,7 @@ async def evaluate( return await start_new_bucket( session, trigger, - bucketing_key=bucket.bucketing_key, + bucketing_key=tuple(bucket.bucketing_key), start=start, end=end, count=count, @@ -321,8 +322,10 @@ async def evaluate_composite_trigger(session: AsyncSession, firing: Firing): # what the current state of the world is. If we have enough firings, we'll # fire the parent trigger. await upsert_child_firing(session, firing) - firings = [cf.child_firing for cf in await get_child_firings(session, trigger)] - firing_ids = {f.id for f in firings} + firings: list[Firing] = [ + cf.child_firing for cf in await get_child_firings(session, trigger) + ] + firing_ids: set[UUID] = {f.id for f in firings} # If our current firing no longer exists when we read firings # another firing has superseded it, and we should defer to that one @@ -345,7 +348,8 @@ async def evaluate_composite_trigger(session: AsyncSession, firing: Firing): ) # clear by firing id - await clear_child_firings(session, trigger, firing_ids=firing_ids) + await clear_child_firings(session, trigger, firing_ids=list(firing_ids)) + await fire( session, Firing( @@ -581,12 +585,12 @@ async def reactive_evaluation(event: ReceivedEvent, depth: int = 0): # retry on operational errors to account for db flakiness with sqlite @retry_async_fn(max_attempts=3, retry_on_exceptions=(sa.exc.OperationalError,)) -async def get_lost_followers(): +async def get_lost_followers() -> List[ReceivedEvent]: """Get followers that have been sitting around longer than our lookback""" return await causal_ordering().get_lost_followers() -async def periodic_evaluation(now: DateTime): +async def periodic_evaluation(now: pendulum.DateTime): """Periodic tasks that should be run regularly, but not as often as every event""" offset = await get_events_clock_offset() as_of = now + timedelta(seconds=offset) @@ -626,7 +630,7 @@ async def evaluate_periodically(periodic_granularity: timedelta): # account and workspace automations_by_id: Dict[UUID, Automation] = {} triggers: Dict[TriggerID, EventTrigger] = {} -next_proactive_runs: Dict[TriggerID, DateTime] = {} +next_proactive_runs: Dict[TriggerID, pendulum.DateTime] = {} # This lock governs any changes to the set of loaded automations; any routine that will # add/remove automations must be holding this lock when it does so. It's best to use @@ -775,7 +779,7 @@ async def read_bucket_by_trigger_id( automation_id: UUID, trigger_id: UUID, bucketing_key: Tuple[str, ...], -) -> Optional["ORMAutomationBucket"]: +) -> "ORMAutomationBucket | None": """Gets the bucket this event would fall into for the given Automation, if there is one currently""" query = sa.select(db.AutomationBucket).where( @@ -800,7 +804,9 @@ async def increment_bucket( last_event: Optional[ReceivedEvent], ) -> "ORMAutomationBucket": """Adds the given count to the bucket, returning the new bucket""" - additional_updates: dict = {"last_event": last_event} if last_event else {} + additional_updates: dict[str, ReceivedEvent] = ( + {"last_event": last_event} if last_event else {} + ) await session.execute( db.queries.insert(db.AutomationBucket) .values( @@ -827,13 +833,18 @@ async def increment_bucket( ) ) - return await read_bucket_by_trigger_id( + read_bucket = await read_bucket_by_trigger_id( session, bucket.automation_id, bucket.trigger_id, - bucket.bucketing_key, + tuple(bucket.bucketing_key), ) + if TYPE_CHECKING: + assert read_bucket is not None + + return read_bucket + @db_injector async def start_new_bucket( @@ -841,10 +852,10 @@ async def start_new_bucket( session: AsyncSession, trigger: EventTrigger, bucketing_key: Tuple[str, ...], - start: DateTime, - end: DateTime, + start: pendulum.DateTime, + end: pendulum.DateTime, count: int, - triggered_at: Optional[DateTime] = None, + triggered_at: Optional[pendulum.DateTime] = None, ) -> "ORMAutomationBucket": """Ensures that a bucket with the given start and end exists with the given count, returning the new bucket""" @@ -879,13 +890,18 @@ async def start_new_bucket( ) ) - return await read_bucket_by_trigger_id( + read_bucket = await read_bucket_by_trigger_id( session, automation.id, trigger.id, - bucketing_key, + tuple(bucketing_key), ) + if TYPE_CHECKING: + assert read_bucket is not None + + return read_bucket + @db_injector async def ensure_bucket( @@ -893,15 +909,17 @@ async def ensure_bucket( session: AsyncSession, trigger: EventTrigger, bucketing_key: Tuple[str, ...], - start: DateTime, - end: DateTime, + start: pendulum.DateTime, + end: pendulum.DateTime, last_event: Optional[ReceivedEvent], initial_count: int = 0, ) -> "ORMAutomationBucket": """Ensures that a bucket has been started for the given automation and key, returning the current bucket. Will not modify the existing bucket.""" automation = trigger.automation - additional_updates: dict = {"last_event": last_event} if last_event else {} + additional_updates: dict[str, ReceivedEvent] = ( + {"last_event": last_event} if last_event else {} + ) await session.execute( db.queries.insert(db.AutomationBucket) .values( @@ -928,10 +946,15 @@ async def ensure_bucket( ) ) - return await read_bucket_by_trigger_id( - session, automation.id, trigger.id, bucketing_key + read_bucket = await read_bucket_by_trigger_id( + session, automation.id, trigger.id, tuple(bucketing_key) ) + if TYPE_CHECKING: + assert read_bucket is not None + + return read_bucket + @db_injector async def remove_bucket( @@ -949,7 +972,7 @@ async def remove_bucket( @db_injector async def sweep_closed_buckets( - db: PrefectDBInterface, session: AsyncSession, older_than: DateTime + db: PrefectDBInterface, session: AsyncSession, older_than: pendulum.DateTime ) -> None: await session.execute( sa.delete(db.AutomationBucket).where(db.AutomationBucket.end <= older_than) @@ -958,7 +981,7 @@ async def sweep_closed_buckets( async def reset(): """Resets the in-memory state of the service""" - reset_events_clock() + await reset_events_clock() automations_by_id.clear() triggers.clear() next_proactive_runs.clear() @@ -1022,7 +1045,9 @@ async def message_handler(message: Message): proactive_task.cancel() -async def proactive_evaluation(trigger: EventTrigger, as_of: DateTime) -> DateTime: +async def proactive_evaluation( + trigger: EventTrigger, as_of: pendulum.DateTime +) -> pendulum.DateTime: """The core proactive evaluation operation for a single Automation""" assert isinstance(trigger, EventTrigger), repr(trigger) automation = trigger.automation @@ -1070,7 +1095,7 @@ async def proactive_evaluation(trigger: EventTrigger, as_of: DateTime) -> DateTi await session.commit() -async def evaluate_proactive_triggers(): +async def evaluate_proactive_triggers() -> None: for trigger in triggers.values(): if trigger.posture != Posture.Proactive: continue diff --git a/src/prefect/server/services/loop_service.py b/src/prefect/server/services/loop_service.py index 43d18b41f747..f337e7ad2f08 100644 --- a/src/prefect/server/services/loop_service.py +++ b/src/prefect/server/services/loop_service.py @@ -2,9 +2,12 @@ The base class for all Prefect REST API loop services. """ +from __future__ import annotations + import asyncio import signal -from typing import List, Optional +from operator import methodcaller +from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, overload import anyio import pendulum @@ -13,6 +16,9 @@ from prefect.settings import PREFECT_API_LOG_RETRYABLE_ERRORS from prefect.utilities.processutils import _register_signal +if TYPE_CHECKING: + import logging + class LoopService: """ @@ -35,11 +41,15 @@ def __init__( gracefully intercepted and shut down the running service. """ if loop_seconds: - self.loop_seconds = loop_seconds # seconds between runs - self._should_stop = False # flag for whether the service should stop running - self._is_running = False # flag for whether the service is running - self.name = type(self).__name__ - self.logger = get_logger(f"server.services.{self.name.lower()}") + self.loop_seconds: int = loop_seconds # seconds between runs + self._should_stop: bool = ( + False # flag for whether the service should stop running + ) + self._is_running: bool = False # flag for whether the service is running + self.name: str = type(self).__name__ + self.logger: "logging.Logger" = get_logger( + f"server.services.{self.name.lower()}" + ) if handle_signals: _register_signal(signal.SIGINT, self._stop) @@ -61,7 +71,15 @@ async def _on_stop(self) -> None: # reset the _is_running flag self._is_running = False - async def start(self, loops=None) -> None: + @overload + async def start(self, loops: None = None) -> NoReturn: + ... + + @overload + async def start(self, loops: int) -> None: + ... + + async def start(self, loops: int | None = None) -> None | NoReturn: """ Run the service `loops` time. Pass loops=None to run forever. @@ -128,7 +146,7 @@ async def start(self, loops=None) -> None: await self._on_stop() - async def stop(self, block=True) -> None: + async def stop(self, block: bool = True) -> None: """ Gracefully stops a running LoopService and optionally blocks until the service stops. @@ -157,7 +175,7 @@ async def stop(self, block=True) -> None: " inside the loop service, use `stop(block=False)` instead." ) - def _stop(self, *_) -> None: + def _stop(self, *_: Any) -> None: """ Private, synchronous method for setting the `_should_stop` flag. Takes arbitrary arguments so it can be used as a signal handler. @@ -177,15 +195,16 @@ async def run_once(self) -> None: raise NotImplementedError("LoopService subclasses must implement this method.") -async def run_multiple_services(loop_services: List[LoopService]): +async def run_multiple_services(loop_services: List[LoopService]) -> NoReturn: """ Only one signal handler can be active at a time, so this function takes a list of loop services and runs all of them with a global signal handler. """ - def stop_all_services(self, *_): + def stop_all_services(*_: Any) -> None: for service in loop_services: - service._stop() + stop = methodcaller("_stop") + stop(service) signal.signal(signal.SIGINT, stop_all_services) signal.signal(signal.SIGTERM, stop_all_services)