From c910d01d15f9ea13ae0d610e15a90da806b6f299 Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Thu, 16 Jan 2025 08:53:35 -0600 Subject: [PATCH] Raise type completeness to 95% (#16723) --- src/prefect/docker/docker_image.py | 22 +++-- src/prefect/engine.py | 29 ++++-- src/prefect/events/schemas/automations.py | 20 ++-- .../infrastructure/provisioners/ecs.py | 98 +++++++++++-------- .../infrastructure/provisioners/modal.py | 9 +- src/prefect/logging/configuration.py | 6 +- src/prefect/logging/filters.py | 4 +- src/prefect/logging/formatters.py | 26 ++--- src/prefect/logging/handlers.py | 31 +++--- src/prefect/logging/highlighters.py | 10 +- src/prefect/logging/loggers.py | 20 ++-- src/prefect/task_runs.py | 15 +-- src/prefect/testing/cli.py | 2 +- src/prefect/testing/docker.py | 8 +- src/prefect/testing/fixtures.py | 54 ++++++---- .../testing/standard_test_suites/blocks.py | 10 +- src/prefect/testing/utilities.py | 19 ++-- src/prefect/workers/base.py | 4 +- 18 files changed, 227 insertions(+), 160 deletions(-) diff --git a/src/prefect/docker/docker_image.py b/src/prefect/docker/docker_image.py index c58442cd0e94..a2de9980d52a 100644 --- a/src/prefect/docker/docker_image.py +++ b/src/prefect/docker/docker_image.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Any, Optional from pendulum import now as pendulum_now @@ -34,7 +34,11 @@ class DockerImage: """ def __init__( - self, name: str, tag: Optional[str] = None, dockerfile="auto", **build_kwargs + self, + name: str, + tag: Optional[str] = None, + dockerfile: str = "auto", + **build_kwargs: Any, ): image_name, image_tag = parse_image_tag(name) if tag and image_tag: @@ -49,16 +53,16 @@ def __init__( namespace = PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE.value() # join the namespace and repository to create the full image name # ignore namespace if it is None - self.name = "/".join(filter(None, [namespace, repository])) - self.tag = tag or image_tag or slugify(pendulum_now("utc").isoformat()) - self.dockerfile = dockerfile - self.build_kwargs = build_kwargs + self.name: str = "/".join(filter(None, [namespace, repository])) + self.tag: str = tag or image_tag or slugify(pendulum_now("utc").isoformat()) + self.dockerfile: str = dockerfile + self.build_kwargs: dict[str, Any] = build_kwargs @property - def reference(self): + def reference(self) -> str: return f"{self.name}:{self.tag}" - def build(self): + def build(self) -> None: full_image_name = self.reference build_kwargs = self.build_kwargs.copy() build_kwargs["context"] = Path.cwd() @@ -72,7 +76,7 @@ def build(self): build_kwargs["dockerfile"] = self.dockerfile build_image(**build_kwargs) - def push(self): + def push(self) -> None: with docker_client() as client: events = client.api.push( repository=self.name, tag=self.tag, stream=True, decode=True diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 262fdaa67071..a2829a7a2001 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -1,6 +1,6 @@ import os import sys -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from uuid import UUID from prefect._internal.compatibility.migration import getattr_migration @@ -15,12 +15,19 @@ run_coro_as_sync, ) -engine_logger = get_logger("engine") +if TYPE_CHECKING: + import logging + + from prefect.flow_engine import FlowRun + from prefect.flows import Flow + from prefect.logging.loggers import LoggingAdapter + +engine_logger: "logging.Logger" = get_logger("engine") if __name__ == "__main__": try: - flow_run_id = UUID( + flow_run_id: UUID = UUID( sys.argv[1] if len(sys.argv) > 1 else os.environ.get("PREFECT__FLOW_RUN_ID") ) except Exception: @@ -37,11 +44,11 @@ run_flow, ) - flow_run = load_flow_run(flow_run_id=flow_run_id) - run_logger = flow_run_logger(flow_run=flow_run) + flow_run: "FlowRun" = load_flow_run(flow_run_id=flow_run_id) + run_logger: "LoggingAdapter" = flow_run_logger(flow_run=flow_run) try: - flow = load_flow(flow_run) + flow: "Flow[..., Any]" = load_flow(flow_run) except Exception: run_logger.error( "Unexpected exception encountered when trying to load flow", @@ -55,15 +62,17 @@ else: run_flow(flow, flow_run=flow_run, error_logger=run_logger) - except Abort as exc: + except Abort as abort_signal: + abort_signal: Abort engine_logger.info( f"Engine execution of flow run '{flow_run_id}' aborted by orchestrator:" - f" {exc}" + f" {abort_signal}" ) exit(0) - except Pause as exc: + except Pause as pause_signal: + pause_signal: Pause engine_logger.info( - f"Engine execution of flow run '{flow_run_id}' is paused: {exc}" + f"Engine execution of flow run '{flow_run_id}' is paused: {pause_signal}" ) exit(0) except Exception: diff --git a/src/prefect/events/schemas/automations.py b/src/prefect/events/schemas/automations.py index 97fe5c7e9a50..2eeea40214c6 100644 --- a/src/prefect/events/schemas/automations.py +++ b/src/prefect/events/schemas/automations.py @@ -52,7 +52,7 @@ def describe_for_cli(self, indent: int = 0) -> str: _deployment_id: Optional[UUID] = PrivateAttr(default=None) - def set_deployment_id(self, deployment_id: UUID): + def set_deployment_id(self, deployment_id: UUID) -> None: self._deployment_id = deployment_id def owner_resource(self) -> Optional[str]: @@ -277,7 +277,7 @@ class MetricTriggerQuery(PrefectBaseModel): ) @field_validator("range", "firing_for") - def enforce_minimum_range(cls, value: timedelta): + def enforce_minimum_range(cls, value: timedelta) -> timedelta: if value < timedelta(seconds=300): raise ValueError("The minimum range is 300 seconds (5 minutes)") return value @@ -404,13 +404,17 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg """Defines an action a user wants to take when a certain number of events do or don't happen to the matching resources""" - name: str = Field(..., description="The name of this automation") - description: str = Field("", description="A longer description of this automation") + name: str = Field(default=..., description="The name of this automation") + description: str = Field( + default="", description="A longer description of this automation" + ) - enabled: bool = Field(True, description="Whether this automation will be evaluated") + enabled: bool = Field( + default=True, description="Whether this automation will be evaluated" + ) trigger: TriggerTypes = Field( - ..., + default=..., description=( "The criteria for which events this Automation covers and how it will " "respond to the presence or absence of those events" @@ -418,7 +422,7 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg ) actions: List[ActionTypes] = Field( - ..., + default=..., description="The actions to perform when this Automation triggers", ) @@ -438,4 +442,4 @@ class AutomationCore(PrefectBaseModel, extra="ignore"): # type: ignore[call-arg class Automation(AutomationCore): - id: UUID = Field(..., description="The ID of this automation") + id: UUID = Field(default=..., description="The ID of this automation") diff --git a/src/prefect/infrastructure/provisioners/ecs.py b/src/prefect/infrastructure/provisioners/ecs.py index a851055e11d3..2f6cb1460f87 100644 --- a/src/prefect/infrastructure/provisioners/ecs.py +++ b/src/prefect/infrastructure/provisioners/ecs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import contextlib import contextvars @@ -9,9 +11,11 @@ from copy import deepcopy from functools import partial from textwrap import dedent -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from types import ModuleType +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional import anyio +import anyio.to_thread from anyio import run_process from rich.console import Console from rich.panel import Panel @@ -33,13 +37,15 @@ if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient -boto3 = lazy_import("boto3") +boto3: ModuleType = lazy_import("boto3") -current_console = contextvars.ContextVar("console", default=Console()) +current_console: contextvars.ContextVar[Console] = contextvars.ContextVar( + "console", default=Console() +) @contextlib.contextmanager -def console_context(value: Console): +def console_context(value: Console) -> Generator[None, None, None]: token = current_console.set(value) try: yield @@ -73,7 +79,7 @@ async def get_task_count(self) -> int: """ return 1 if await self.requires_provisioning() else 0 - def _get_policy_by_name(self, name): + def _get_policy_by_name(self, name: str) -> dict[str, Any] | None: paginator = self._iam_client.get_paginator("list_policies") page_iterator = paginator.paginate(Scope="Local") @@ -119,9 +125,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - policy_document: Dict[str, Any], + policy_document: dict[str, Any], advance: Callable[[], None], - ): + ) -> str: """ Provisions an IAM policy. @@ -153,7 +159,7 @@ async def provision( return policy["Arn"] @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -215,7 +221,7 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, advance: Callable[[], None], - ): + ) -> None: """ Provisions an IAM user. @@ -231,7 +237,7 @@ async def provision( advance() @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -241,7 +247,7 @@ def __init__(self, user_name: str, block_document_name: str): self._user_name = user_name self._requires_provisioning = None - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -357,7 +363,7 @@ async def provision( } @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -374,7 +380,7 @@ def __init__( credentials_block_name or f"{work_pool_name}-aws-credentials" ) self._policy_name = policy_name - self._policy_document = { + self._policy_document: dict[str, Any] = { "Version": "2012-10-17", "Statement": [ { @@ -417,7 +423,11 @@ def __init__( self._execution_role_resource = ExecutionRoleResource() @property - def resources(self): + def resources( + self, + ) -> list[ + "ExecutionRoleResource | IamUserResource | IamPolicyResource | CredentialsBlockResource" + ]: return [ self._execution_role_resource, self._iam_user_resource, @@ -425,7 +435,7 @@ def resources(self): self._credentials_block_resource, ] - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -461,9 +471,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions the authentication resources. @@ -507,7 +517,7 @@ async def provision( ) @property - def next_steps(self): + def next_steps(self) -> list[str]: return [ next_step for resource in self.resources @@ -521,7 +531,7 @@ def __init__(self, cluster_name: str = "prefect-ecs-cluster"): self._cluster_name = cluster_name self._requires_provisioning = None - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -566,9 +576,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions an ECS cluster. @@ -592,7 +602,7 @@ async def provision( ] = self._cluster_name @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -608,7 +618,7 @@ def __init__( self._requires_provisioning = None self._ecs_security_group_name = ecs_security_group_name - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -642,7 +652,9 @@ async def _get_existing_vpc_cidrs(self): response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) return [vpc["CidrBlock"] for vpc in response["Vpcs"]] - async def _find_non_overlapping_cidr(self, default_cidr="172.31.0.0/16"): + async def _find_non_overlapping_cidr( + self, default_cidr: str = "172.31.0.0/16" + ) -> str: """Find a non-overlapping CIDR block""" response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs) existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]] @@ -708,9 +720,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions a VPC. @@ -768,7 +780,7 @@ async def provision( ) )["AvailabilityZones"] zones = [az["ZoneName"] for az in azs] - subnets = [] + subnets: list[Any] = [] for i, subnet_cidr in enumerate(subnet_cidrs[0:3]): subnets.append( await anyio.to_thread.run_sync( @@ -828,7 +840,7 @@ async def provision( ) @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -838,9 +850,9 @@ def __init__(self, work_pool_name: str, repository_name: str = "prefect-flows"): self._repository_name = repository_name self._requires_provisioning = None self._work_pool_name = work_pool_name - self._next_steps = [] + self._next_steps: list[str | Panel] = [] - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -895,9 +907,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> None: """ Provisions an ECR repository. @@ -978,7 +990,7 @@ def my_flow(name: str = "world"): ) @property - def next_steps(self): + def next_steps(self) -> list[str | Panel]: return self._next_steps @@ -1000,7 +1012,7 @@ def __init__(self, execution_role_name: str = "PrefectEcsTaskExecutionRole"): ) self._requires_provisioning = None - async def get_task_count(self): + async def get_task_count(self) -> int: """ Returns the number of tasks that will be executed to provision this resource. @@ -1046,9 +1058,9 @@ async def get_planned_actions(self) -> List[str]: async def provision( self, - base_job_template: Dict[str, Any], + base_job_template: dict[str, Any], advance: Callable[[], None], - ): + ) -> str: """ Provisions an IAM role. @@ -1087,7 +1099,7 @@ async def provision( return response["Role"]["Arn"] @property - def next_steps(self): + def next_steps(self) -> list[str]: return [] @@ -1100,11 +1112,11 @@ def __init__(self): self._console = Console() @property - def console(self): + def console(self) -> Console: return self._console @console.setter - def console(self, value): + def console(self, value: Console) -> None: self._console = value async def _prompt_boto3_installation(self): @@ -1115,7 +1127,7 @@ async def _prompt_boto3_installation(self): boto3 = importlib.import_module("boto3") @staticmethod - def is_boto3_installed(): + def is_boto3_installed() -> bool: """ Check if boto3 is installed. """ @@ -1157,8 +1169,8 @@ def _generate_resources( async def provision( self, work_pool_name: str, - base_job_template: dict, - ) -> Dict[str, Any]: + base_job_template: dict[str, Any], + ) -> dict[str, Any]: """ Provisions the infrastructure for an ECS push work pool. @@ -1310,7 +1322,7 @@ async def provision( # provision calls will be no-ops, but update the base job template base_job_template_copy = deepcopy(base_job_template) - next_steps = [] + next_steps: list[str | Panel] = [] with Progress(console=self._console, disable=num_tasks == 0) as progress: task = progress.add_task( "Provisioning Infrastructure", diff --git a/src/prefect/infrastructure/provisioners/modal.py b/src/prefect/infrastructure/provisioners/modal.py index 274775817f86..04960d9c5d89 100644 --- a/src/prefect/infrastructure/provisioners/modal.py +++ b/src/prefect/infrastructure/provisioners/modal.py @@ -2,6 +2,7 @@ import shlex import sys from copy import deepcopy +from types import ModuleType from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from anyio import run_process @@ -19,7 +20,7 @@ from prefect.client.orchestration import PrefectClient -modal = lazy_import("modal") +modal: ModuleType = lazy_import("modal") class ModalPushProvisioner: @@ -28,14 +29,14 @@ class ModalPushProvisioner: """ def __init__(self, client: Optional["PrefectClient"] = None): - self._console = Console() + self._console: Console = Console() @property - def console(self): + def console(self) -> Console: return self._console @console.setter - def console(self, value): + def console(self, value: Console) -> None: self._console = value @staticmethod diff --git a/src/prefect/logging/configuration.py b/src/prefect/logging/configuration.py index 9b666668e33d..73216645cd99 100644 --- a/src/prefect/logging/configuration.py +++ b/src/prefect/logging/configuration.py @@ -6,7 +6,7 @@ import warnings from functools import partial from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional import yaml @@ -24,10 +24,10 @@ PROCESS_LOGGING_CONFIG: Optional[Dict[str, Any]] = None # Regex call to replace non-alphanumeric characters to '_' to create a valid env var -to_envvar = partial(re.sub, re.compile(r"[^0-9a-zA-Z]+"), "_") +to_envvar: Callable[[str], str] = partial(re.sub, re.compile(r"[^0-9a-zA-Z]+"), "_") -def load_logging_config(path: Path) -> dict: +def load_logging_config(path: Path) -> dict[str, Any]: """ Loads logging configuration from a path allowing override from the environment """ diff --git a/src/prefect/logging/filters.py b/src/prefect/logging/filters.py index 43deb4847a6a..013025063d2a 100644 --- a/src/prefect/logging/filters.py +++ b/src/prefect/logging/filters.py @@ -5,7 +5,7 @@ from prefect.utilities.names import obfuscate -def redact_substr(obj: Any, substr: str): +def redact_substr(obj: Any, substr: str) -> Any: """ Redact a string from a potentially nested object. @@ -17,7 +17,7 @@ def redact_substr(obj: Any, substr: str): Any: The object with the API key redacted. """ - def redact_item(item): + def redact_item(item: Any) -> Any: if isinstance(item, str): return item.replace(substr, obfuscate(substr)) return item diff --git a/src/prefect/logging/formatters.py b/src/prefect/logging/formatters.py index a8eb92c8170d..9fe2c752984b 100644 --- a/src/prefect/logging/formatters.py +++ b/src/prefect/logging/formatters.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import logging.handlers import sys import traceback from types import TracebackType -from typing import Optional, Tuple, Type, Union +from typing import Any, Literal, Optional, Tuple, Type, Union import orjson @@ -14,7 +16,7 @@ ] -def format_exception_info(exc_info: ExceptionInfoType) -> dict: +def format_exception_info(exc_info: ExceptionInfoType) -> dict[str, Any]: # if sys.exc_info() returned a (None, None, None) tuple, # then there's nothing to format if exc_info[0] is None: @@ -40,13 +42,15 @@ class JsonFormatter(logging.Formatter): newlines. """ - def __init__(self, fmt, dmft, style) -> None: # noqa + def __init__( + self, fmt: Literal["pretty", "default"], dmft: str, style: str + ) -> None: # noqa super().__init__() if fmt not in ["pretty", "default"]: raise ValueError("Format must be either 'pretty' or 'default'.") - self.serializer = JSONSerializer( + self.serializer: JSONSerializer = JSONSerializer( jsonlib="orjson", dumps_kwargs={"option": orjson.OPT_INDENT_2} if fmt == "pretty" else {}, ) @@ -72,13 +76,13 @@ def format(self, record: logging.LogRecord) -> str: class PrefectFormatter(logging.Formatter): def __init__( self, - format=None, - datefmt=None, - style="%", - validate=True, + format: str | None = None, + datefmt: str | None = None, + style: str = "%", + validate: bool = True, *, - defaults=None, - task_run_fmt: Optional[str] = None, + defaults: dict[str, Any] | None = None, + task_run_fmt: str | None = None, flow_run_fmt: Optional[str] = None, ) -> None: """ @@ -118,7 +122,7 @@ def __init__( self._flow_run_style.validate() self._task_run_style.validate() - def formatMessage(self, record: logging.LogRecord): + def formatMessage(self, record: logging.LogRecord) -> str: if record.name == "prefect.flow_runs": style = self._flow_run_style elif record.name == "prefect.task_runs": diff --git a/src/prefect/logging/handlers.py b/src/prefect/logging/handlers.py index 0e254a9e58a9..7edd7aa14294 100644 --- a/src/prefect/logging/handlers.py +++ b/src/prefect/logging/handlers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import sys @@ -6,7 +8,7 @@ import uuid import warnings from contextlib import asynccontextmanager -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, TextIO, Type import pendulum from rich.console import Console @@ -34,6 +36,11 @@ PREFECT_LOGGING_TO_API_WHEN_MISSING_FLOW, ) +if sys.version_info >= (3, 12): + StreamHandler = logging.StreamHandler[TextIO] +else: + StreamHandler = logging.StreamHandler + class APILogWorker(BatchedQueueService[Dict[str, Any]]): @property @@ -90,7 +97,7 @@ class APILogHandler(logging.Handler): """ @classmethod - def flush(cls): + def flush(cls) -> None: """ Tell the `APILogWorker` to send any currently enqueued logs and block until completion. @@ -118,7 +125,7 @@ def flush(cls): return APILogWorker.drain_all(timeout=5) @classmethod - async def aflush(cls): + async def aflush(cls) -> bool: """ Tell the `APILogWorker` to send any currently enqueued logs and block until completion. @@ -126,7 +133,7 @@ async def aflush(cls): return await APILogWorker.drain_all() - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: """ Send a log to the `APILogWorker` """ @@ -239,7 +246,7 @@ def _get_payload_size(self, log: Dict[str, Any]) -> int: class WorkerAPILogHandler(APILogHandler): - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: # Open-source API servers do not currently support worker logs, and # worker logs only have an associated worker ID when connected to Cloud, # so we won't send worker logs to the API unless they have a worker ID. @@ -278,13 +285,13 @@ def prepare(self, record: logging.LogRecord) -> Dict[str, Any]: return log -class PrefectConsoleHandler(logging.StreamHandler): +class PrefectConsoleHandler(StreamHandler): def __init__( self, - stream=None, - highlighter: Highlighter = PrefectConsoleHighlighter, - styles: Optional[Dict[str, str]] = None, - level: Union[int, str] = logging.NOTSET, + stream: TextIO | None = None, + highlighter: type[Highlighter] = PrefectConsoleHighlighter, + styles: dict[str, str] | None = None, + level: int | str = logging.NOTSET, ): """ The default console handler for Prefect, which highlights log levels, @@ -307,14 +314,14 @@ def __init__( theme = Theme(inherit=False) self.level = level - self.console = Console( + self.console: Console = Console( highlighter=highlighter, theme=theme, file=self.stream, markup=markup_console, ) - def emit(self, record: logging.LogRecord): + def emit(self, record: logging.LogRecord) -> None: try: message = self.format(record) self.console.print(message, soft_wrap=True) diff --git a/src/prefect/logging/highlighters.py b/src/prefect/logging/highlighters.py index b842f7c95240..ac9b84a273d0 100644 --- a/src/prefect/logging/highlighters.py +++ b/src/prefect/logging/highlighters.py @@ -7,7 +7,7 @@ class LevelHighlighter(RegexHighlighter): """Apply style to log levels.""" base_style = "level." - highlights = [ + highlights: list[str] = [ r"(?PDEBUG)", r"(?PINFO)", r"(?PWARNING)", @@ -20,7 +20,7 @@ class UrlHighlighter(RegexHighlighter): """Apply style to urls.""" base_style = "url." - highlights = [ + highlights: list[str] = [ r"(?P(https|http|ws|wss):\/\/[0-9a-zA-Z\$\-\_\+\!`\(\)\,\.\?\/\;\:\&\=\%\#]*)", r"(?P(file):\/\/[0-9a-zA-Z\$\-\_\+\!`\(\)\,\.\?\/\;\:\&\=\%\#]*)", ] @@ -30,7 +30,7 @@ class NameHighlighter(RegexHighlighter): """Apply style to names.""" base_style = "name." - highlights = [ + highlights: list[str] = [ # ?i means case insensitive # ?<= means find string right after the words: flow run r"(?i)(?P(?<=flow run) \'(.*?)\')", @@ -44,7 +44,7 @@ class StateHighlighter(RegexHighlighter): """Apply style to states.""" base_style = "state." - highlights = [ + highlights: list[str] = [ rf"(?P<{state.lower()}_state>{state.title()})" for state in StateType ] + [ r"(?PCached)(?=\(type=COMPLETED\))" # Highlight only "Cached" @@ -55,7 +55,7 @@ class PrefectConsoleHighlighter(RegexHighlighter): """Applies style from multiple highlighters.""" base_style = "log." - highlights = ( + highlights: list[str] = ( LevelHighlighter.highlights + UrlHighlighter.highlights + NameHighlighter.highlights diff --git a/src/prefect/logging/loggers.py b/src/prefect/logging/loggers.py index 45a2f6195f73..7b061d2787c8 100644 --- a/src/prefect/logging/loggers.py +++ b/src/prefect/logging/loggers.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from functools import lru_cache from logging import LogRecord -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, List, Mapping, MutableMapping, Optional, Union from typing_extensions import Self @@ -39,9 +39,9 @@ class PrefectLogAdapter(LoggingAdapter): not a bug in the LoggingAdapter and subclassing is the intended workaround. """ - extra: Mapping[str, object] | None - - def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]: # type: ignore[incompatibleMethodOverride] + def process( + self, msg: str, kwargs: MutableMapping[str, Any] + ) -> tuple[str, MutableMapping[str, Any]]: kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})} return (msg, kwargs) @@ -192,7 +192,7 @@ def task_run_logger( flow_run: Optional["FlowRun"] = None, flow: Optional["Flow[Any, Any]"] = None, **kwargs: Any, -): +) -> LoggingAdapter: """ Create a task run logger with the run's metadata attached. @@ -228,7 +228,9 @@ def task_run_logger( ) -def get_worker_logger(worker: "BaseWorker", name: Optional[str] = None): +def get_worker_logger( + worker: "BaseWorker", name: Optional[str] = None +) -> logging.Logger | LoggingAdapter: """ Create a worker logger with the worker's metadata attached. @@ -364,7 +366,9 @@ def __init__(self, eavesdrop_on: str, level: int = logging.NOTSET): # It's important that we use a very minimalistic formatter for use cases where # we may present these logs back to the user. We shouldn't leak filenames, # versions, or other environmental information. - self.formatter = logging.Formatter("[%(levelname)s]: %(message)s") + self.formatter: logging.Formatter | None = logging.Formatter( + "[%(levelname)s]: %(message)s" + ) def __enter__(self) -> Self: self._target_logger = logging.getLogger(self.eavesdrop_on) @@ -374,7 +378,7 @@ def __enter__(self) -> Self: self._lines = [] return self - def __exit__(self, *_): + def __exit__(self, *_: Any) -> None: if self._target_logger: self._target_logger.removeHandler(self) self._target_logger.level = self._original_level diff --git a/src/prefect/task_runs.py b/src/prefect/task_runs.py index c76ea3f6418a..bdbce7e6138b 100644 --- a/src/prefect/task_runs.py +++ b/src/prefect/task_runs.py @@ -17,6 +17,9 @@ from prefect.events.filters import EventFilter, EventNameFilter from prefect.logging.loggers import get_logger +if TYPE_CHECKING: + import logging + class TaskRunWaiter: """ @@ -70,8 +73,8 @@ async def main(): _instance_lock = threading.Lock() def __init__(self): - self.logger = get_logger("TaskRunWaiter") - self._consumer_task: asyncio.Task[None] | None = None + self.logger: "logging.Logger" = get_logger("TaskRunWaiter") + self._consumer_task: "asyncio.Task[None] | None" = None self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache( maxsize=10000, ttl=600 ) @@ -82,7 +85,7 @@ def __init__(self): self._completion_events_lock = threading.Lock() self._started = False - def start(self): + def start(self) -> None: """ Start the TaskRunWaiter service. """ @@ -145,7 +148,7 @@ async def _consume_events(self, consumer_started: asyncio.Event): except Exception as exc: self.logger.error(f"Error processing event: {exc}") - def stop(self): + def stop(self) -> None: """ Stop the TaskRunWaiter service. """ @@ -159,7 +162,7 @@ def stop(self): @classmethod async def wait_for_task_run( cls, task_run_id: uuid.UUID, timeout: Optional[float] = None - ): + ) -> None: """ Wait for a task run to finish. @@ -225,7 +228,7 @@ def add_done_callback( instance._completion_callbacks[task_run_id] = callback @classmethod - def instance(cls): + def instance(cls) -> Self: """ Get the singleton instance of TaskRunWaiter. """ diff --git a/src/prefect/testing/cli.py b/src/prefect/testing/cli.py index 7a660119257b..03c893527336 100644 --- a/src/prefect/testing/cli.py +++ b/src/prefect/testing/cli.py @@ -13,7 +13,7 @@ from prefect.utilities.asyncutils import in_async_main_thread -def check_contains(cli_result: Result, content: str, should_contain: bool): +def check_contains(cli_result: Result, content: str, should_contain: bool) -> None: """ Utility function to see if content is or is not in a CLI result. diff --git a/src/prefect/testing/docker.py b/src/prefect/testing/docker.py index d70a1a2e5542..1e0de02624ad 100644 --- a/src/prefect/testing/docker.py +++ b/src/prefect/testing/docker.py @@ -1,18 +1,18 @@ from contextlib import contextmanager -from typing import Generator, List +from typing import Any, Generator from unittest import mock from prefect.utilities.dockerutils import ImageBuilder @contextmanager -def capture_builders() -> Generator[List[ImageBuilder], None, None]: +def capture_builders() -> Generator[list[ImageBuilder], None, None]: """Captures any instances of ImageBuilder created while this context is active""" - builders = [] + builders: list[ImageBuilder] = [] original_init = ImageBuilder.__init__ - def capture(self, *args, **kwargs): + def capture(self: ImageBuilder, *args: Any, **kwargs: Any): builders.append(self) original_init(self, *args, **kwargs) diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py index 07352f872afc..3b59ab64440e 100644 --- a/src/prefect/testing/fixtures.py +++ b/src/prefect/testing/fixtures.py @@ -4,7 +4,7 @@ import socket import sys from contextlib import contextmanager -from typing import AsyncGenerator, Generator, List, Optional, Union +from typing import Any, AsyncGenerator, Callable, Generator, List, Optional, Union from unittest import mock from uuid import UUID @@ -39,7 +39,9 @@ @pytest.fixture(autouse=True) -def add_prefect_loggers_to_caplog(caplog): +def add_prefect_loggers_to_caplog( + caplog: pytest.LogCaptureFixture, +) -> Generator[None, None, None]: import logging logger = logging.getLogger("prefect") @@ -57,7 +59,9 @@ def is_port_in_use(port: int) -> bool: @pytest.fixture(scope="session") -async def hosted_api_server(unused_tcp_port_factory): +async def hosted_api_server( + unused_tcp_port_factory: Callable[[], int], +) -> AsyncGenerator[str, None]: """ Runs an instance of the Prefect API server in a subprocess instead of the using the ephemeral application. @@ -134,7 +138,7 @@ async def hosted_api_server(unused_tcp_port_factory): @pytest.fixture(autouse=True) -def use_hosted_api_server(hosted_api_server): +def use_hosted_api_server(hosted_api_server: str) -> Generator[str, None, None]: """ Sets `PREFECT_API_URL` to the test session's hosted API endpoint. """ @@ -148,7 +152,7 @@ def use_hosted_api_server(hosted_api_server): @pytest.fixture -def disable_hosted_api_server(): +def disable_hosted_api_server() -> Generator[None, None, None]: """ Disables the hosted API server by setting `PREFECT_API_URL` to `None`. """ @@ -157,11 +161,13 @@ def disable_hosted_api_server(): PREFECT_API_URL: None, } ): - yield hosted_api_server + yield @pytest.fixture -def enable_ephemeral_server(disable_hosted_api_server): +def enable_ephemeral_server( + disable_hosted_api_server: None, +) -> Generator[None, None, None]: """ Enables the ephemeral server by setting `PREFECT_SERVER_ALLOW_EPHEMERAL_MODE` to `True`. """ @@ -170,13 +176,15 @@ def enable_ephemeral_server(disable_hosted_api_server): PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, } ): - yield hosted_api_server + yield SubprocessASGIServer().stop() @pytest.fixture -def mock_anyio_sleep(monkeypatch): +def mock_anyio_sleep( + monkeypatch: pytest.MonkeyPatch, +) -> Generator[Callable[[float], None], None, None]: """ Mock sleep used to not actually sleep but to set the current time to now + sleep delay seconds while still yielding to other tasks in the event loop. @@ -188,18 +196,18 @@ def mock_anyio_sleep(monkeypatch): original_sleep = anyio.sleep time_shift = 0.0 - async def callback(delay_in_seconds): + async def callback(delay_in_seconds: float) -> None: nonlocal time_shift time_shift += float(delay_in_seconds) # Preserve yield effects of sleep await original_sleep(0) - def latest_now(*args): + def latest_now(*args: Any) -> pendulum.DateTime: # Fast-forwards the time by the total sleep time return original_now(*args).add( # Ensure we retain float precision seconds=int(time_shift), - microseconds=(time_shift - int(time_shift)) * 1000000, + microseconds=int((time_shift - int(time_shift)) * 1000000), ) monkeypatch.setattr("pendulum.now", latest_now) @@ -368,7 +376,7 @@ def events_cloud_api_url(events_server: WebSocketServer, unused_tcp_port: int) - @pytest.fixture -def mock_should_emit_events(monkeypatch) -> mock.Mock: +def mock_should_emit_events(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: m = mock.Mock() m.return_value = True monkeypatch.setattr("prefect.events.utilities.should_emit_events", m) @@ -376,7 +384,9 @@ def mock_should_emit_events(monkeypatch) -> mock.Mock: @pytest.fixture -def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]: +def asserting_events_worker( + monkeypatch: pytest.MonkeyPatch, +) -> Generator[EventsWorker, None, None]: worker = EventsWorker.instance(AssertingEventsClient) # Always yield the asserting worker when new instances are retrieved monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker) @@ -388,7 +398,7 @@ def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]: @pytest.fixture def asserting_and_emitting_events_worker( - monkeypatch, + monkeypatch: pytest.MonkeyPatch, ) -> Generator[EventsWorker, None, None]: worker = EventsWorker.instance(AssertingPassthroughEventsClient) # Always yield the asserting worker when new instances are retrieved @@ -400,7 +410,9 @@ def asserting_and_emitting_events_worker( @pytest.fixture -async def events_pipeline(asserting_events_worker: EventsWorker): +async def events_pipeline( + asserting_events_worker: EventsWorker, +) -> AsyncGenerator[EventsPipeline, None]: class AssertingEventsPipeline(EventsPipeline): @sync_compatible async def process_events( @@ -435,7 +447,9 @@ async def wait_for_min_events(): @pytest.fixture -async def emitting_events_pipeline(asserting_and_emitting_events_worker: EventsWorker): +async def emitting_events_pipeline( + asserting_and_emitting_events_worker: EventsWorker, +) -> AsyncGenerator[EventsPipeline, None]: class AssertingAndEmittingEventsPipeline(EventsPipeline): @sync_compatible async def process_events(self): @@ -449,14 +463,16 @@ async def process_events(self): @pytest.fixture -def reset_worker_events(asserting_events_worker: EventsWorker): +def reset_worker_events( + asserting_events_worker: EventsWorker, +) -> Generator[None, None, None]: yield assert isinstance(asserting_events_worker._client, AssertingEventsClient) asserting_events_worker._client.events = [] @pytest.fixture -def enable_lineage_events(): +def enable_lineage_events() -> Generator[None, None, None]: """A fixture that ensures lineage events are enabled.""" with temporary_settings(updates={PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED: True}): yield diff --git a/src/prefect/testing/standard_test_suites/blocks.py b/src/prefect/testing/standard_test_suites/blocks.py index 0fa04776784e..f71f80c1244e 100644 --- a/src/prefect/testing/standard_test_suites/blocks.py +++ b/src/prefect/testing/standard_test_suites/blocks.py @@ -14,13 +14,13 @@ class BlockStandardTestSuite(ABC): def block(self) -> type[Block]: pass - def test_has_a_description(self, block: type[Block]): + def test_has_a_description(self, block: type[Block]) -> None: assert block.get_description() - def test_has_a_documentation_url(self, block: type[Block]): + def test_has_a_documentation_url(self, block: type[Block]) -> None: assert block._documentation_url - def test_all_fields_have_a_description(self, block: type[Block]): + def test_all_fields_have_a_description(self, block: type[Block]) -> None: for name, field in block.model_fields.items(): if Block.annotation_refers_to_block_class(field.annotation): # TODO: Block field descriptions aren't currently handled by the UI, so @@ -34,7 +34,7 @@ def test_all_fields_have_a_description(self, block: type[Block]): "." ), f"{name} description on {block.__name__} does not end with a period" - def test_has_a_valid_code_example(self, block: type[Block]): + def test_has_a_valid_code_example(self, block: type[Block]) -> None: code_example = block.get_code_example() assert code_example is not None, f"{block.__name__} is missing a code example" @@ -55,7 +55,7 @@ def test_has_a_valid_code_example(self, block: type[Block]): f" matching the pattern {block_load_pattern}" ) - def test_has_a_valid_image(self, block: type[Block]): + def test_has_a_valid_image(self, block: type[Block]) -> None: logo_url = block._logo_url assert ( logo_url is not None diff --git a/src/prefect/testing/utilities.py b/src/prefect/testing/utilities.py index 2b874240ec39..3d76c9dc4fda 100644 --- a/src/prefect/testing/utilities.py +++ b/src/prefect/testing/utilities.py @@ -2,6 +2,8 @@ Internal utilities for tests. """ +from __future__ import annotations + import atexit import shutil import warnings @@ -9,7 +11,7 @@ from pathlib import Path from pprint import pprint from tempfile import mkdtemp -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import prefect.context import prefect.settings @@ -30,6 +32,7 @@ if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient + from prefect.client.schemas.objects import FlowRun from prefect.filesystems import ReadableFileSystem @@ -175,7 +178,7 @@ def cleanup_temp_dir(temp_dir): test_server.stop() -async def get_most_recent_flow_run(client: "PrefectClient" = None): +async def get_most_recent_flow_run(client: "PrefectClient | None" = None) -> "FlowRun": if client is None: client = get_client() @@ -187,8 +190,8 @@ async def get_most_recent_flow_run(client: "PrefectClient" = None): def assert_blocks_equal( - found: Block, expected: Block, exclude_private: bool = True, **kwargs -) -> bool: + found: Block, expected: Block, exclude_private: bool = True, **kwargs: Any +) -> None: assert isinstance( found, type(expected) ), f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}" @@ -205,7 +208,7 @@ def assert_blocks_equal( async def assert_uses_result_serializer( state: State, serializer: Union[str, Serializer], client: "PrefectClient" -): +) -> None: assert isinstance(state.data, (ResultRecord, ResultRecordMetadata)) if isinstance(state.data, ResultRecord): result_serializer = state.data.metadata.serializer @@ -241,7 +244,7 @@ async def assert_uses_result_serializer( @inject_client async def assert_uses_result_storage( state: State, storage: Union[str, "ReadableFileSystem"], client: "PrefectClient" -): +) -> None: assert isinstance(state.data, (ResultRecord, ResultRecordMetadata)) if isinstance(state.data, ResultRecord): assert_blocks_equal( @@ -267,11 +270,11 @@ async def assert_uses_result_storage( ) -def a_test_step(**kwargs): +def a_test_step(**kwargs: Any) -> dict[str, Any]: kwargs.update({"output1": 1, "output2": ["b", 2, 3]}) return kwargs -def b_test_step(**kwargs): +def b_test_step(**kwargs: Any) -> dict[str, Any]: kwargs.update({"output1": 1, "output2": ["b", 2, 3]}) return kwargs diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index ebc67c05d794..563c1cc2eb66 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -432,7 +432,7 @@ def __init__( self._prefetch_seconds: float = ( prefetch_seconds or PREFECT_WORKER_PREFETCH_SECONDS.value() ) - self.heartbeat_interval_seconds = ( + self.heartbeat_interval_seconds: int = ( heartbeat_interval_seconds or PREFECT_WORKER_HEARTBEAT_SECONDS.value() ) @@ -640,7 +640,7 @@ async def setup(self): async def teardown(self, *exc_info): """Cleans up resources after the worker is stopped.""" self._logger.debug("Tearing down worker...") - self.is_setup = False + self.is_setup: bool = False for scope in self._scheduled_task_scopes: scope.cancel()