Skip to content

Commit

Permalink
Improve type completeness of prefect.server.services (#16701)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Jan 13, 2025
1 parent 5d47ba3 commit 043c21b
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 88 deletions.
4 changes: 3 additions & 1 deletion src/prefect/server/events/schemas/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def resources_in_role(self) -> Mapping[str, Sequence[RelatedResource]]:

@field_validator("related")
@classmethod
def enforce_maximum_related_resources(cls, value: List[RelatedResource]):
def enforce_maximum_related_resources(
cls, value: List[RelatedResource]
) -> List[RelatedResource]:
if len(value) > PREFECT_EVENTS_MAXIMUM_RELATED_RESOURCES.value():
raise ValueError(
"The maximum number of related resources "
Expand Down
6 changes: 4 additions & 2 deletions src/prefect/server/models/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from prefect.server.exceptions import ObjectNotFoundError
from prefect.server.orchestration.core_policy import MinimalFlowPolicy
from prefect.server.orchestration.global_policy import GlobalFlowPolicy
from prefect.server.orchestration.policies import BaseOrchestrationPolicy
from prefect.server.orchestration.policies import (
FlowRunOrchestrationPolicy,
)
from prefect.server.orchestration.rules import FlowOrchestrationContext
from prefect.server.schemas.core import TaskRunResult
from prefect.server.schemas.graph import Graph
Expand Down Expand Up @@ -506,7 +508,7 @@ async def set_flow_run_state(
flow_run_id: UUID,
state: schemas.states.State,
force: bool = False,
flow_policy: Optional[Type[BaseOrchestrationPolicy]] = None,
flow_policy: Optional[Type[FlowRunOrchestrationPolicy]] = None,
orchestration_parameters: Optional[Dict[str, Any]] = None,
) -> OrchestrationResult:
"""
Expand Down
8 changes: 4 additions & 4 deletions src/prefect/server/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,10 @@ class WorkPoolCreate(ActionBaseModel):
class WorkPoolUpdate(ActionBaseModel):
"""Data used by the Prefect REST API to update a work pool."""

description: Optional[str] = Field(None)
is_paused: Optional[bool] = Field(None)
base_job_template: Optional[Dict[str, Any]] = Field(None)
concurrency_limit: Optional[NonNegativeInteger] = Field(None)
description: Optional[str] = Field(default=None)
is_paused: Optional[bool] = Field(default=None)
base_job_template: Optional[Dict[str, Any]] = Field(default=None)
concurrency_limit: Optional[NonNegativeInteger] = Field(default=None)

_validate_base_job_template = field_validator("base_job_template")(
validate_base_job_template
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/server/schemas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,12 @@ class TaskRun(ORMBaseModel):

@field_validator("name", mode="before")
@classmethod
def set_name(cls, name):
def set_name(cls, name: str) -> str:
return get_or_create_run_name(name)

@field_validator("cache_key")
@classmethod
def validate_cache_key(cls, cache_key):
def validate_cache_key(cls, cache_key: str) -> str:
return validate_cache_key_length(cache_key)


Expand Down
30 changes: 19 additions & 11 deletions src/prefect/server/services/cancellation_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
"""

import asyncio
from typing import Optional
from typing import Any, Optional
from uuid import UUID

import pendulum
import sqlalchemy as sa
from sqlalchemy.sql.expression import or_

import prefect.server.models as models
from prefect.server.database import PrefectDBInterface, inject_db, orm_models
from prefect.server.database import PrefectDBInterface, orm_models
from prefect.server.database.dependencies import db_injector
from prefect.server.schemas import filters, states
from prefect.server.services.loop_service import LoopService
from prefect.settings import PREFECT_API_SERVICES_CANCELLATION_CLEANUP_LOOP_SECONDS
Expand All @@ -25,7 +26,7 @@ class CancellationCleanup(LoopService):
cancelling flow runs
"""

def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any):
super().__init__(
loop_seconds=loop_seconds
or PREFECT_API_SERVICES_CANCELLATION_CLEANUP_LOOP_SECONDS.value(),
Expand All @@ -35,8 +36,8 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
# query for this many runs to mark failed at once
self.batch_size = 200

@inject_db
async def run_once(self, db: PrefectDBInterface):
@db_injector
async def run_once(self, db: PrefectDBInterface) -> None:
"""
- cancels active tasks belonging to recently cancelled flow runs
- cancels any active subflow that belongs to a cancelled flow
Expand All @@ -49,7 +50,9 @@ async def run_once(self, db: PrefectDBInterface):

self.logger.info("Finished cleaning up cancelled flow runs.")

async def clean_up_cancelled_flow_run_task_runs(self, db: PrefectDBInterface):
async def clean_up_cancelled_flow_run_task_runs(
self, db: PrefectDBInterface
) -> None:
while True:
cancelled_flow_query = (
sa.select(db.FlowRun)
Expand All @@ -72,7 +75,7 @@ async def clean_up_cancelled_flow_run_task_runs(self, db: PrefectDBInterface):
if len(flow_runs) < self.batch_size:
break

async def clean_up_cancelled_subflow_runs(self, db: PrefectDBInterface):
async def clean_up_cancelled_subflow_runs(self, db: PrefectDBInterface) -> None:
high_water_mark = UUID(int=0)
while True:
subflow_query = (
Expand Down Expand Up @@ -110,9 +113,13 @@ async def _cancel_child_runs(
async with db.session_context() as session:
child_task_runs = await models.task_runs.read_task_runs(
session,
flow_run_filter=filters.FlowRunFilter(id={"any_": [flow_run.id]}),
flow_run_filter=filters.FlowRunFilter(
id=filters.FlowRunFilterId(any_=[flow_run.id])
),
task_run_filter=filters.TaskRunFilter(
state={"type": {"any_": NON_TERMINAL_STATES}}
state=filters.TaskRunFilterState(
type=filters.TaskRunFilterStateType(any_=NON_TERMINAL_STATES)
)
),
limit=100,
)
Expand All @@ -131,7 +138,7 @@ async def _cancel_child_runs(
async def _cancel_subflow(
self, db: PrefectDBInterface, flow_run: orm_models.FlowRun
) -> Optional[bool]:
if not flow_run.parent_task_run_id:
if not flow_run.parent_task_run_id or not flow_run.state:
return False

if flow_run.state.type in states.TERMINAL_STATES:
Expand All @@ -142,7 +149,7 @@ async def _cancel_subflow(
session, task_run_id=flow_run.parent_task_run_id
)

if not parent_task_run:
if not parent_task_run or not parent_task_run.flow_run_id:
# Global orchestration policy will prevent further orchestration
return False

Expand All @@ -152,6 +159,7 @@ async def _cancel_subflow(

if (
containing_flow_run
and containing_flow_run.state
and containing_flow_run.state.type != states.StateType.CANCELLED
):
# Nothing to do here; the parent is not cancelled
Expand Down
32 changes: 21 additions & 11 deletions src/prefect/server/services/flow_run_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
A service that checks for flow run notifications and sends them.
"""

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any
from uuid import UUID

import sqlalchemy as sa

from prefect.server import models, schemas
from prefect.server.database import PrefectDBInterface, inject_db
from prefect.server.database import PrefectDBInterface
from prefect.server.database.dependencies import db_injector
from prefect.server.database.query_components import FlowRunNotificationsFromQueue
from prefect.server.services.loop_service import LoopService
from prefect.utilities import urls

Expand All @@ -23,10 +28,10 @@ class FlowRunNotifications(LoopService):

# check queue every 4 seconds
# note: a tight loop is executed until the queue is exhausted
loop_seconds: int = 4
loop_seconds: float = 4

@inject_db
async def run_once(self, db: PrefectDBInterface):
@db_injector
async def run_once(self, db: PrefectDBInterface) -> None:
while True:
async with db.session_context(begin_transaction=True) as session:
# Drain the queue one entry at a time, because if a transient
Expand Down Expand Up @@ -68,13 +73,12 @@ async def run_once(self, db: PrefectDBInterface):
await session.rollback()
assert not connection.invalidated

@inject_db
async def send_flow_run_notification(
self,
session: sa.orm.session,
db: PrefectDBInterface,
notification,
):
session: sa.orm.session,
notification: FlowRunNotificationsFromQueue,
) -> None:
try:
orm_block_document = await session.get(
db.BlockDocument, notification.block_document_id
Expand All @@ -97,6 +101,10 @@ async def send_flow_run_notification(
)

message = self.construct_notification_message(notification=notification)
if TYPE_CHECKING:
from prefect.blocks.abstract import NotificationBlock

assert isinstance(block, NotificationBlock)
await block.notify(
subject="Prefect flow run notification",
body=message,
Expand All @@ -118,7 +126,9 @@ async def send_flow_run_notification(
exc_info=True,
)

def construct_notification_message(self, notification) -> str:
def construct_notification_message(
self, notification: FlowRunNotificationsFromQueue
) -> str:
"""
Construct the message for a flow run notification, including
templating any variables.
Expand All @@ -129,7 +139,7 @@ def construct_notification_message(self, notification) -> str:
)

# create a dict from the sqlalchemy object for templating
notification_dict = dict(notification._mapping)
notification_dict: dict[str, Any] = dict(notification._mapping)
# add the flow run url to the info
notification_dict["flow_run_url"] = self.get_ui_url_for_flow_run_id(
flow_run_id=notification_dict["flow_run_id"]
Expand All @@ -143,7 +153,7 @@ def construct_notification_message(self, notification) -> str:
)
return message

def get_ui_url_for_flow_run_id(self, flow_run_id: UUID) -> str:
def get_ui_url_for_flow_run_id(self, flow_run_id: UUID) -> str | None:
"""
Returns a link to the flow run view of the given flow run id.
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/server/services/foreman.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from datetime import timedelta
from typing import Optional
from typing import Any, Optional

import pendulum
import sqlalchemy as sa
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
fallback_heartbeat_interval_seconds: Optional[int] = None,
deployment_last_polled_timeout_seconds: Optional[int] = None,
work_queue_last_polled_timeout_seconds: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
super().__init__(
loop_seconds=loop_seconds
Expand Down
20 changes: 14 additions & 6 deletions src/prefect/server/services/late_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
The threshold for a late run can be configured by changing `PREFECT_API_SERVICES_LATE_RUNS_AFTER_SECONDS`.
"""

from __future__ import annotations

import asyncio
import datetime
from typing import Optional
from typing import TYPE_CHECKING, Any

import pendulum
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession

import prefect.server.models as models
from prefect.server.database import PrefectDBInterface, inject_db
from prefect.server.database.dependencies import db_injector
from prefect.server.exceptions import ObjectNotFoundError
from prefect.server.orchestration.core_policy import MarkLateRunsPolicy
from prefect.server.schemas import states
Expand All @@ -22,6 +25,9 @@
PREFECT_API_SERVICES_LATE_RUNS_LOOP_SECONDS,
)

if TYPE_CHECKING:
from uuid import UUID


class MarkLateRuns(LoopService):
"""
Expand All @@ -32,7 +38,7 @@ class MarkLateRuns(LoopService):
Prefect REST API Settings.
"""

def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
def __init__(self, loop_seconds: float | None = None, **kwargs: Any):
super().__init__(
loop_seconds=loop_seconds
or PREFECT_API_SERVICES_LATE_RUNS_LOOP_SECONDS.value(),
Expand All @@ -47,8 +53,8 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs):
# query for this many runs to mark as late at once
self.batch_size = 400

@inject_db
async def run_once(self, db: PrefectDBInterface):
@db_injector
async def run_once(self, db: PrefectDBInterface) -> None:
"""
Mark flow runs as late by:
Expand Down Expand Up @@ -81,7 +87,7 @@ async def run_once(self, db: PrefectDBInterface):
@inject_db
def _get_select_late_flow_runs_query(
self, scheduled_to_start_before: datetime.datetime, db: PrefectDBInterface
):
) -> sa.Select[tuple["UUID", pendulum.DateTime | None]]:
"""
Returns a sqlalchemy query for late flow runs.
Expand All @@ -106,7 +112,9 @@ def _get_select_late_flow_runs_query(
return query

async def _mark_flow_run_as_late(
self, session: AsyncSession, flow_run: PrefectDBInterface.FlowRun
self,
session: AsyncSession,
flow_run: sa.Row[tuple["UUID", pendulum.DateTime | None]],
) -> None:
"""
Mark a flow run as late.
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/server/services/loop_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
gracefully intercepted and shut down the running service.
"""
if loop_seconds:
self.loop_seconds: int = loop_seconds # seconds between runs
self.loop_seconds: float = loop_seconds # seconds between runs
self._should_stop: bool = (
False # flag for whether the service should stop running
)
Expand Down
Loading

0 comments on commit 043c21b

Please sign in to comment.