From 043c21b2bb6a4091d6f5af3f2ee506d4bbb0ae5a Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 13 Jan 2025 09:31:29 -0600 Subject: [PATCH] Improve type completeness of `prefect.server.services` (#16701) --- src/prefect/server/events/schemas/events.py | 4 +- src/prefect/server/models/flow_runs.py | 6 ++- src/prefect/server/schemas/actions.py | 8 +-- src/prefect/server/schemas/core.py | 4 +- .../server/services/cancellation_cleanup.py | 30 +++++++---- .../server/services/flow_run_notifications.py | 32 ++++++++---- src/prefect/server/services/foreman.py | 4 +- src/prefect/server/services/late_runs.py | 20 +++++--- src/prefect/server/services/loop_service.py | 2 +- .../server/services/pause_expirations.py | 20 +++++--- src/prefect/server/services/scheduler.py | 51 ++++++++++--------- .../server/services/task_run_recorder.py | 34 +++++++++---- src/prefect/server/services/telemetry.py | 19 +++---- 13 files changed, 146 insertions(+), 88 deletions(-) diff --git a/src/prefect/server/events/schemas/events.py b/src/prefect/server/events/schemas/events.py index b5d703898ee5..36f628f11b42 100644 --- a/src/prefect/server/events/schemas/events.py +++ b/src/prefect/server/events/schemas/events.py @@ -160,7 +160,9 @@ def resources_in_role(self) -> Mapping[str, Sequence[RelatedResource]]: @field_validator("related") @classmethod - def enforce_maximum_related_resources(cls, value: List[RelatedResource]): + def enforce_maximum_related_resources( + cls, value: List[RelatedResource] + ) -> List[RelatedResource]: if len(value) > PREFECT_EVENTS_MAXIMUM_RELATED_RESOURCES.value(): raise ValueError( "The maximum number of related resources " diff --git a/src/prefect/server/models/flow_runs.py b/src/prefect/server/models/flow_runs.py index 87d6a92b136b..fca96360dbfa 100644 --- a/src/prefect/server/models/flow_runs.py +++ b/src/prefect/server/models/flow_runs.py @@ -34,7 +34,9 @@ from prefect.server.exceptions import ObjectNotFoundError from prefect.server.orchestration.core_policy import MinimalFlowPolicy from prefect.server.orchestration.global_policy import GlobalFlowPolicy -from prefect.server.orchestration.policies import BaseOrchestrationPolicy +from prefect.server.orchestration.policies import ( + FlowRunOrchestrationPolicy, +) from prefect.server.orchestration.rules import FlowOrchestrationContext from prefect.server.schemas.core import TaskRunResult from prefect.server.schemas.graph import Graph @@ -506,7 +508,7 @@ async def set_flow_run_state( flow_run_id: UUID, state: schemas.states.State, force: bool = False, - flow_policy: Optional[Type[BaseOrchestrationPolicy]] = None, + flow_policy: Optional[Type[FlowRunOrchestrationPolicy]] = None, orchestration_parameters: Optional[Dict[str, Any]] = None, ) -> OrchestrationResult: """ diff --git a/src/prefect/server/schemas/actions.py b/src/prefect/server/schemas/actions.py index b4a2e87a9ada..3d5bdff9b9cd 100644 --- a/src/prefect/server/schemas/actions.py +++ b/src/prefect/server/schemas/actions.py @@ -836,10 +836,10 @@ class WorkPoolCreate(ActionBaseModel): class WorkPoolUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a work pool.""" - description: Optional[str] = Field(None) - is_paused: Optional[bool] = Field(None) - base_job_template: Optional[Dict[str, Any]] = Field(None) - concurrency_limit: Optional[NonNegativeInteger] = Field(None) + description: Optional[str] = Field(default=None) + is_paused: Optional[bool] = Field(default=None) + base_job_template: Optional[Dict[str, Any]] = Field(default=None) + concurrency_limit: Optional[NonNegativeInteger] = Field(default=None) _validate_base_job_template = field_validator("base_job_template")( validate_base_job_template diff --git a/src/prefect/server/schemas/core.py b/src/prefect/server/schemas/core.py index fe86a21efd9b..3b37a1115ea6 100644 --- a/src/prefect/server/schemas/core.py +++ b/src/prefect/server/schemas/core.py @@ -532,12 +532,12 @@ class TaskRun(ORMBaseModel): @field_validator("name", mode="before") @classmethod - def set_name(cls, name): + def set_name(cls, name: str) -> str: return get_or_create_run_name(name) @field_validator("cache_key") @classmethod - def validate_cache_key(cls, cache_key): + def validate_cache_key(cls, cache_key: str) -> str: return validate_cache_key_length(cache_key) diff --git a/src/prefect/server/services/cancellation_cleanup.py b/src/prefect/server/services/cancellation_cleanup.py index 221201d57a8b..95d4b7f1b612 100644 --- a/src/prefect/server/services/cancellation_cleanup.py +++ b/src/prefect/server/services/cancellation_cleanup.py @@ -3,7 +3,7 @@ """ import asyncio -from typing import Optional +from typing import Any, Optional from uuid import UUID import pendulum @@ -11,7 +11,8 @@ from sqlalchemy.sql.expression import or_ import prefect.server.models as models -from prefect.server.database import PrefectDBInterface, inject_db, orm_models +from prefect.server.database import PrefectDBInterface, orm_models +from prefect.server.database.dependencies import db_injector from prefect.server.schemas import filters, states from prefect.server.services.loop_service import LoopService from prefect.settings import PREFECT_API_SERVICES_CANCELLATION_CLEANUP_LOOP_SECONDS @@ -25,7 +26,7 @@ class CancellationCleanup(LoopService): cancelling flow runs """ - def __init__(self, loop_seconds: Optional[float] = None, **kwargs): + def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any): super().__init__( loop_seconds=loop_seconds or PREFECT_API_SERVICES_CANCELLATION_CLEANUP_LOOP_SECONDS.value(), @@ -35,8 +36,8 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs): # query for this many runs to mark failed at once self.batch_size = 200 - @inject_db - async def run_once(self, db: PrefectDBInterface): + @db_injector + async def run_once(self, db: PrefectDBInterface) -> None: """ - cancels active tasks belonging to recently cancelled flow runs - cancels any active subflow that belongs to a cancelled flow @@ -49,7 +50,9 @@ async def run_once(self, db: PrefectDBInterface): self.logger.info("Finished cleaning up cancelled flow runs.") - async def clean_up_cancelled_flow_run_task_runs(self, db: PrefectDBInterface): + async def clean_up_cancelled_flow_run_task_runs( + self, db: PrefectDBInterface + ) -> None: while True: cancelled_flow_query = ( sa.select(db.FlowRun) @@ -72,7 +75,7 @@ async def clean_up_cancelled_flow_run_task_runs(self, db: PrefectDBInterface): if len(flow_runs) < self.batch_size: break - async def clean_up_cancelled_subflow_runs(self, db: PrefectDBInterface): + async def clean_up_cancelled_subflow_runs(self, db: PrefectDBInterface) -> None: high_water_mark = UUID(int=0) while True: subflow_query = ( @@ -110,9 +113,13 @@ async def _cancel_child_runs( async with db.session_context() as session: child_task_runs = await models.task_runs.read_task_runs( session, - flow_run_filter=filters.FlowRunFilter(id={"any_": [flow_run.id]}), + flow_run_filter=filters.FlowRunFilter( + id=filters.FlowRunFilterId(any_=[flow_run.id]) + ), task_run_filter=filters.TaskRunFilter( - state={"type": {"any_": NON_TERMINAL_STATES}} + state=filters.TaskRunFilterState( + type=filters.TaskRunFilterStateType(any_=NON_TERMINAL_STATES) + ) ), limit=100, ) @@ -131,7 +138,7 @@ async def _cancel_child_runs( async def _cancel_subflow( self, db: PrefectDBInterface, flow_run: orm_models.FlowRun ) -> Optional[bool]: - if not flow_run.parent_task_run_id: + if not flow_run.parent_task_run_id or not flow_run.state: return False if flow_run.state.type in states.TERMINAL_STATES: @@ -142,7 +149,7 @@ async def _cancel_subflow( session, task_run_id=flow_run.parent_task_run_id ) - if not parent_task_run: + if not parent_task_run or not parent_task_run.flow_run_id: # Global orchestration policy will prevent further orchestration return False @@ -152,6 +159,7 @@ async def _cancel_subflow( if ( containing_flow_run + and containing_flow_run.state and containing_flow_run.state.type != states.StateType.CANCELLED ): # Nothing to do here; the parent is not cancelled diff --git a/src/prefect/server/services/flow_run_notifications.py b/src/prefect/server/services/flow_run_notifications.py index 132c81c4c9fa..e350bb6b17f8 100644 --- a/src/prefect/server/services/flow_run_notifications.py +++ b/src/prefect/server/services/flow_run_notifications.py @@ -2,13 +2,18 @@ A service that checks for flow run notifications and sends them. """ +from __future__ import annotations + import asyncio +from typing import TYPE_CHECKING, Any from uuid import UUID import sqlalchemy as sa from prefect.server import models, schemas -from prefect.server.database import PrefectDBInterface, inject_db +from prefect.server.database import PrefectDBInterface +from prefect.server.database.dependencies import db_injector +from prefect.server.database.query_components import FlowRunNotificationsFromQueue from prefect.server.services.loop_service import LoopService from prefect.utilities import urls @@ -23,10 +28,10 @@ class FlowRunNotifications(LoopService): # check queue every 4 seconds # note: a tight loop is executed until the queue is exhausted - loop_seconds: int = 4 + loop_seconds: float = 4 - @inject_db - async def run_once(self, db: PrefectDBInterface): + @db_injector + async def run_once(self, db: PrefectDBInterface) -> None: while True: async with db.session_context(begin_transaction=True) as session: # Drain the queue one entry at a time, because if a transient @@ -68,13 +73,12 @@ async def run_once(self, db: PrefectDBInterface): await session.rollback() assert not connection.invalidated - @inject_db async def send_flow_run_notification( self, - session: sa.orm.session, db: PrefectDBInterface, - notification, - ): + session: sa.orm.session, + notification: FlowRunNotificationsFromQueue, + ) -> None: try: orm_block_document = await session.get( db.BlockDocument, notification.block_document_id @@ -97,6 +101,10 @@ async def send_flow_run_notification( ) message = self.construct_notification_message(notification=notification) + if TYPE_CHECKING: + from prefect.blocks.abstract import NotificationBlock + + assert isinstance(block, NotificationBlock) await block.notify( subject="Prefect flow run notification", body=message, @@ -118,7 +126,9 @@ async def send_flow_run_notification( exc_info=True, ) - def construct_notification_message(self, notification) -> str: + def construct_notification_message( + self, notification: FlowRunNotificationsFromQueue + ) -> str: """ Construct the message for a flow run notification, including templating any variables. @@ -129,7 +139,7 @@ def construct_notification_message(self, notification) -> str: ) # create a dict from the sqlalchemy object for templating - notification_dict = dict(notification._mapping) + notification_dict: dict[str, Any] = dict(notification._mapping) # add the flow run url to the info notification_dict["flow_run_url"] = self.get_ui_url_for_flow_run_id( flow_run_id=notification_dict["flow_run_id"] @@ -143,7 +153,7 @@ def construct_notification_message(self, notification) -> str: ) return message - def get_ui_url_for_flow_run_id(self, flow_run_id: UUID) -> str: + def get_ui_url_for_flow_run_id(self, flow_run_id: UUID) -> str | None: """ Returns a link to the flow run view of the given flow run id. diff --git a/src/prefect/server/services/foreman.py b/src/prefect/server/services/foreman.py index 5d8b79b033bd..11c04ff26fde 100644 --- a/src/prefect/server/services/foreman.py +++ b/src/prefect/server/services/foreman.py @@ -3,7 +3,7 @@ """ from datetime import timedelta -from typing import Optional +from typing import Any, Optional import pendulum import sqlalchemy as sa @@ -43,7 +43,7 @@ def __init__( fallback_heartbeat_interval_seconds: Optional[int] = None, deployment_last_polled_timeout_seconds: Optional[int] = None, work_queue_last_polled_timeout_seconds: Optional[int] = None, - **kwargs, + **kwargs: Any, ): super().__init__( loop_seconds=loop_seconds diff --git a/src/prefect/server/services/late_runs.py b/src/prefect/server/services/late_runs.py index dfc2abf6dd83..c393588bccc1 100644 --- a/src/prefect/server/services/late_runs.py +++ b/src/prefect/server/services/late_runs.py @@ -3,9 +3,11 @@ The threshold for a late run can be configured by changing `PREFECT_API_SERVICES_LATE_RUNS_AFTER_SECONDS`. """ +from __future__ import annotations + import asyncio import datetime -from typing import Optional +from typing import TYPE_CHECKING, Any import pendulum import sqlalchemy as sa @@ -13,6 +15,7 @@ import prefect.server.models as models from prefect.server.database import PrefectDBInterface, inject_db +from prefect.server.database.dependencies import db_injector from prefect.server.exceptions import ObjectNotFoundError from prefect.server.orchestration.core_policy import MarkLateRunsPolicy from prefect.server.schemas import states @@ -22,6 +25,9 @@ PREFECT_API_SERVICES_LATE_RUNS_LOOP_SECONDS, ) +if TYPE_CHECKING: + from uuid import UUID + class MarkLateRuns(LoopService): """ @@ -32,7 +38,7 @@ class MarkLateRuns(LoopService): Prefect REST API Settings. """ - def __init__(self, loop_seconds: Optional[float] = None, **kwargs): + def __init__(self, loop_seconds: float | None = None, **kwargs: Any): super().__init__( loop_seconds=loop_seconds or PREFECT_API_SERVICES_LATE_RUNS_LOOP_SECONDS.value(), @@ -47,8 +53,8 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs): # query for this many runs to mark as late at once self.batch_size = 400 - @inject_db - async def run_once(self, db: PrefectDBInterface): + @db_injector + async def run_once(self, db: PrefectDBInterface) -> None: """ Mark flow runs as late by: @@ -81,7 +87,7 @@ async def run_once(self, db: PrefectDBInterface): @inject_db def _get_select_late_flow_runs_query( self, scheduled_to_start_before: datetime.datetime, db: PrefectDBInterface - ): + ) -> sa.Select[tuple["UUID", pendulum.DateTime | None]]: """ Returns a sqlalchemy query for late flow runs. @@ -106,7 +112,9 @@ def _get_select_late_flow_runs_query( return query async def _mark_flow_run_as_late( - self, session: AsyncSession, flow_run: PrefectDBInterface.FlowRun + self, + session: AsyncSession, + flow_run: sa.Row[tuple["UUID", pendulum.DateTime | None]], ) -> None: """ Mark a flow run as late. diff --git a/src/prefect/server/services/loop_service.py b/src/prefect/server/services/loop_service.py index f337e7ad2f08..ea71819aab02 100644 --- a/src/prefect/server/services/loop_service.py +++ b/src/prefect/server/services/loop_service.py @@ -41,7 +41,7 @@ def __init__( gracefully intercepted and shut down the running service. """ if loop_seconds: - self.loop_seconds: int = loop_seconds # seconds between runs + self.loop_seconds: float = loop_seconds # seconds between runs self._should_stop: bool = ( False # flag for whether the service should stop running ) diff --git a/src/prefect/server/services/pause_expirations.py b/src/prefect/server/services/pause_expirations.py index 390955bf7f9c..b9949210ea09 100644 --- a/src/prefect/server/services/pause_expirations.py +++ b/src/prefect/server/services/pause_expirations.py @@ -3,14 +3,16 @@ """ import asyncio -from typing import Optional +from typing import Any, Optional import pendulum import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.models as models -from prefect.server.database import PrefectDBInterface, inject_db +from prefect.server.database import PrefectDBInterface +from prefect.server.database.dependencies import db_injector +from prefect.server.database.orm_models import FlowRun from prefect.server.schemas import states from prefect.server.services.loop_service import LoopService from prefect.settings import PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_LOOP_SECONDS @@ -21,7 +23,7 @@ class FailExpiredPauses(LoopService): A simple loop service responsible for identifying Paused flow runs that no longer can be resumed. """ - def __init__(self, loop_seconds: Optional[float] = None, **kwargs): + def __init__(self, loop_seconds: Optional[float] = None, **kwargs: Any): super().__init__( loop_seconds=loop_seconds or PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_LOOP_SECONDS.value(), @@ -31,8 +33,8 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs): # query for this many runs to mark failed at once self.batch_size = 200 - @inject_db - async def run_once(self, db: PrefectDBInterface): + @db_injector + async def run_once(self, db: PrefectDBInterface) -> None: """ Mark flow runs as failed by: @@ -63,14 +65,18 @@ async def run_once(self, db: PrefectDBInterface): self.logger.info("Finished monitoring for late runs.") async def _mark_flow_run_as_failed( - self, session: AsyncSession, flow_run: PrefectDBInterface.FlowRun + self, session: AsyncSession, flow_run: FlowRun ) -> None: """ Mark a flow run as failed. Pass-through method for overrides. """ - if flow_run.state.state_details.pause_timeout < pendulum.now("UTC"): + if ( + flow_run.state is not None + and flow_run.state.state_details.pause_timeout is not None + and flow_run.state.state_details.pause_timeout < pendulum.now("UTC") + ): await models.flow_runs.set_flow_run_state( session=session, flow_run_id=flow_run.id, diff --git a/src/prefect/server/services/scheduler.py b/src/prefect/server/services/scheduler.py index 69013ce7dbe5..39e9862611f1 100644 --- a/src/prefect/server/services/scheduler.py +++ b/src/prefect/server/services/scheduler.py @@ -2,16 +2,19 @@ The Scheduler service. """ +from __future__ import annotations + import asyncio import datetime -from typing import Dict, List, Optional +from typing import Any, Sequence from uuid import UUID import pendulum import sqlalchemy as sa import prefect.server.models as models -from prefect.server.database import PrefectDBInterface, inject_db +from prefect.server.database import PrefectDBInterface +from prefect.server.database.dependencies import db_injector from prefect.server.schemas.states import StateType from prefect.server.services.loop_service import LoopService, run_multiple_services from prefect.settings import ( @@ -37,9 +40,9 @@ class Scheduler(LoopService): # the main scheduler takes its loop interval from # PREFECT_API_SERVICES_SCHEDULER_LOOP_SECONDS - loop_seconds = None + loop_seconds: float - def __init__(self, loop_seconds: Optional[float] = None, **kwargs): + def __init__(self, loop_seconds: float | None = None, **kwargs: Any): super().__init__( loop_seconds=( loop_seconds @@ -59,12 +62,12 @@ def __init__(self, loop_seconds: Optional[float] = None, **kwargs): self.min_scheduled_time: datetime.timedelta = ( PREFECT_API_SERVICES_SCHEDULER_MIN_SCHEDULED_TIME.value() ) - self.insert_batch_size = ( + self.insert_batch_size: int = ( PREFECT_API_SERVICES_SCHEDULER_INSERT_BATCH_SIZE.value() ) - @inject_db - async def run_once(self, db: PrefectDBInterface): + @db_injector + async def run_once(self, db: PrefectDBInterface) -> None: """ Schedule flow runs by: @@ -101,7 +104,7 @@ async def run_once(self, db: PrefectDBInterface): for batch in batched_iterable(runs_to_insert, self.insert_batch_size): async with db.session_context(begin_transaction=True) as session: inserted_runs = await self._insert_scheduled_flow_runs( - session=session, runs=batch + session=session, runs=list(batch) ) total_inserted_runs += len(inserted_runs) @@ -114,8 +117,10 @@ async def run_once(self, db: PrefectDBInterface): self.logger.info(f"Scheduled {total_inserted_runs} runs.") - @inject_db - def _get_select_deployments_to_schedule_query(self, db: PrefectDBInterface): + @db_injector + def _get_select_deployments_to_schedule_query( + self, db: PrefectDBInterface + ) -> sa.Select[tuple[UUID]]: """ Returns a sqlalchemy query for selecting deployments to schedule. @@ -180,9 +185,9 @@ def _get_select_deployments_to_schedule_query(self, db: PrefectDBInterface): async def _collect_flow_runs( self, session: sa.orm.Session, - deployment_ids: List[UUID], - ) -> List[Dict]: - runs_to_insert = [] + deployment_ids: Sequence[UUID], + ) -> list[dict[str, Any]]: + runs_to_insert: list[dict[str, Any]] = [] for deployment_id in deployment_ids: now = pendulum.now("UTC") # guard against erroneously configured schedules @@ -224,9 +229,10 @@ async def _collect_flow_runs( raise TryAgain() return runs_to_insert - @inject_db + @db_injector async def _generate_scheduled_flow_runs( self, + db: PrefectDBInterface, session: sa.orm.Session, deployment_id: UUID, start_time: datetime.datetime, @@ -234,8 +240,7 @@ async def _generate_scheduled_flow_runs( min_time: datetime.timedelta, min_runs: int, max_runs: int, - db: PrefectDBInterface, - ) -> List[Dict]: + ) -> list[dict[str, Any]]: """ Given a `deployment_id` and schedule params, generates a list of flow run objects and associated scheduled states that represent scheduled flow runs. @@ -274,13 +279,11 @@ async def _generate_scheduled_flow_runs( max_runs=max_runs, ) - @inject_db async def _insert_scheduled_flow_runs( self, session: sa.orm.Session, - runs: List[Dict], - db: PrefectDBInterface, - ) -> List[UUID]: + runs: list[dict[str, Any]], + ) -> Sequence[UUID]: """ Given a list of flow runs to schedule, as generated by `_generate_scheduled_flow_runs`, inserts them into the database. Note this is a @@ -306,10 +309,12 @@ class RecentDeploymentsScheduler(Scheduler): """ # this scheduler runs on a tight loop - loop_seconds = 5 + loop_seconds: float = 5 - @inject_db - def _get_select_deployments_to_schedule_query(self, db: PrefectDBInterface): + @db_injector + def _get_select_deployments_to_schedule_query( + self, db: PrefectDBInterface + ) -> sa.Select[tuple[UUID]]: """ Returns a sqlalchemy query for selecting deployments to schedule """ diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py index ed51953132a2..33ae2f86e9c1 100644 --- a/src/prefect/server/services/task_run_recorder.py +++ b/src/prefect/server/services/task_run_recorder.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import asyncio from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Dict, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional from uuid import UUID import pendulum @@ -17,12 +19,20 @@ from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.schemas.core import TaskRun from prefect.server.schemas.states import State -from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer +from prefect.server.utilities.messaging import ( + Consumer, + Message, + MessageHandler, + create_consumer, +) + +if TYPE_CHECKING: + import logging -logger = get_logger(__name__) +logger: "logging.Logger" = get_logger(__name__) -def causal_ordering(): +def causal_ordering() -> CausalOrdering: return CausalOrdering( "task-run-recorder", ) @@ -35,6 +45,8 @@ async def _insert_task_run( task_run: TaskRun, task_run_attributes: Dict[str, Any], ): + if TYPE_CHECKING: + assert task_run.state is not None await session.execute( db.queries.insert(db.TaskRun) .values( @@ -58,6 +70,8 @@ async def _insert_task_run( async def _insert_task_run_state( db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun ): + if TYPE_CHECKING: + assert task_run.state is not None await session.execute( db.queries.insert(db.TaskRunState) .values( @@ -80,6 +94,8 @@ async def _update_task_run_with_state( task_run: TaskRun, denormalized_state_attributes: Dict[str, Any], ): + if TYPE_CHECKING: + assert task_run.state is not None await session.execute( sa.update(db.TaskRun) .where( @@ -121,7 +137,7 @@ def task_run_from_event(event: ReceivedEvent) -> TaskRun: ) -async def record_task_run_event(event: ReceivedEvent): +async def record_task_run_event(event: ReceivedEvent) -> None: task_run = task_run_from_event(event) task_run_attributes = task_run.model_dump_for_orm( @@ -201,7 +217,7 @@ class TaskRunRecorder: name: str = "TaskRunRecorder" - consumer_task: Optional[asyncio.Task] = None + consumer_task: asyncio.Task[None] | None = None def __init__(self): self._started_event: Optional[asyncio.Event] = None @@ -216,9 +232,9 @@ def started_event(self) -> asyncio.Event: def started_event(self, value: asyncio.Event) -> None: self._started_event = value - async def start(self): + async def start(self) -> None: assert self.consumer_task is None, "TaskRunRecorder already started" - self.consumer = create_consumer("events") + self.consumer: Consumer = create_consumer("events") async with consumer() as handler: self.consumer_task = asyncio.create_task(self.consumer.run(handler)) @@ -230,7 +246,7 @@ async def start(self): except asyncio.CancelledError: pass - async def stop(self): + async def stop(self) -> None: assert self.consumer_task is not None, "Logger not started" self.consumer_task.cancel() try: diff --git a/src/prefect/server/services/telemetry.py b/src/prefect/server/services/telemetry.py index 0fced956a9fc..a0773aa70545 100644 --- a/src/prefect/server/services/telemetry.py +++ b/src/prefect/server/services/telemetry.py @@ -5,14 +5,15 @@ import asyncio import os import platform -from typing import Optional +from typing import Any, Optional from uuid import uuid4 import httpx import pendulum import prefect -from prefect.server.database import PrefectDBInterface, inject_db +from prefect.server.database import PrefectDBInterface +from prefect.server.database.dependencies import db_injector from prefect.server.models import configuration from prefect.server.schemas.core import Configuration from prefect.server.services.loop_service import LoopService @@ -25,15 +26,15 @@ class Telemetry(LoopService): improve. It can be toggled off with the PREFECT_SERVER_ANALYTICS_ENABLED setting. """ - loop_seconds: int = 600 + loop_seconds: float = 600 - def __init__(self, loop_seconds: Optional[int] = None, **kwargs): + def __init__(self, loop_seconds: Optional[int] = None, **kwargs: Any): super().__init__(loop_seconds=loop_seconds, **kwargs) - self.telemetry_environment = os.environ.get( + self.telemetry_environment: str = os.environ.get( "PREFECT_API_TELEMETRY_ENVIRONMENT", "production" ) - @inject_db + @db_injector async def _fetch_or_set_telemetry_session(self, db: PrefectDBInterface): """ This method looks for a telemetry session in the configuration table. If there @@ -66,8 +67,8 @@ async def _fetch_or_set_telemetry_session(self, db: PrefectDBInterface): self.session_start_timestamp = session_start_timestamp else: self.logger.debug("Session information retrieved from database") - self.session_id = telemetry_session.value["session_id"] - self.session_start_timestamp = telemetry_session.value[ + self.session_id: str = telemetry_session.value["session_id"] + self.session_start_timestamp: str = telemetry_session.value[ "session_start_timestamp" ] self.logger.debug( @@ -75,7 +76,7 @@ async def _fetch_or_set_telemetry_session(self, db: PrefectDBInterface): ) return (self.session_start_timestamp, self.session_id) - async def run_once(self): + async def run_once(self) -> None: """ Sends a heartbeat to the sens-o-matic """