Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing in prefect.server.events #16692

Merged
merged 4 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/prefect/_internal/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def retry_async_fn(
retry_on_exceptions: tuple[type[Exception], ...] = (Exception,),
operation_name: Optional[str] = None,
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, Optional[R]]]
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]:
"""A decorator for retrying an async function.

Expand All @@ -48,9 +48,9 @@ def retry_async_fn(

def decorator(
func: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, Optional[R]]]:
) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
name = operation_name or func.__name__
for attempt in range(max_attempts):
try:
Expand All @@ -67,6 +67,9 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
f"Retrying in {delay:.2f} seconds..."
)
await asyncio.sleep(delay)
# Technically unreachable, but this raise helps pyright know that this function
# won't return None.
raise Exception(f"Function {name!r} failed after {max_attempts} attempts")

return wrapper

Expand Down
22 changes: 12 additions & 10 deletions src/prefect/server/events/schemas/automations.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class EventTrigger(ResourceTrigger):
@model_validator(mode="before")
@classmethod
def enforce_minimum_within_for_proactive_triggers(
cls, data: Dict[str, Any]
cls, data: Dict[str, Any] | Any
) -> Dict[str, Any]:
if not isinstance(data, dict):
return data
Expand All @@ -342,7 +342,7 @@ def enforce_minimum_within_for_proactive_triggers(

return data

def covers(self, event: ReceivedEvent):
def covers(self, event: ReceivedEvent) -> bool:
if not self.covers_resources(event.resource, event.related):
return False

Expand All @@ -356,10 +356,10 @@ def immediate(self) -> bool:
"""Does this reactive trigger fire immediately for all events?"""
return self.posture == Posture.Reactive and self.within == timedelta(0)

_event_pattern: Optional[re.Pattern] = PrivateAttr(None)
_event_pattern: Optional[re.Pattern[str]] = PrivateAttr(None)

@property
def event_pattern(self) -> re.Pattern:
def event_pattern(self) -> re.Pattern[str]:
"""A regular expression which may be evaluated against any event string to
determine if this trigger would be interested in the event"""
if self._event_pattern:
Expand Down Expand Up @@ -625,13 +625,15 @@ class Firing(PrefectBaseModel):

id: UUID = Field(default_factory=uuid4)

trigger: ServerTriggerTypes = Field(..., description="The trigger that is firing")
trigger: Union[ServerTriggerTypes, CompositeTrigger] = Field(
default=..., description="The trigger that is firing"
)
trigger_states: Set[TriggerState] = Field(
...,
default=...,
description="The state changes represented by this Firing",
)
triggered: DateTime = Field(
...,
default=...,
description=(
"The time at which this trigger fired, which may differ from the "
"occurred time of the associated event (as events processing may always "
Expand All @@ -654,16 +656,16 @@ class Firing(PrefectBaseModel):
),
)
triggering_event: Optional[ReceivedEvent] = Field(
None,
default=None,
description=(
"The most recent event associated with this Firing. This may be the "
"event that caused the trigger to fire (for Reactive triggers), or the "
"last event to match the trigger (for Proactive triggers), or the state "
"change event (for a Metric trigger)."
),
)
triggering_value: Any = Field(
None,
triggering_value: Optional[Any] = Field(
default=None,
description=(
"A value associated with this firing of a trigger. Maybe used to "
"convey additional information at the point of firing, like the value of "
Expand Down
19 changes: 12 additions & 7 deletions src/prefect/server/events/services/actions.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from __future__ import annotations

import asyncio
from typing import Optional
from typing import TYPE_CHECKING, NoReturn

from prefect.logging import get_logger
from prefect.server.events import actions
from prefect.server.utilities.messaging import create_consumer
from prefect.server.utilities.messaging import Consumer, create_consumer

if TYPE_CHECKING:
import logging

logger = get_logger(__name__)
logger: "logging.Logger" = get_logger(__name__)


class Actions:
"""Runs actions triggered by Automatinos"""

name: str = "Actions"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Actions already started"
self.consumer = create_consumer("actions")
self.consumer: Consumer = create_consumer("actions")

async with actions.consumer() as handler:
self.consumer_task = asyncio.create_task(self.consumer.run(handler))
Expand All @@ -28,7 +33,7 @@ async def start(self):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Actions not started"
self.consumer_task.cancel()
try:
Expand Down
19 changes: 12 additions & 7 deletions src/prefect/server/events/services/event_logger.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from __future__ import annotations

import asyncio
from typing import Optional
from typing import TYPE_CHECKING, NoReturn

import pendulum
import rich

from prefect.logging import get_logger
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.utilities.messaging import Message, create_consumer
from prefect.server.utilities.messaging import Consumer, Message, create_consumer

if TYPE_CHECKING:
import logging

logger = get_logger(__name__)
logger: "logging.Logger" = get_logger(__name__)


class EventLogger:
"""A debugging service that logs events to the console as they arrive."""

name: str = "EventLogger"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Logger already started"
self.consumer = create_consumer("events")
self.consumer: Consumer = create_consumer("events")

console = rich.console.Console()

Expand All @@ -46,7 +51,7 @@ async def handler(message: Message):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Logger not started"
self.consumer_task.cancel()
try:
Expand Down
26 changes: 18 additions & 8 deletions src/prefect/server/events/services/event_persister.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
storage as fast as it can. Never gets tired.
"""

from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator, List, Optional
from typing import TYPE_CHECKING, AsyncGenerator, List, NoReturn

import pendulum
import sqlalchemy as sa
Expand All @@ -15,25 +17,33 @@
from prefect.server.database import provide_database_interface
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.events.storage.database import write_events
from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer
from prefect.server.utilities.messaging import (
Consumer,
Message,
MessageHandler,
create_consumer,
)
from prefect.settings import (
PREFECT_API_SERVICES_EVENT_PERSISTER_BATCH_SIZE,
PREFECT_API_SERVICES_EVENT_PERSISTER_FLUSH_INTERVAL,
PREFECT_EVENTS_RETENTION_PERIOD,
)

logger = get_logger(__name__)
if TYPE_CHECKING:
import logging

logger: "logging.Logger" = get_logger(__name__)


class EventPersister:
"""A service that persists events to the database as they arrive."""

name: str = "EventLogger"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

def __init__(self):
self._started_event: Optional[asyncio.Event] = None
self._started_event: asyncio.Event | None = None

@property
def started_event(self) -> asyncio.Event:
Expand All @@ -45,9 +55,9 @@ def started_event(self) -> asyncio.Event:
def started_event(self, value: asyncio.Event) -> None:
self._started_event = value

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Event persister already started"
self.consumer = create_consumer("events")
self.consumer: Consumer = create_consumer("events")

async with create_handler(
batch_size=PREFECT_API_SERVICES_EVENT_PERSISTER_BATCH_SIZE.value(),
Expand All @@ -64,7 +74,7 @@ async def start(self):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Event persister not started"
self.consumer_task.cancel()
try:
Expand Down
23 changes: 14 additions & 9 deletions src/prefect/server/events/services/triggers.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from __future__ import annotations

import asyncio
from typing import Optional
from typing import TYPE_CHECKING, Any, NoReturn, Optional

from prefect.logging import get_logger
from prefect.server.events import triggers
from prefect.server.services.loop_service import LoopService
from prefect.server.utilities.messaging import create_consumer
from prefect.server.utilities.messaging import Consumer, create_consumer
from prefect.settings import PREFECT_EVENTS_PROACTIVE_GRANULARITY

logger = get_logger(__name__)
if TYPE_CHECKING:
import logging

logger: "logging.Logger" = get_logger(__name__)


class ReactiveTriggers:
"""Runs the reactive triggers consumer"""

name: str = "ReactiveTriggers"

consumer_task: Optional[asyncio.Task] = None
consumer_task: asyncio.Task[None] | None = None

async def start(self):
async def start(self) -> NoReturn:
assert self.consumer_task is None, "Reactive triggers already started"
self.consumer = create_consumer("events")
self.consumer: Consumer = create_consumer("events")

async with triggers.consumer() as handler:
self.consumer_task = asyncio.create_task(self.consumer.run(handler))
Expand All @@ -30,7 +35,7 @@ async def start(self):
except asyncio.CancelledError:
pass

async def stop(self):
async def stop(self) -> None:
assert self.consumer_task is not None, "Reactive triggers not started"
self.consumer_task.cancel()
try:
Expand All @@ -43,7 +48,7 @@ async def stop(self):


class ProactiveTriggers(LoopService):
def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any):
super().__init__(
loop_seconds=(
loop_seconds
Expand All @@ -52,5 +57,5 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
**kwargs,
)

async def run_once(self):
async def run_once(self) -> None:
await triggers.evaluate_proactive_triggers()
Loading
Loading