From 24c7f082dbe99fb09e56349473320f1828dae1c3 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Sun, 8 Dec 2024 02:24:58 +0000 Subject: [PATCH] [typing] prefect.client (#16265) --- src/prefect/_internal/schemas/validators.py | 3 +- src/prefect/client/__init__.py | 4 +- src/prefect/client/base.py | 55 +- src/prefect/client/cloud.py | 36 +- src/prefect/client/orchestration.py | 572 ++++++++++---------- src/prefect/client/schemas/__init__.py | 24 + src/prefect/client/schemas/actions.py | 246 +++++---- src/prefect/client/schemas/objects.py | 187 ++++--- src/prefect/client/schemas/responses.py | 36 +- src/prefect/client/schemas/schedules.py | 229 ++++---- src/prefect/client/subscriptions.py | 16 +- src/prefect/client/utilities.py | 72 ++- src/prefect/main.py | 33 +- 13 files changed, 816 insertions(+), 697 deletions(-) diff --git a/src/prefect/_internal/schemas/validators.py b/src/prefect/_internal/schemas/validators.py index 9bda7fc5edff..cff72820e19b 100644 --- a/src/prefect/_internal/schemas/validators.py +++ b/src/prefect/_internal/schemas/validators.py @@ -13,6 +13,7 @@ from copy import copy from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +from uuid import UUID import jsonschema import pendulum @@ -653,7 +654,7 @@ def validate_message_template_variables(v: Optional[str]) -> Optional[str]: return v -def validate_default_queue_id_not_none(v: Optional[str]) -> Optional[str]: +def validate_default_queue_id_not_none(v: Optional[UUID]) -> UUID: if v is None: raise ValueError( "`default_queue_id` is a required field. If you are " diff --git a/src/prefect/client/__init__.py b/src/prefect/client/__init__.py index 5d2fc25a2a9f..df0bfd34dcab 100644 --- a/src/prefect/client/__init__.py +++ b/src/prefect/client/__init__.py @@ -16,6 +16,8 @@ """ +from collections.abc import Callable +from typing import Any from prefect._internal.compatibility.migration import getattr_migration -__getattr__ = getattr_migration(__name__) +__getattr__: Callable[[str], Any] = getattr_migration(__name__) diff --git a/src/prefect/client/base.py b/src/prefect/client/base.py index 5071387668aa..7eed8a92a497 100644 --- a/src/prefect/client/base.py +++ b/src/prefect/client/base.py @@ -4,22 +4,11 @@ import time import uuid from collections import defaultdict +from collections.abc import AsyncGenerator, Awaitable, MutableMapping from contextlib import asynccontextmanager from datetime import datetime, timezone -from typing import ( - Any, - AsyncGenerator, - Awaitable, - Callable, - Dict, - MutableMapping, - Optional, - Protocol, - Set, - Tuple, - Type, - runtime_checkable, -) +from logging import Logger +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable import anyio import httpx @@ -46,14 +35,14 @@ # Datastores for lifespan management, keys should be a tuple of thread and app # identities. -APP_LIFESPANS: Dict[Tuple[int, int], LifespanManager] = {} -APP_LIFESPANS_REF_COUNTS: Dict[Tuple[int, int], int] = {} +APP_LIFESPANS: dict[tuple[int, int], LifespanManager] = {} +APP_LIFESPANS_REF_COUNTS: dict[tuple[int, int], int] = {} # Blocks concurrent access to the above dicts per thread. The index should be the thread # identity. -APP_LIFESPANS_LOCKS: Dict[int, anyio.Lock] = defaultdict(anyio.Lock) +APP_LIFESPANS_LOCKS: dict[int, anyio.Lock] = defaultdict(anyio.Lock) -logger = get_logger("client") +logger: Logger = get_logger("client") # Define ASGI application types for type checking @@ -174,9 +163,9 @@ def raise_for_status(self) -> Response: raise PrefectHTTPStatusError.from_httpx_error(exc) from exc.__cause__ @classmethod - def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Response: + def from_httpx_response(cls: type[Self], response: httpx.Response) -> Response: """ - Create a `PrefectReponse` from an `httpx.Response`. + Create a `PrefectResponse` from an `httpx.Response`. By changing the `__class__` attribute of the Response, we change the method resolution order to look for methods defined in PrefectResponse, while leaving @@ -222,10 +211,10 @@ async def _send_with_retry( self, request: Request, send: Callable[[Request], Awaitable[Response]], - send_args: Tuple[Any, ...], - send_kwargs: Dict[str, Any], - retry_codes: Set[int] = set(), - retry_exceptions: Tuple[Type[Exception], ...] = tuple(), + send_args: tuple[Any, ...], + send_kwargs: dict[str, Any], + retry_codes: set[int] = set(), + retry_exceptions: tuple[type[Exception], ...] = tuple(), ): """ Send a request and retry it if it fails. @@ -240,6 +229,11 @@ async def _send_with_retry( try_count = 0 response = None + if TYPE_CHECKING: + # older httpx versions type method as str | bytes | Unknown + # but in reality it is always a string. + assert isinstance(request.method, str) # type: ignore + is_change_request = request.method.lower() in {"post", "put", "patch", "delete"} if self.enable_csrf_support and is_change_request: @@ -436,10 +430,10 @@ def _send_with_retry( self, request: Request, send: Callable[[Request], Response], - send_args: Tuple[Any, ...], - send_kwargs: Dict[str, Any], - retry_codes: Set[int] = set(), - retry_exceptions: Tuple[Type[Exception], ...] = tuple(), + send_args: tuple[Any, ...], + send_kwargs: dict[str, Any], + retry_codes: set[int] = set(), + retry_exceptions: tuple[type[Exception], ...] = tuple(), ): """ Send a request and retry it if it fails. @@ -454,6 +448,11 @@ def _send_with_retry( try_count = 0 response = None + if TYPE_CHECKING: + # older httpx versions type method as str | bytes | Unknown + # but in reality it is always a string. + assert isinstance(request.method, str) # type: ignore + is_change_request = request.method.lower() in {"post", "put", "patch", "delete"} if self.enable_csrf_support and is_change_request: diff --git a/src/prefect/client/cloud.py b/src/prefect/client/cloud.py index 6542393ed4b7..90cae81b0f51 100644 --- a/src/prefect/client/cloud.py +++ b/src/prefect/client/cloud.py @@ -1,11 +1,12 @@ import re -from typing import Any, Dict, List, Optional, cast +from typing import Any, NoReturn, Optional, cast from uuid import UUID import anyio import httpx import pydantic from starlette import status +from typing_extensions import Self import prefect.context import prefect.settings @@ -30,7 +31,7 @@ def get_cloud_client( host: Optional[str] = None, api_key: Optional[str] = None, - httpx_settings: Optional[Dict[str, Any]] = None, + httpx_settings: Optional[dict[str, Any]] = None, infer_cloud_url: bool = False, ) -> "CloudClient": """ @@ -62,11 +63,14 @@ class CloudUnauthorizedError(PrefectException): class CloudClient: + account_id: Optional[str] = None + workspace_id: Optional[str] = None + def __init__( self, host: str, api_key: str, - httpx_settings: Optional[Dict[str, Any]] = None, + httpx_settings: Optional[dict[str, Any]] = None, ) -> None: httpx_settings = httpx_settings or dict() httpx_settings.setdefault("headers", dict()) @@ -79,7 +83,7 @@ def __init__( **httpx_settings, enable_csrf_support=False ) - api_url = prefect.settings.PREFECT_API_URL.value() or "" + api_url: str = prefect.settings.PREFECT_API_URL.value() or "" if match := ( re.search(PARSE_API_URL_REGEX, host) or re.search(PARSE_API_URL_REGEX, api_url) @@ -100,7 +104,7 @@ def workspace_base_url(self) -> str: return f"{self.account_base_url}/workspaces/{self.workspace_id}" - async def api_healthcheck(self): + async def api_healthcheck(self) -> None: """ Attempts to connect to the Cloud API and raises the encountered exception if not successful. @@ -110,8 +114,8 @@ async def api_healthcheck(self): with anyio.fail_after(10): await self.read_workspaces() - async def read_workspaces(self) -> List[Workspace]: - workspaces = pydantic.TypeAdapter(List[Workspace]).validate_python( + async def read_workspaces(self) -> list[Workspace]: + workspaces = pydantic.TypeAdapter(list[Workspace]).validate_python( await self.get("/me/workspaces") ) return workspaces @@ -124,17 +128,17 @@ async def read_current_workspace(self) -> Workspace: return workspace raise ValueError("Current workspace not found") - async def read_worker_metadata(self) -> Dict[str, Any]: + async def read_worker_metadata(self) -> dict[str, Any]: response = await self.get( f"{self.workspace_base_url}/collections/work_pool_types" ) - return cast(Dict[str, Any], response) + return cast(dict[str, Any], response) - async def read_account_settings(self) -> Dict[str, Any]: + async def read_account_settings(self) -> dict[str, Any]: response = await self.get(f"{self.account_base_url}/settings") - return cast(Dict[str, Any], response) + return cast(dict[str, Any], response) - async def update_account_settings(self, settings: Dict[str, Any]): + async def update_account_settings(self, settings: dict[str, Any]) -> None: await self.request( "PATCH", f"{self.account_base_url}/settings", @@ -145,7 +149,7 @@ async def read_account_ip_allowlist(self) -> IPAllowlist: response = await self.get(f"{self.account_base_url}/ip_allowlist") return IPAllowlist.model_validate(response) - async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist): + async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist) -> None: await self.request( "PUT", f"{self.account_base_url}/ip_allowlist", @@ -175,20 +179,20 @@ async def update_flow_run_labels( json=labels, ) - async def __aenter__(self): + async def __aenter__(self) -> Self: await self._client.__aenter__() return self async def __aexit__(self, *exc_info: Any) -> None: return await self._client.__aexit__(*exc_info) - def __enter__(self): + def __enter__(self) -> NoReturn: raise RuntimeError( "The `CloudClient` must be entered with an async context. Use 'async " "with CloudClient(...)' not 'with CloudClient(...)'" ) - def __exit__(self, *_): + def __exit__(self, *_: object) -> NoReturn: assert False, "This should never be called but must be defined for __enter__" async def get(self, route: str, **kwargs: Any) -> Any: diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index ab83a4dcbeb9..f244ccde5708 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -2,21 +2,10 @@ import datetime import ssl import warnings +from collections.abc import Iterable from contextlib import AsyncExitStack -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Literal, - Optional, - Set, - Tuple, - TypeVar, - Union, - overload, -) +from logging import Logger +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Optional, Union, overload from uuid import UUID, uuid4 import certifi @@ -27,7 +16,7 @@ from asgi_lifespan import LifespanManager from packaging import version from starlette import status -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Self, TypeVar import prefect import prefect.exceptions @@ -152,26 +141,29 @@ ) P = ParamSpec("P") -R = TypeVar("R") +R = TypeVar("R", infer_variance=True) +T = TypeVar("T") @overload def get_client( - httpx_settings: Optional[Dict[str, Any]] = None, sync_client: Literal[False] = False + *, + httpx_settings: Optional[dict[str, Any]] = ..., + sync_client: Literal[False] = False, ) -> "PrefectClient": ... @overload def get_client( - httpx_settings: Optional[Dict[str, Any]] = None, sync_client: Literal[True] = True + *, httpx_settings: Optional[dict[str, Any]] = ..., sync_client: Literal[True] = ... ) -> "SyncPrefectClient": ... def get_client( - httpx_settings: Optional[Dict[str, Any]] = None, sync_client: bool = False -): + httpx_settings: Optional[dict[str, Any]] = None, sync_client: bool = False +) -> Union["SyncPrefectClient", "PrefectClient"]: """ Retrieve a HTTP client for communicating with the Prefect REST API. @@ -200,18 +192,21 @@ def get_client( if sync_client: if client_ctx := prefect.context.SyncClientContext.get(): - if client_ctx.client and client_ctx._httpx_settings == httpx_settings: + if ( + client_ctx.client + and getattr(client_ctx, "_httpx_settings", None) == httpx_settings + ): return client_ctx.client else: if client_ctx := prefect.context.AsyncClientContext.get(): if ( client_ctx.client - and client_ctx._httpx_settings == httpx_settings - and loop in (client_ctx.client._loop, None) + and getattr(client_ctx, "_httpx_settings", None) == httpx_settings + and loop in (getattr(client_ctx.client, "_loop", None), None) ): return client_ctx.client - api = PREFECT_API_URL.value() + api: str = PREFECT_API_URL.value() server_type = None if not api and PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: @@ -277,7 +272,7 @@ def __init__( *, api_key: Optional[str] = None, api_version: Optional[str] = None, - httpx_settings: Optional[Dict[str, Any]] = None, + httpx_settings: Optional[dict[str, Any]] = None, server_type: Optional[ServerType] = None, ) -> None: httpx_settings = httpx_settings.copy() if httpx_settings else {} @@ -357,7 +352,7 @@ def __init__( ) # Connect to an in-process application - elif isinstance(api, ASGIApp): + else: self._ephemeral_app = api self.server_type = ServerType.EPHEMERAL @@ -377,12 +372,6 @@ def __init__( ) httpx_settings.setdefault("base_url", "http://ephemeral-prefect/api") - else: - raise TypeError( - f"Unexpected type {type(api).__name__!r} for argument `api`. Expected" - " 'str' or 'ASGIApp/FastAPI'" - ) - # See https://www.python-httpx.org/advanced/#timeout-configuration httpx_settings.setdefault( "timeout", @@ -426,9 +415,9 @@ def __init__( if isinstance(server_transport, httpx.AsyncHTTPTransport): pool = getattr(server_transport, "_pool", None) if isinstance(pool, httpcore.AsyncConnectionPool): - pool._retries = 3 + setattr(pool, "_retries", 3) - self.logger = get_logger("client") + self.logger: Logger = get_logger("client") @property def api_url(self) -> httpx.URL: @@ -458,7 +447,7 @@ async def hello(self) -> httpx.Response: """ return await self._client.get("/hello") - async def create_flow(self, flow: "FlowObject") -> UUID: + async def create_flow(self, flow: "FlowObject[Any, Any]") -> UUID: """ Create a flow in the Prefect API. @@ -514,16 +503,16 @@ async def read_flow(self, flow_id: UUID) -> Flow: async def read_flows( self, *, - flow_filter: FlowFilter = None, - flow_run_filter: FlowRunFilter = None, - task_run_filter: TaskRunFilter = None, - deployment_filter: DeploymentFilter = None, - work_pool_filter: WorkPoolFilter = None, - work_queue_filter: WorkQueueFilter = None, - sort: FlowSort = None, + flow_filter: Optional[FlowFilter] = None, + flow_run_filter: Optional[FlowRunFilter] = None, + task_run_filter: Optional[TaskRunFilter] = None, + deployment_filter: Optional[DeploymentFilter] = None, + work_pool_filter: Optional[WorkPoolFilter] = None, + work_queue_filter: Optional[WorkQueueFilter] = None, + sort: Optional[FlowSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[Flow]: + ) -> list[Flow]: """ Query the Prefect API for flows. Only flows matching all criteria will be returned. @@ -542,7 +531,7 @@ async def read_flows( Returns: a list of Flow model representations of the flows """ - body = { + body: dict[str, Any] = { "flows": flow_filter.model_dump(mode="json") if flow_filter else None, "flow_runs": ( flow_run_filter.model_dump(mode="json", exclude_unset=True) @@ -567,7 +556,7 @@ async def read_flows( } response = await self._client.post("/flows/filter", json=body) - return pydantic.TypeAdapter(List[Flow]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Flow]).validate_python(response.json()) async def read_flow_by_name( self, @@ -589,15 +578,15 @@ async def create_flow_run_from_deployment( self, deployment_id: UUID, *, - parameters: Optional[Dict[str, Any]] = None, - context: Optional[Dict[str, Any]] = None, - state: Optional[prefect.states.State] = None, + parameters: Optional[dict[str, Any]] = None, + context: Optional[dict[str, Any]] = None, + state: Optional[prefect.states.State[Any]] = None, name: Optional[str] = None, tags: Optional[Iterable[str]] = None, idempotency_key: Optional[str] = None, parent_task_run_id: Optional[UUID] = None, work_queue_name: Optional[str] = None, - job_variables: Optional[Dict[str, Any]] = None, + job_variables: Optional[dict[str, Any]] = None, ) -> FlowRun: """ Create a flow run for a deployment. @@ -638,7 +627,7 @@ async def create_flow_run_from_deployment( parameters=parameters, context=context, state=state.to_state_create(), - tags=tags, + tags=list(tags), name=name, idempotency_key=idempotency_key, parent_task_run_id=parent_task_run_id, @@ -657,13 +646,13 @@ async def create_flow_run_from_deployment( async def create_flow_run( self, - flow: "FlowObject", + flow: "FlowObject[Any, R]", name: Optional[str] = None, - parameters: Optional[Dict[str, Any]] = None, - context: Optional[Dict[str, Any]] = None, + parameters: Optional[dict[str, Any]] = None, + context: Optional[dict[str, Any]] = None, tags: Optional[Iterable[str]] = None, parent_task_run_id: Optional[UUID] = None, - state: Optional["prefect.states.State"] = None, + state: Optional["prefect.states.State[R]"] = None, ) -> FlowRun: """ Create a flow run for a flow. @@ -705,7 +694,7 @@ async def create_flow_run( state=state.to_state_create(), empirical_policy=FlowRunPolicy( retries=flow.retries, - retry_delay=flow.retry_delay_seconds, + retry_delay=int(flow.retry_delay_seconds or 0), ), ) @@ -723,12 +712,12 @@ async def update_flow_run( self, flow_run_id: UUID, flow_version: Optional[str] = None, - parameters: Optional[dict] = None, + parameters: Optional[dict[str, Any]] = None, name: Optional[str] = None, tags: Optional[Iterable[str]] = None, empirical_policy: Optional[FlowRunPolicy] = None, infrastructure_pid: Optional[str] = None, - job_variables: Optional[dict] = None, + job_variables: Optional[dict[str, Any]] = None, ) -> httpx.Response: """ Update a flow run's details. @@ -749,7 +738,7 @@ async def update_flow_run( Returns: an `httpx.Response` object from the PATCH request """ - params = {} + params: dict[str, Any] = {} if flow_version is not None: params["flow_version"] = flow_version if parameters is not None: @@ -832,7 +821,7 @@ async def create_concurrency_limit( async def read_concurrency_limit_by_tag( self, tag: str, - ): + ) -> ConcurrencyLimit: """ Read the concurrency limit set on a specific tag. @@ -868,7 +857,7 @@ async def read_concurrency_limits( self, limit: int, offset: int, - ): + ) -> list[ConcurrencyLimit]: """ Lists concurrency limits set on task run tags. @@ -886,15 +875,15 @@ async def read_concurrency_limits( } response = await self._client.post("/concurrency_limits/filter", json=body) - return pydantic.TypeAdapter(List[ConcurrencyLimit]).validate_python( + return pydantic.TypeAdapter(list[ConcurrencyLimit]).validate_python( response.json() ) async def reset_concurrency_limit_by_tag( self, tag: str, - slot_override: Optional[List[Union[UUID, str]]] = None, - ): + slot_override: Optional[list[Union[UUID, str]]] = None, + ) -> None: """ Resets the concurrency limit slots set on a specific tag. @@ -927,7 +916,7 @@ async def reset_concurrency_limit_by_tag( async def delete_concurrency_limit_by_tag( self, tag: str, - ): + ) -> None: """ Delete the concurrency limit set on a specific tag. @@ -951,7 +940,7 @@ async def delete_concurrency_limit_by_tag( async def increment_v1_concurrency_slots( self, - names: List[str], + names: list[str], task_run_id: UUID, ) -> httpx.Response: """ @@ -961,7 +950,7 @@ async def increment_v1_concurrency_slots( names (List[str]): A list of limit names for which to increment limits. task_run_id (UUID): The task run ID incrementing the limits. """ - data = { + data: dict[str, Any] = { "names": names, "task_run_id": str(task_run_id), } @@ -973,7 +962,7 @@ async def increment_v1_concurrency_slots( async def decrement_v1_concurrency_slots( self, - names: List[str], + names: list[str], task_run_id: UUID, occupancy_seconds: float, ) -> httpx.Response: @@ -989,7 +978,7 @@ async def decrement_v1_concurrency_slots( Returns: httpx.Response: The HTTP response from the server. """ - data = { + data: dict[str, Any] = { "names": names, "task_run_id": str(task_run_id), "occupancy_seconds": occupancy_seconds, @@ -1089,7 +1078,7 @@ async def read_work_queue_by_name( return WorkQueue.model_validate(response.json()) - async def update_work_queue(self, id: UUID, **kwargs): + async def update_work_queue(self, id: UUID, **kwargs: Any) -> None: """ Update properties of a work queue. @@ -1119,8 +1108,8 @@ async def get_runs_in_work_queue( self, id: UUID, limit: int = 10, - scheduled_before: datetime.datetime = None, - ) -> List[FlowRun]: + scheduled_before: Optional[datetime.datetime] = None, + ) -> list[FlowRun]: """ Read flow runs off a work queue. @@ -1153,7 +1142,7 @@ async def get_runs_in_work_queue( raise prefect.exceptions.ObjectNotFound(http_exc=e) from e else: raise - return pydantic.TypeAdapter(List[FlowRun]).validate_python(response.json()) + return pydantic.TypeAdapter(list[FlowRun]).validate_python(response.json()) async def read_work_queue( self, @@ -1209,9 +1198,9 @@ async def read_work_queue_status( async def match_work_queues( self, - prefixes: List[str], + prefixes: list[str], work_pool_name: Optional[str] = None, - ) -> List[WorkQueue]: + ) -> list[WorkQueue]: """ Query the Prefect API for work queues with names with a specific prefix. @@ -1225,7 +1214,7 @@ async def match_work_queues( """ page_length = 100 current_page = 0 - work_queues = [] + work_queues: list[WorkQueue] = [] while True: new_queues = await self.read_work_queues( @@ -1246,7 +1235,7 @@ async def match_work_queues( async def delete_work_queue_by_id( self, id: UUID, - ): + ) -> None: """ Delete a work queue by its ID. @@ -1343,7 +1332,7 @@ async def update_block_document( self, block_document_id: UUID, block_document: BlockDocumentUpdate, - ): + ) -> None: """ Update a block document in the Prefect API. """ @@ -1362,7 +1351,7 @@ async def update_block_document( else: raise - async def delete_block_document(self, block_document_id: UUID): + async def delete_block_document(self, block_document_id: UUID) -> None: """ Delete a block document. """ @@ -1405,7 +1394,9 @@ async def read_block_schema_by_checksum( raise return BlockSchema.model_validate(response.json()) - async def update_block_type(self, block_type_id: UUID, block_type: BlockTypeUpdate): + async def update_block_type( + self, block_type_id: UUID, block_type: BlockTypeUpdate + ) -> None: """ Update a block document in the Prefect API. """ @@ -1424,7 +1415,7 @@ async def update_block_type(self, block_type_id: UUID, block_type: BlockTypeUpda else: raise - async def delete_block_type(self, block_type_id: UUID): + async def delete_block_type(self, block_type_id: UUID) -> None: """ Delete a block type. """ @@ -1444,7 +1435,7 @@ async def delete_block_type(self, block_type_id: UUID): else: raise - async def read_block_types(self) -> List[BlockType]: + async def read_block_types(self) -> list[BlockType]: """ Read all block types Raises: @@ -1454,9 +1445,9 @@ async def read_block_types(self) -> List[BlockType]: List of BlockTypes. """ response = await self._client.post("/block_types/filter", json={}) - return pydantic.TypeAdapter(List[BlockType]).validate_python(response.json()) + return pydantic.TypeAdapter(list[BlockType]).validate_python(response.json()) - async def read_block_schemas(self) -> List[BlockSchema]: + async def read_block_schemas(self) -> list[BlockSchema]: """ Read all block schemas Raises: @@ -1466,7 +1457,7 @@ async def read_block_schemas(self) -> List[BlockSchema]: A BlockSchema. """ response = await self._client.post("/block_schemas/filter", json={}) - return pydantic.TypeAdapter(List[BlockSchema]).validate_python(response.json()) + return pydantic.TypeAdapter(list[BlockSchema]).validate_python(response.json()) async def get_most_recent_block_schema_for_block_type( self, @@ -1502,7 +1493,7 @@ async def read_block_document( self, block_document_id: UUID, include_secrets: bool = True, - ): + ) -> BlockDocument: """ Read the block document with the specified ID. @@ -1580,7 +1571,7 @@ async def read_block_documents( offset: Optional[int] = None, limit: Optional[int] = None, include_secrets: bool = True, - ): + ) -> list[BlockDocument]: """ Read block documents @@ -1607,7 +1598,7 @@ async def read_block_documents( include_secrets=include_secrets, ), ) - return pydantic.TypeAdapter(List[BlockDocument]).validate_python( + return pydantic.TypeAdapter(list[BlockDocument]).validate_python( response.json() ) @@ -1617,7 +1608,7 @@ async def read_block_documents_by_type( offset: Optional[int] = None, limit: Optional[int] = None, include_secrets: bool = True, - ) -> List[BlockDocument]: + ) -> list[BlockDocument]: """Retrieve block documents by block type slug. Args: @@ -1638,7 +1629,7 @@ async def read_block_documents_by_type( ), ) - return pydantic.TypeAdapter(List[BlockDocument]).validate_python( + return pydantic.TypeAdapter(list[BlockDocument]).validate_python( response.json() ) @@ -1647,23 +1638,23 @@ async def create_deployment( flow_id: UUID, name: str, version: Optional[str] = None, - schedules: Optional[List[DeploymentScheduleCreate]] = None, + schedules: Optional[list[DeploymentScheduleCreate]] = None, concurrency_limit: Optional[int] = None, concurrency_options: Optional[ConcurrencyOptions] = None, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[dict[str, Any]] = None, description: Optional[str] = None, work_queue_name: Optional[str] = None, work_pool_name: Optional[str] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, storage_document_id: Optional[UUID] = None, path: Optional[str] = None, entrypoint: Optional[str] = None, infrastructure_document_id: Optional[UUID] = None, - parameter_openapi_schema: Optional[Dict[str, Any]] = None, + parameter_openapi_schema: Optional[dict[str, Any]] = None, paused: Optional[bool] = None, - pull_steps: Optional[List[dict]] = None, + pull_steps: Optional[list[dict[str, Any]]] = None, enforce_parameter_schema: Optional[bool] = None, - job_variables: Optional[Dict[str, Any]] = None, + job_variables: Optional[dict[str, Any]] = None, ) -> UUID: """ Create a deployment. @@ -1743,7 +1734,9 @@ async def create_deployment( return UUID(deployment_id) - async def set_deployment_paused_state(self, deployment_id: UUID, paused: bool): + async def set_deployment_paused_state( + self, deployment_id: UUID, paused: bool + ) -> None: await self._client.patch( f"/deployments/{deployment_id}", json={"paused": paused} ) @@ -1752,7 +1745,7 @@ async def update_deployment( self, deployment_id: UUID, deployment: DeploymentUpdate, - ): + ) -> None: await self._client.patch( f"/deployments/{deployment_id}", json=deployment.model_dump(mode="json", exclude_unset=True), @@ -1775,7 +1768,7 @@ async def _create_deployment_from_schema(self, schema: DeploymentCreate) -> UUID async def read_deployment( self, - deployment_id: UUID, + deployment_id: Union[UUID, str], ) -> DeploymentResponse: """ Query the Prefect API for a deployment by id. @@ -1868,7 +1861,7 @@ async def read_deployments( limit: Optional[int] = None, sort: Optional[DeploymentSort] = None, offset: int = 0, - ) -> List[DeploymentResponse]: + ) -> list[DeploymentResponse]: """ Query the Prefect API for deployments. Only deployments matching all the provided criteria will be returned. @@ -1887,7 +1880,7 @@ async def read_deployments( a list of Deployment model representations of the deployments """ - body = { + body: dict[str, Any] = { "flows": flow_filter.model_dump(mode="json") if flow_filter else None, "flow_runs": ( flow_run_filter.model_dump(mode="json", exclude_unset=True) @@ -1912,14 +1905,14 @@ async def read_deployments( } response = await self._client.post("/deployments/filter", json=body) - return pydantic.TypeAdapter(List[DeploymentResponse]).validate_python( + return pydantic.TypeAdapter(list[DeploymentResponse]).validate_python( response.json() ) async def delete_deployment( self, deployment_id: UUID, - ): + ) -> None: """ Delete deployment by id. @@ -1940,8 +1933,8 @@ async def delete_deployment( async def create_deployment_schedules( self, deployment_id: UUID, - schedules: List[Tuple[SCHEDULE_TYPES, bool]], - ) -> List[DeploymentSchedule]: + schedules: list[tuple[SCHEDULE_TYPES, bool]], + ) -> list[DeploymentSchedule]: """ Create deployment schedules. @@ -1968,14 +1961,14 @@ async def create_deployment_schedules( response = await self._client.post( f"/deployments/{deployment_id}/schedules", json=json ) - return pydantic.TypeAdapter(List[DeploymentSchedule]).validate_python( + return pydantic.TypeAdapter(list[DeploymentSchedule]).validate_python( response.json() ) async def read_deployment_schedules( self, deployment_id: UUID, - ) -> List[DeploymentSchedule]: + ) -> list[DeploymentSchedule]: """ Query the Prefect API for a deployment's schedules. @@ -1992,7 +1985,7 @@ async def read_deployment_schedules( raise prefect.exceptions.ObjectNotFound(http_exc=e) from e else: raise - return pydantic.TypeAdapter(List[DeploymentSchedule]).validate_python( + return pydantic.TypeAdapter(list[DeploymentSchedule]).validate_python( response.json() ) @@ -2002,7 +1995,7 @@ async def update_deployment_schedule( schedule_id: UUID, active: Optional[bool] = None, schedule: Optional[SCHEDULE_TYPES] = None, - ): + ) -> None: """ Update a deployment schedule by ID. @@ -2012,7 +2005,7 @@ async def update_deployment_schedule( active: whether or not the schedule should be active schedule: the cron, rrule, or interval schedule this deployment schedule should use """ - kwargs = {} + kwargs: dict[str, Any] = {} if active is not None: kwargs["active"] = active if schedule is not None: @@ -2076,8 +2069,8 @@ async def read_flow_run(self, flow_run_id: UUID) -> FlowRun: return FlowRun.model_validate(response.json()) async def resume_flow_run( - self, flow_run_id: UUID, run_input: Optional[Dict] = None - ) -> OrchestrationResult: + self, flow_run_id: UUID, run_input: Optional[dict[str, Any]] = None + ) -> OrchestrationResult[Any]: """ Resumes a paused flow run. @@ -2095,21 +2088,24 @@ async def resume_flow_run( except httpx.HTTPStatusError: raise - return OrchestrationResult.model_validate(response.json()) + result: OrchestrationResult[Any] = OrchestrationResult.model_validate( + response.json() + ) + return result async def read_flow_runs( self, *, - flow_filter: FlowFilter = None, - flow_run_filter: FlowRunFilter = None, - task_run_filter: TaskRunFilter = None, - deployment_filter: DeploymentFilter = None, - work_pool_filter: WorkPoolFilter = None, - work_queue_filter: WorkQueueFilter = None, - sort: FlowRunSort = None, + flow_filter: Optional[FlowFilter] = None, + flow_run_filter: Optional[FlowRunFilter] = None, + task_run_filter: Optional[TaskRunFilter] = None, + deployment_filter: Optional[DeploymentFilter] = None, + work_pool_filter: Optional[WorkPoolFilter] = None, + work_queue_filter: Optional[WorkQueueFilter] = None, + sort: Optional[FlowRunSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[FlowRun]: + ) -> list[FlowRun]: """ Query the Prefect API for flow runs. Only flow runs matching all criteria will be returned. @@ -2129,7 +2125,7 @@ async def read_flow_runs( a list of Flow Run model representations of the flow runs """ - body = { + body: dict[str, Any] = { "flows": flow_filter.model_dump(mode="json") if flow_filter else None, "flow_runs": ( flow_run_filter.model_dump(mode="json", exclude_unset=True) @@ -2154,14 +2150,14 @@ async def read_flow_runs( } response = await self._client.post("/flow_runs/filter", json=body) - return pydantic.TypeAdapter(List[FlowRun]).validate_python(response.json()) + return pydantic.TypeAdapter(list[FlowRun]).validate_python(response.json()) async def set_flow_run_state( self, - flow_run_id: UUID, - state: "prefect.states.State", + flow_run_id: Union[UUID, str], + state: "prefect.states.State[T]", force: bool = False, - ) -> OrchestrationResult: + ) -> OrchestrationResult[T]: """ Set the state of a flow run. @@ -2194,11 +2190,14 @@ async def set_flow_run_state( else: raise - return OrchestrationResult.model_validate(response.json()) + result: OrchestrationResult[T] = OrchestrationResult.model_validate( + response.json() + ) + return result async def read_flow_run_states( self, flow_run_id: UUID - ) -> List[prefect.states.State]: + ) -> list[prefect.states.State]: """ Query for the states of a flow run @@ -2212,18 +2211,18 @@ async def read_flow_run_states( response = await self._client.get( "/flow_run_states/", params=dict(flow_run_id=str(flow_run_id)) ) - return pydantic.TypeAdapter(List[prefect.states.State]).validate_python( + return pydantic.TypeAdapter(list[prefect.states.State]).validate_python( response.json() ) - async def set_flow_run_name(self, flow_run_id: UUID, name: str): + async def set_flow_run_name(self, flow_run_id: UUID, name: str) -> httpx.Response: flow_run_data = FlowRunUpdate(name=name) return await self._client.patch( f"/flow_runs/{flow_run_id}", json=flow_run_data.model_dump(mode="json", exclude_unset=True), ) - async def set_task_run_name(self, task_run_id: UUID, name: str): + async def set_task_run_name(self, task_run_id: UUID, name: str) -> httpx.Response: task_run_data = TaskRunUpdate(name=name) return await self._client.patch( f"/task_runs/{task_run_id}", @@ -2240,9 +2239,9 @@ async def create_task_run( extra_tags: Optional[Iterable[str]] = None, state: Optional[prefect.states.State[R]] = None, task_inputs: Optional[ - Dict[ + dict[ str, - List[ + list[ Union[ TaskRunResult, Parameter, @@ -2276,6 +2275,12 @@ async def create_task_run( if state is None: state = prefect.states.Pending() + retry_delay = task.retry_delay_seconds + if isinstance(retry_delay, list): + retry_delay = [int(rd) for rd in retry_delay] + elif isinstance(retry_delay, float): + retry_delay = int(retry_delay) + task_run_data = TaskRunCreate( id=id, name=name, @@ -2286,7 +2291,7 @@ async def create_task_run( task_version=task.version, empirical_policy=TaskRunPolicy( retries=task.retries, - retry_delay=task.retry_delay_seconds, + retry_delay=retry_delay, retry_jitter_factor=task.retry_jitter_factor, ), state=state.to_state_create(), @@ -2319,14 +2324,14 @@ async def read_task_run(self, task_run_id: UUID) -> TaskRun: async def read_task_runs( self, *, - flow_filter: FlowFilter = None, - flow_run_filter: FlowRunFilter = None, - task_run_filter: TaskRunFilter = None, - deployment_filter: DeploymentFilter = None, - sort: TaskRunSort = None, + flow_filter: Optional[FlowFilter] = None, + flow_run_filter: Optional[FlowRunFilter] = None, + task_run_filter: Optional[TaskRunFilter] = None, + deployment_filter: Optional[DeploymentFilter] = None, + sort: Optional[TaskRunSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[TaskRun]: + ) -> list[TaskRun]: """ Query the Prefect API for task runs. Only task runs matching all criteria will be returned. @@ -2344,7 +2349,7 @@ async def read_task_runs( a list of Task Run model representations of the task runs """ - body = { + body: dict[str, Any] = { "flows": flow_filter.model_dump(mode="json") if flow_filter else None, "flow_runs": ( flow_run_filter.model_dump(mode="json", exclude_unset=True) @@ -2362,7 +2367,7 @@ async def read_task_runs( "offset": offset, } response = await self._client.post("/task_runs/filter", json=body) - return pydantic.TypeAdapter(List[TaskRun]).validate_python(response.json()) + return pydantic.TypeAdapter(list[TaskRun]).validate_python(response.json()) async def delete_task_run(self, task_run_id: UUID) -> None: """ @@ -2385,9 +2390,9 @@ async def delete_task_run(self, task_run_id: UUID) -> None: async def set_task_run_state( self, task_run_id: UUID, - state: prefect.states.State, + state: prefect.states.State[T], force: bool = False, - ) -> OrchestrationResult: + ) -> OrchestrationResult[T]: """ Set the state of a task run. @@ -2406,11 +2411,14 @@ async def set_task_run_state( f"/task_runs/{task_run_id}/set_state", json=dict(state=state_create.model_dump(mode="json"), force=force), ) - return OrchestrationResult.model_validate(response.json()) + result: OrchestrationResult[T] = OrchestrationResult.model_validate( + response.json() + ) + return result async def read_task_run_states( self, task_run_id: UUID - ) -> List[prefect.states.State]: + ) -> list[prefect.states.State]: """ Query for the states of a task run @@ -2423,11 +2431,13 @@ async def read_task_run_states( response = await self._client.get( "/task_run_states/", params=dict(task_run_id=str(task_run_id)) ) - return pydantic.TypeAdapter(List[prefect.states.State]).validate_python( + return pydantic.TypeAdapter(list[prefect.states.State]).validate_python( response.json() ) - async def create_logs(self, logs: Iterable[Union[LogCreate, dict]]) -> None: + async def create_logs( + self, logs: Iterable[Union[LogCreate, dict[str, Any]]] + ) -> None: """ Create logs for a flow or task run @@ -2444,8 +2454,8 @@ async def create_flow_run_notification_policy( self, block_document_id: UUID, is_active: bool = True, - tags: List[str] = None, - state_names: List[str] = None, + tags: Optional[list[str]] = None, + state_names: Optional[list[str]] = None, message_template: Optional[str] = None, ) -> UUID: """ @@ -2507,8 +2517,8 @@ async def update_flow_run_notification_policy( id: UUID, block_document_id: Optional[UUID] = None, is_active: Optional[bool] = None, - tags: Optional[List[str]] = None, - state_names: Optional[List[str]] = None, + tags: Optional[list[str]] = None, + state_names: Optional[list[str]] = None, message_template: Optional[str] = None, ) -> None: """ @@ -2525,7 +2535,7 @@ async def update_flow_run_notification_policy( prefect.exceptions.ObjectNotFound: If request returns 404 httpx.RequestError: If requests fails """ - params = {} + params: dict[str, Any] = {} if block_document_id is not None: params["block_document_id"] = block_document_id if is_active is not None: @@ -2555,7 +2565,7 @@ async def read_flow_run_notification_policies( flow_run_notification_policy_filter: FlowRunNotificationPolicyFilter, limit: Optional[int] = None, offset: int = 0, - ) -> List[FlowRunNotificationPolicy]: + ) -> list[FlowRunNotificationPolicy]: """ Query the Prefect API for flow run notification policies. Only policies matching all criteria will be returned. @@ -2569,7 +2579,7 @@ async def read_flow_run_notification_policies( a list of FlowRunNotificationPolicy model representations of the notification policies """ - body = { + body: dict[str, Any] = { "flow_run_notification_policy_filter": ( flow_run_notification_policy_filter.model_dump(mode="json") if flow_run_notification_policy_filter @@ -2581,7 +2591,7 @@ async def read_flow_run_notification_policies( response = await self._client.post( "/flow_run_notification_policies/filter", json=body ) - return pydantic.TypeAdapter(List[FlowRunNotificationPolicy]).validate_python( + return pydantic.TypeAdapter(list[FlowRunNotificationPolicy]).validate_python( response.json() ) @@ -2591,11 +2601,11 @@ async def read_logs( limit: Optional[int] = None, offset: Optional[int] = None, sort: LogSort = LogSort.TIMESTAMP_ASC, - ) -> List[Log]: + ) -> list[Log]: """ Read flow and task run logs. """ - body = { + body: dict[str, Any] = { "logs": log_filter.model_dump(mode="json") if log_filter else None, "limit": limit, "offset": offset, @@ -2603,7 +2613,7 @@ async def read_logs( } response = await self._client.post("/logs/filter", json=body) - return pydantic.TypeAdapter(List[Log]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Log]).validate_python(response.json()) async def send_worker_heartbeat( self, @@ -2622,7 +2632,7 @@ async def send_worker_heartbeat( return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. worker_metadata: Metadata about the worker to send to the server. """ - params = { + params: dict[str, Any] = { "name": worker_name, "heartbeat_interval_seconds": heartbeat_interval_seconds, } @@ -2654,7 +2664,7 @@ async def read_workers_for_work_pool( worker_filter: Optional[WorkerFilter] = None, offset: Optional[int] = None, limit: Optional[int] = None, - ) -> List[Worker]: + ) -> list[Worker]: """ Reads workers for a given work pool. @@ -2678,7 +2688,7 @@ async def read_workers_for_work_pool( }, ) - return pydantic.TypeAdapter(List[Worker]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Worker]).validate_python(response.json()) async def read_work_pool(self, work_pool_name: str) -> WorkPool: """ @@ -2705,7 +2715,7 @@ async def read_work_pools( limit: Optional[int] = None, offset: int = 0, work_pool_filter: Optional[WorkPoolFilter] = None, - ) -> List[WorkPool]: + ) -> list[WorkPool]: """ Reads work pools. @@ -2718,7 +2728,7 @@ async def read_work_pools( A list of work pools. """ - body = { + body: dict[str, Any] = { "limit": limit, "offset": offset, "work_pools": ( @@ -2726,7 +2736,7 @@ async def read_work_pools( ), } response = await self._client.post("/work_pools/filter", json=body) - return pydantic.TypeAdapter(List[WorkPool]).validate_python(response.json()) + return pydantic.TypeAdapter(list[WorkPool]).validate_python(response.json()) async def create_work_pool( self, @@ -2776,7 +2786,7 @@ async def update_work_pool( self, work_pool_name: str, work_pool: WorkPoolUpdate, - ): + ) -> None: """ Updates a work pool. @@ -2798,7 +2808,7 @@ async def update_work_pool( async def delete_work_pool( self, work_pool_name: str, - ): + ) -> None: """ Deletes a work pool. @@ -2819,7 +2829,7 @@ async def read_work_queues( work_queue_filter: Optional[WorkQueueFilter] = None, limit: Optional[int] = None, offset: Optional[int] = None, - ) -> List[WorkQueue]: + ) -> list[WorkQueue]: """ Retrieves queues for a work pool. @@ -2832,7 +2842,7 @@ async def read_work_queues( Returns: List of queues for the specified work pool. """ - json = { + json: dict[str, Any] = { "work_queues": ( work_queue_filter.model_dump(mode="json", exclude_unset=True) if work_queue_filter @@ -2856,15 +2866,15 @@ async def read_work_queues( else: response = await self._client.post("/work_queues/filter", json=json) - return pydantic.TypeAdapter(List[WorkQueue]).validate_python(response.json()) + return pydantic.TypeAdapter(list[WorkQueue]).validate_python(response.json()) async def get_scheduled_flow_runs_for_deployments( self, - deployment_ids: List[UUID], + deployment_ids: list[UUID], scheduled_before: Optional[datetime.datetime] = None, limit: Optional[int] = None, - ) -> List[FlowRunResponse]: - body: Dict[str, Any] = dict(deployment_ids=[str(id) for id in deployment_ids]) + ) -> list[FlowRunResponse]: + body: dict[str, Any] = dict(deployment_ids=[str(id) for id in deployment_ids]) if scheduled_before: body["scheduled_before"] = str(scheduled_before) if limit: @@ -2875,16 +2885,16 @@ async def get_scheduled_flow_runs_for_deployments( json=body, ) - return pydantic.TypeAdapter(List[FlowRunResponse]).validate_python( + return pydantic.TypeAdapter(list[FlowRunResponse]).validate_python( response.json() ) async def get_scheduled_flow_runs_for_work_pool( self, work_pool_name: str, - work_queue_names: Optional[List[str]] = None, + work_queue_names: Optional[list[str]] = None, scheduled_before: Optional[datetime.datetime] = None, - ) -> List[WorkerFlowRunResponse]: + ) -> list[WorkerFlowRunResponse]: """ Retrieves scheduled flow runs for the provided set of work pool queues. @@ -2900,7 +2910,7 @@ async def get_scheduled_flow_runs_for_work_pool( A list of worker flow run responses containing information about the retrieved flow runs. """ - body: Dict[str, Any] = {} + body: dict[str, Any] = {} if work_queue_names is not None: body["work_queue_names"] = list(work_queue_names) if scheduled_before: @@ -2910,7 +2920,7 @@ async def get_scheduled_flow_runs_for_work_pool( f"/work_pools/{work_pool_name}/get_scheduled_flow_runs", json=body, ) - return pydantic.TypeAdapter(List[WorkerFlowRunResponse]).validate_python( + return pydantic.TypeAdapter(list[WorkerFlowRunResponse]).validate_python( response.json() ) @@ -2956,13 +2966,13 @@ async def update_artifact( async def read_artifacts( self, *, - artifact_filter: ArtifactFilter = None, - flow_run_filter: FlowRunFilter = None, - task_run_filter: TaskRunFilter = None, - sort: ArtifactSort = None, + artifact_filter: Optional[ArtifactFilter] = None, + flow_run_filter: Optional[FlowRunFilter] = None, + task_run_filter: Optional[TaskRunFilter] = None, + sort: Optional[ArtifactSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[Artifact]: + ) -> list[Artifact]: """ Query the Prefect API for artifacts. Only artifacts matching all criteria will be returned. @@ -2976,7 +2986,7 @@ async def read_artifacts( Returns: a list of Artifact model representations of the artifacts """ - body = { + body: dict[str, Any] = { "artifacts": ( artifact_filter.model_dump(mode="json") if artifact_filter else None ), @@ -2991,18 +3001,18 @@ async def read_artifacts( "offset": offset, } response = await self._client.post("/artifacts/filter", json=body) - return pydantic.TypeAdapter(List[Artifact]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Artifact]).validate_python(response.json()) async def read_latest_artifacts( self, *, - artifact_filter: ArtifactCollectionFilter = None, - flow_run_filter: FlowRunFilter = None, - task_run_filter: TaskRunFilter = None, - sort: ArtifactCollectionSort = None, + artifact_filter: Optional[ArtifactCollectionFilter] = None, + flow_run_filter: Optional[FlowRunFilter] = None, + task_run_filter: Optional[TaskRunFilter] = None, + sort: Optional[ArtifactCollectionSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[ArtifactCollection]: + ) -> list[ArtifactCollection]: """ Query the Prefect API for artifacts. Only artifacts matching all criteria will be returned. @@ -3016,7 +3026,7 @@ async def read_latest_artifacts( Returns: a list of Artifact model representations of the artifacts """ - body = { + body: dict[str, Any] = { "artifacts": ( artifact_filter.model_dump(mode="json") if artifact_filter else None ), @@ -3031,7 +3041,7 @@ async def read_latest_artifacts( "offset": offset, } response = await self._client.post("/artifacts/latest/filter", json=body) - return pydantic.TypeAdapter(List[ArtifactCollection]).validate_python( + return pydantic.TypeAdapter(list[ArtifactCollection]).validate_python( response.json() ) @@ -3090,7 +3100,7 @@ async def read_variable_by_name(self, name: str) -> Optional[Variable]: else: raise - async def delete_variable_by_name(self, name: str): + async def delete_variable_by_name(self, name: str) -> None: """Deletes a variable by name.""" try: await self._client.delete(f"/variables/name/{name}") @@ -3100,12 +3110,12 @@ async def delete_variable_by_name(self, name: str): else: raise - async def read_variables(self, limit: Optional[int] = None) -> List[Variable]: + async def read_variables(self, limit: Optional[int] = None) -> list[Variable]: """Reads all variables.""" response = await self._client.post("/variables/filter", json={"limit": limit}) - return pydantic.TypeAdapter(List[Variable]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Variable]).validate_python(response.json()) - async def read_worker_metadata(self) -> Dict[str, Any]: + async def read_worker_metadata(self) -> dict[str, Any]: """Reads worker metadata stored in Prefect collection registry.""" response = await self._client.get("collections/views/aggregate-worker-metadata") response.raise_for_status() @@ -3113,7 +3123,7 @@ async def read_worker_metadata(self) -> Dict[str, Any]: async def increment_concurrency_slots( self, - names: List[str], + names: list[str], slots: int, mode: str, create_if_missing: Optional[bool] = None, @@ -3129,7 +3139,7 @@ async def increment_concurrency_slots( ) async def release_concurrency_slots( - self, names: List[str], slots: int, occupancy_seconds: float + self, names: list[str], slots: int, occupancy_seconds: float ) -> httpx.Response: """ Release concurrency slots for the specified limits. @@ -3201,7 +3211,9 @@ async def read_global_concurrency_limit_by_name( else: raise - async def upsert_global_concurrency_limit_by_name(self, name: str, limit: int): + async def upsert_global_concurrency_limit_by_name( + self, name: str, limit: int + ) -> None: """Creates a global concurrency limit with the given name and limit if one does not already exist. If one does already exist matching the name then update it's limit if it is different. @@ -3227,7 +3239,7 @@ async def upsert_global_concurrency_limit_by_name(self, name: str, limit: int): async def read_global_concurrency_limits( self, limit: int = 10, offset: int = 0 - ) -> List[GlobalConcurrencyLimitResponse]: + ) -> list[GlobalConcurrencyLimitResponse]: response = await self._client.post( "/v2/concurrency_limits/filter", json={ @@ -3236,12 +3248,12 @@ async def read_global_concurrency_limits( }, ) return pydantic.TypeAdapter( - List[GlobalConcurrencyLimitResponse] + list[GlobalConcurrencyLimitResponse] ).validate_python(response.json()) async def create_flow_run_input( self, flow_run_id: UUID, key: str, value: str, sender: Optional[str] = None - ): + ) -> None: """ Creates a flow run input. @@ -3262,8 +3274,8 @@ async def create_flow_run_input( response.raise_for_status() async def filter_flow_run_input( - self, flow_run_id: UUID, key_prefix: str, limit: int, exclude_keys: Set[str] - ) -> List[FlowRunInput]: + self, flow_run_id: UUID, key_prefix: str, limit: int, exclude_keys: set[str] + ) -> list[FlowRunInput]: response = await self._client.post( f"/flow_runs/{flow_run_id}/input/filter", json={ @@ -3273,7 +3285,7 @@ async def filter_flow_run_input( }, ) response.raise_for_status() - return pydantic.TypeAdapter(List[FlowRunInput]).validate_python(response.json()) + return pydantic.TypeAdapter(list[FlowRunInput]).validate_python(response.json()) async def read_flow_run_input(self, flow_run_id: UUID, key: str) -> str: """ @@ -3287,7 +3299,7 @@ async def read_flow_run_input(self, flow_run_id: UUID, key: str) -> str: response.raise_for_status() return response.content.decode() - async def delete_flow_run_input(self, flow_run_id: UUID, key: str): + async def delete_flow_run_input(self, flow_run_id: UUID, key: str) -> None: """ Deletes a flow run input. @@ -3307,7 +3319,9 @@ async def create_automation(self, automation: AutomationCore) -> UUID: return UUID(response.json()["id"]) - async def update_automation(self, automation_id: UUID, automation: AutomationCore): + async def update_automation( + self, automation_id: UUID, automation: AutomationCore + ) -> None: """Updates an automation in Prefect Cloud.""" response = await self._client.put( f"/automations/{automation_id}", @@ -3315,21 +3329,23 @@ async def update_automation(self, automation_id: UUID, automation: AutomationCor ) response.raise_for_status - async def read_automations(self) -> List[Automation]: + async def read_automations(self) -> list[Automation]: response = await self._client.post("/automations/filter") response.raise_for_status() - return pydantic.TypeAdapter(List[Automation]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Automation]).validate_python(response.json()) async def find_automation( self, id_or_name: Union[str, UUID] ) -> Optional[Automation]: if isinstance(id_or_name, str): + name = id_or_name try: id = UUID(id_or_name) except ValueError: id = None - elif isinstance(id_or_name, UUID): + else: id = id_or_name + name = str(id) if id: try: @@ -3343,24 +3359,26 @@ async def find_automation( # Look for it by an exact name for automation in automations: - if automation.name == id_or_name: + if automation.name == name: return automation # Look for it by a case-insensitive name for automation in automations: - if automation.name.lower() == id_or_name.lower(): + if automation.name.lower() == name.lower(): return automation return None - async def read_automation(self, automation_id: UUID) -> Optional[Automation]: + async def read_automation( + self, automation_id: Union[UUID, str] + ) -> Optional[Automation]: response = await self._client.get(f"/automations/{automation_id}") if response.status_code == 404: return None response.raise_for_status() return Automation.model_validate(response.json()) - async def read_automations_by_name(self, name: str) -> List[Automation]: + async def read_automations_by_name(self, name: str) -> list[Automation]: """ Query the Prefect API for an automation by name. Only automations matching the provided name will be returned. @@ -3370,7 +3388,9 @@ async def read_automations_by_name(self, name: str) -> List[Automation]: Returns: a list of Automation model representations of the automations """ - automation_filter = filters.AutomationFilter(name=dict(any_=[name])) + automation_filter = filters.AutomationFilter( + name=filters.AutomationFilterName(any_=[name]) + ) response = await self._client.post( "/automations/filter", @@ -3384,21 +3404,21 @@ async def read_automations_by_name(self, name: str) -> List[Automation]: response.raise_for_status() - return pydantic.TypeAdapter(List[Automation]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Automation]).validate_python(response.json()) - async def pause_automation(self, automation_id: UUID): + async def pause_automation(self, automation_id: UUID) -> None: response = await self._client.patch( f"/automations/{automation_id}", json={"enabled": False} ) response.raise_for_status() - async def resume_automation(self, automation_id: UUID): + async def resume_automation(self, automation_id: UUID) -> None: response = await self._client.patch( f"/automations/{automation_id}", json={"enabled": True} ) response.raise_for_status() - async def delete_automation(self, automation_id: UUID): + async def delete_automation(self, automation_id: UUID) -> None: response = await self._client.delete(f"/automations/{automation_id}") if response.status_code == 404: return @@ -3407,12 +3427,12 @@ async def delete_automation(self, automation_id: UUID): async def read_resource_related_automations( self, resource_id: str - ) -> List[Automation]: + ) -> list[Automation]: response = await self._client.get(f"/automations/related-to/{resource_id}") response.raise_for_status() - return pydantic.TypeAdapter(List[Automation]).validate_python(response.json()) + return pydantic.TypeAdapter(list[Automation]).validate_python(response.json()) - async def delete_resource_owned_automations(self, resource_id: str): + async def delete_resource_owned_automations(self, resource_id: str) -> None: await self._client.delete(f"/automations/owned-by/{resource_id}") async def api_version(self) -> str: @@ -3422,7 +3442,7 @@ async def api_version(self) -> str: def client_version(self) -> str: return prefect.__version__ - async def raise_for_api_version_mismatch(self): + async def raise_for_api_version_mismatch(self) -> None: # Cloud is always compatible as a server if self.server_type == ServerType.CLOUD: return @@ -3441,7 +3461,7 @@ async def raise_for_api_version_mismatch(self): f"Major versions must match." ) - async def __aenter__(self): + async def __aenter__(self) -> Self: """ Start the client. @@ -3488,7 +3508,7 @@ async def __aenter__(self): return self - async def __aexit__(self, *exc_info): + async def __aexit__(self, *exc_info: Any) -> Optional[bool]: """ Shutdown the client. """ @@ -3499,13 +3519,13 @@ async def __aexit__(self, *exc_info): self._closed = True return await self._exit_stack.__aexit__(*exc_info) - def __enter__(self): + def __enter__(self) -> NoReturn: raise RuntimeError( "The `PrefectClient` must be entered with an async context. Use 'async " "with PrefectClient(...)' not 'with PrefectClient(...)'" ) - def __exit__(self, *_): + def __exit__(self, *_: object) -> NoReturn: assert False, "This should never be called but must be defined for __enter__" @@ -3541,7 +3561,7 @@ def __init__( *, api_key: Optional[str] = None, api_version: Optional[str] = None, - httpx_settings: Optional[Dict[str, Any]] = None, + httpx_settings: Optional[dict[str, Any]] = None, server_type: Optional[ServerType] = None, ) -> None: httpx_settings = httpx_settings.copy() if httpx_settings else {} @@ -3617,16 +3637,10 @@ def __init__( ) # Connect to an in-process application - elif isinstance(api, ASGIApp): + else: self._ephemeral_app = api self.server_type = ServerType.EPHEMERAL - else: - raise TypeError( - f"Unexpected type {type(api).__name__!r} for argument `api`. Expected" - " 'str' or 'ASGIApp/FastAPI'" - ) - # See https://www.python-httpx.org/advanced/#timeout-configuration httpx_settings.setdefault( "timeout", @@ -3669,9 +3683,9 @@ def __init__( if isinstance(server_transport, httpx.HTTPTransport): pool = getattr(server_transport, "_pool", None) if isinstance(pool, httpcore.ConnectionPool): - pool._retries = 3 + setattr(pool, "_retries", 3) - self.logger = get_logger("client") + self.logger: Logger = get_logger("client") @property def api_url(self) -> httpx.URL: @@ -3709,7 +3723,7 @@ def __enter__(self) -> "SyncPrefectClient": return self - def __exit__(self, *exc_info) -> None: + def __exit__(self, *exc_info: Any) -> None: """ Shutdown the client. """ @@ -3747,7 +3761,7 @@ def api_version(self) -> str: def client_version(self) -> str: return prefect.__version__ - def raise_for_api_version_mismatch(self): + def raise_for_api_version_mismatch(self) -> None: # Cloud is always compatible as a server if self.server_type == ServerType.CLOUD: return @@ -3766,7 +3780,7 @@ def raise_for_api_version_mismatch(self): f"Major versions must match." ) - def create_flow(self, flow: "FlowObject") -> UUID: + def create_flow(self, flow: "FlowObject[Any, Any]") -> UUID: """ Create a flow in the Prefect API. @@ -3806,13 +3820,13 @@ def create_flow_from_name(self, flow_name: str) -> UUID: def create_flow_run( self, - flow: "FlowObject", + flow: "FlowObject[Any, R]", name: Optional[str] = None, - parameters: Optional[Dict[str, Any]] = None, - context: Optional[Dict[str, Any]] = None, + parameters: Optional[dict[str, Any]] = None, + context: Optional[dict[str, Any]] = None, tags: Optional[Iterable[str]] = None, parent_task_run_id: Optional[UUID] = None, - state: Optional["prefect.states.State"] = None, + state: Optional["prefect.states.State[R]"] = None, ) -> FlowRun: """ Create a flow run for a flow. @@ -3854,7 +3868,7 @@ def create_flow_run( state=state.to_state_create(), empirical_policy=FlowRunPolicy( retries=flow.retries, - retry_delay=flow.retry_delay_seconds, + retry_delay=int(flow.retry_delay_seconds or 0), ), ) @@ -3872,12 +3886,12 @@ def update_flow_run( self, flow_run_id: UUID, flow_version: Optional[str] = None, - parameters: Optional[dict] = None, + parameters: Optional[dict[str, Any]] = None, name: Optional[str] = None, tags: Optional[Iterable[str]] = None, empirical_policy: Optional[FlowRunPolicy] = None, infrastructure_pid: Optional[str] = None, - job_variables: Optional[dict] = None, + job_variables: Optional[dict[str, Any]] = None, ) -> httpx.Response: """ Update a flow run's details. @@ -3898,7 +3912,7 @@ def update_flow_run( Returns: an `httpx.Response` object from the PATCH request """ - params = {} + params: dict[str, Any] = {} if flow_version is not None: params["flow_version"] = flow_version if parameters is not None: @@ -3954,7 +3968,7 @@ def read_flow_runs( sort: Optional[FlowRunSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[FlowRun]: + ) -> list[FlowRun]: """ Query the Prefect API for flow runs. Only flow runs matching all criteria will be returned. @@ -3974,7 +3988,7 @@ def read_flow_runs( a list of Flow Run model representations of the flow runs """ - body = { + body: dict[str, Any] = { "flows": flow_filter.model_dump(mode="json") if flow_filter else None, "flow_runs": ( flow_run_filter.model_dump(mode="json", exclude_unset=True) @@ -3999,14 +4013,14 @@ def read_flow_runs( } response = self._client.post("/flow_runs/filter", json=body) - return pydantic.TypeAdapter(List[FlowRun]).validate_python(response.json()) + return pydantic.TypeAdapter(list[FlowRun]).validate_python(response.json()) def set_flow_run_state( self, flow_run_id: UUID, - state: "prefect.states.State", + state: "prefect.states.State[T]", force: bool = False, - ) -> OrchestrationResult: + ) -> OrchestrationResult[T]: """ Set the state of a flow run. @@ -4036,16 +4050,19 @@ def set_flow_run_state( else: raise - return OrchestrationResult.model_validate(response.json()) + result: OrchestrationResult[T] = OrchestrationResult.model_validate( + response.json() + ) + return result - def set_flow_run_name(self, flow_run_id: UUID, name: str): + def set_flow_run_name(self, flow_run_id: UUID, name: str) -> httpx.Response: flow_run_data = FlowRunUpdate(name=name) return self._client.patch( f"/flow_runs/{flow_run_id}", json=flow_run_data.model_dump(mode="json", exclude_unset=True), ) - def set_task_run_name(self, task_run_id: UUID, name: str): + def set_task_run_name(self, task_run_id: UUID, name: str) -> httpx.Response: task_run_data = TaskRunUpdate(name=name) return self._client.patch( f"/task_runs/{task_run_id}", @@ -4062,9 +4079,9 @@ def create_task_run( extra_tags: Optional[Iterable[str]] = None, state: Optional[prefect.states.State[R]] = None, task_inputs: Optional[ - Dict[ + dict[ str, - List[ + list[ Union[ TaskRunResult, Parameter, @@ -4098,6 +4115,12 @@ def create_task_run( if state is None: state = prefect.states.Pending() + retry_delay = task.retry_delay_seconds + if isinstance(retry_delay, list): + retry_delay = [int(rd) for rd in retry_delay] + elif isinstance(retry_delay, float): + retry_delay = int(retry_delay) + task_run_data = TaskRunCreate( id=id, name=name, @@ -4108,7 +4131,7 @@ def create_task_run( task_version=task.version, empirical_policy=TaskRunPolicy( retries=task.retries, - retry_delay=task.retry_delay_seconds, + retry_delay=retry_delay, retry_jitter_factor=task.retry_jitter_factor, ), state=state.to_state_create(), @@ -4142,14 +4165,14 @@ def read_task_run(self, task_run_id: UUID) -> TaskRun: def read_task_runs( self, *, - flow_filter: FlowFilter = None, - flow_run_filter: FlowRunFilter = None, - task_run_filter: TaskRunFilter = None, - deployment_filter: DeploymentFilter = None, - sort: TaskRunSort = None, + flow_filter: Optional[FlowFilter] = None, + flow_run_filter: Optional[FlowRunFilter] = None, + task_run_filter: Optional[TaskRunFilter] = None, + deployment_filter: Optional[DeploymentFilter] = None, + sort: Optional[TaskRunSort] = None, limit: Optional[int] = None, offset: int = 0, - ) -> List[TaskRun]: + ) -> list[TaskRun]: """ Query the Prefect API for task runs. Only task runs matching all criteria will be returned. @@ -4167,7 +4190,7 @@ def read_task_runs( a list of Task Run model representations of the task runs """ - body = { + body: dict[str, Any] = { "flows": flow_filter.model_dump(mode="json") if flow_filter else None, "flow_runs": ( flow_run_filter.model_dump(mode="json", exclude_unset=True) @@ -4185,14 +4208,14 @@ def read_task_runs( "offset": offset, } response = self._client.post("/task_runs/filter", json=body) - return pydantic.TypeAdapter(List[TaskRun]).validate_python(response.json()) + return pydantic.TypeAdapter(list[TaskRun]).validate_python(response.json()) def set_task_run_state( self, task_run_id: UUID, - state: prefect.states.State, + state: prefect.states.State[Any], force: bool = False, - ) -> OrchestrationResult: + ) -> OrchestrationResult[Any]: """ Set the state of a task run. @@ -4211,9 +4234,12 @@ def set_task_run_state( f"/task_runs/{task_run_id}/set_state", json=dict(state=state_create.model_dump(mode="json"), force=force), ) - return OrchestrationResult.model_validate(response.json()) + result: OrchestrationResult[Any] = OrchestrationResult.model_validate( + response.json() + ) + return result - def read_task_run_states(self, task_run_id: UUID) -> List[prefect.states.State]: + def read_task_run_states(self, task_run_id: UUID) -> list[prefect.states.State]: """ Query for the states of a task run @@ -4226,7 +4252,7 @@ def read_task_run_states(self, task_run_id: UUID) -> List[prefect.states.State]: response = self._client.get( "/task_run_states/", params=dict(task_run_id=str(task_run_id)) ) - return pydantic.TypeAdapter(List[prefect.states.State]).validate_python( + return pydantic.TypeAdapter(list[prefect.states.State]).validate_python( response.json() ) @@ -4300,7 +4326,7 @@ def create_artifact( return Artifact.model_validate(response.json()) def release_concurrency_slots( - self, names: List[str], slots: int, occupancy_seconds: float + self, names: list[str], slots: int, occupancy_seconds: float ) -> httpx.Response: """ Release concurrency slots for the specified limits. @@ -4324,7 +4350,7 @@ def release_concurrency_slots( ) def decrement_v1_concurrency_slots( - self, names: List[str], occupancy_seconds: float, task_run_id: UUID + self, names: list[str], occupancy_seconds: float, task_run_id: UUID ) -> httpx.Response: """ Release the specified concurrency limits. diff --git a/src/prefect/client/schemas/__init__.py b/src/prefect/client/schemas/__init__.py index c5335d4906b0..2a35e6a1f3c0 100644 --- a/src/prefect/client/schemas/__init__.py +++ b/src/prefect/client/schemas/__init__.py @@ -25,3 +25,27 @@ StateAcceptDetails, StateRejectDetails, ) + +__all__ = ( + "BlockDocument", + "BlockSchema", + "BlockType", + "BlockTypeUpdate", + "DEFAULT_BLOCK_SCHEMA_VERSION", + "FlowRun", + "FlowRunPolicy", + "OrchestrationResult", + "SetStateStatus", + "State", + "StateAbortDetails", + "StateAcceptDetails", + "StateCreate", + "StateDetails", + "StateRejectDetails", + "StateType", + "TaskRun", + "TaskRunInput", + "TaskRunPolicy", + "TaskRunResult", + "Workspace", +) diff --git a/src/prefect/client/schemas/actions.py b/src/prefect/client/schemas/actions.py index 9e0dd4bd3052..6f17c7cd8cc8 100644 --- a/src/prefect/client/schemas/actions.py +++ b/src/prefect/client/schemas/actions.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from uuid import UUID, uuid4 import jsonschema @@ -51,7 +51,7 @@ class StateCreate(ActionBaseModel): name: Optional[str] = Field(default=None) message: Optional[str] = Field(default=None, examples=["Run started"]) state_details: StateDetails = Field(default_factory=StateDetails) - data: Union["BaseResult[R]", "ResultRecordMetadata", Any] = Field( + data: Union["BaseResult[Any]", "ResultRecordMetadata", Any] = Field( default=None, ) @@ -62,18 +62,19 @@ class FlowCreate(ActionBaseModel): name: str = Field( default=..., description="The name of the flow", examples=["my-flow"] ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of flow tags", examples=[["tag-1", "tag-2"]], ) - labels: KeyValueLabelsField + + labels: KeyValueLabelsField = Field(default_factory=dict) class FlowUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a flow.""" - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of flow tags", examples=[["tag-1", "tag-2"]], @@ -94,7 +95,7 @@ class DeploymentScheduleCreate(ActionBaseModel): @field_validator("max_scheduled_runs") @classmethod - def validate_max_scheduled_runs(cls, v): + def validate_max_scheduled_runs(cls, v: Optional[int]) -> Optional[int]: return validate_schedule_max_scheduled_runs( v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value() ) @@ -115,7 +116,7 @@ class DeploymentScheduleUpdate(ActionBaseModel): @field_validator("max_scheduled_runs") @classmethod - def validate_max_scheduled_runs(cls, v): + def validate_max_scheduled_runs(cls, v: Optional[int]) -> Optional[int]: return validate_schedule_max_scheduled_runs( v, PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS.value() ) @@ -126,18 +127,20 @@ class DeploymentCreate(ActionBaseModel): @model_validator(mode="before") @classmethod - def remove_old_fields(cls, values): + def remove_old_fields(cls, values: dict[str, Any]) -> dict[str, Any]: return remove_old_deployment_fields(values) @field_validator("description", "tags", mode="before") @classmethod - def convert_to_strings(cls, values): + def convert_to_strings( + cls, values: Optional[Union[str, list[str]]] + ) -> Union[str, list[str]]: return convert_to_strings(values) name: str = Field(..., description="The name of the deployment.") flow_id: UUID = Field(..., description="The ID of the flow to deploy.") - paused: Optional[bool] = Field(None) - schedules: List[DeploymentScheduleCreate] = Field( + paused: Optional[bool] = Field(default=None) + schedules: list[DeploymentScheduleCreate] = Field( default_factory=list, description="A list of schedules for the deployment.", ) @@ -155,33 +158,33 @@ def convert_to_strings(cls, values): "Whether or not the deployment should enforce the parameter schema." ), ) - parameter_openapi_schema: Optional[Dict[str, Any]] = Field(default_factory=dict) - parameters: Dict[str, Any] = Field( + parameter_openapi_schema: Optional[dict[str, Any]] = Field(default_factory=dict) + parameters: dict[str, Any] = Field( default_factory=dict, description="Parameters for flow runs scheduled by the deployment.", ) - tags: List[str] = Field(default_factory=list) - labels: KeyValueLabelsField - pull_steps: Optional[List[dict]] = Field(None) + tags: list[str] = Field(default_factory=list) + labels: KeyValueLabelsField = Field(default_factory=dict) + pull_steps: Optional[list[dict[str, Any]]] = Field(default=None) - work_queue_name: Optional[str] = Field(None) + work_queue_name: Optional[str] = Field(default=None) work_pool_name: Optional[str] = Field( default=None, description="The name of the deployment's work pool.", examples=["my-work-pool"], ) - storage_document_id: Optional[UUID] = Field(None) - infrastructure_document_id: Optional[UUID] = Field(None) - description: Optional[str] = Field(None) - path: Optional[str] = Field(None) - version: Optional[str] = Field(None) - entrypoint: Optional[str] = Field(None) - job_variables: Dict[str, Any] = Field( + storage_document_id: Optional[UUID] = Field(default=None) + infrastructure_document_id: Optional[UUID] = Field(default=None) + description: Optional[str] = Field(default=None) + path: Optional[str] = Field(default=None) + version: Optional[str] = Field(default=None) + entrypoint: Optional[str] = Field(default=None) + job_variables: dict[str, Any] = Field( default_factory=dict, description="Overrides to apply to flow run infrastructure at runtime.", ) - def check_valid_configuration(self, base_job_template: dict): + def check_valid_configuration(self, base_job_template: dict[str, Any]) -> None: """Check that the combination of base_job_template defaults and job_variables conforms to the specified schema. """ @@ -206,19 +209,19 @@ class DeploymentUpdate(ActionBaseModel): @model_validator(mode="before") @classmethod - def remove_old_fields(cls, values): + def remove_old_fields(cls, values: dict[str, Any]) -> dict[str, Any]: return remove_old_deployment_fields(values) - version: Optional[str] = Field(None) - description: Optional[str] = Field(None) - parameters: Optional[Dict[str, Any]] = Field( + version: Optional[str] = Field(default=None) + description: Optional[str] = Field(default=None) + parameters: Optional[dict[str, Any]] = Field( default=None, description="Parameters for flow runs scheduled by the deployment.", ) paused: Optional[bool] = Field( default=None, description="Whether or not the deployment is paused." ) - schedules: Optional[List[DeploymentScheduleCreate]] = Field( + schedules: Optional[list[DeploymentScheduleCreate]] = Field( default=None, description="A list of schedules for the deployment.", ) @@ -230,21 +233,21 @@ def remove_old_fields(cls, values): default=None, description="The concurrency options for the deployment.", ) - tags: List[str] = Field(default_factory=list) - work_queue_name: Optional[str] = Field(None) + tags: list[str] = Field(default_factory=list) + work_queue_name: Optional[str] = Field(default=None) work_pool_name: Optional[str] = Field( default=None, description="The name of the deployment's work pool.", examples=["my-work-pool"], ) - path: Optional[str] = Field(None) - job_variables: Optional[Dict[str, Any]] = Field( + path: Optional[str] = Field(default=None) + job_variables: Optional[dict[str, Any]] = Field( default_factory=dict, description="Overrides to apply to flow run infrastructure at runtime.", ) - entrypoint: Optional[str] = Field(None) - storage_document_id: Optional[UUID] = Field(None) - infrastructure_document_id: Optional[UUID] = Field(None) + entrypoint: Optional[str] = Field(default=None) + storage_document_id: Optional[UUID] = Field(default=None) + infrastructure_document_id: Optional[UUID] = Field(default=None) enforce_parameter_schema: Optional[bool] = Field( default=None, description=( @@ -252,7 +255,7 @@ def remove_old_fields(cls, values): ), ) - def check_valid_configuration(self, base_job_template: dict): + def check_valid_configuration(self, base_job_template: dict[str, Any]) -> None: """Check that the combination of base_job_template defaults and job_variables conforms to the specified schema. """ @@ -276,15 +279,15 @@ def check_valid_configuration(self, base_job_template: dict): class FlowRunUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a flow run.""" - name: Optional[str] = Field(None) - flow_version: Optional[str] = Field(None) - parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) + name: Optional[str] = Field(default=None) + flow_version: Optional[str] = Field(default=None) + parameters: Optional[dict[str, Any]] = Field(default_factory=dict) empirical_policy: objects.FlowRunPolicy = Field( default_factory=objects.FlowRunPolicy ) - tags: List[str] = Field(default_factory=list) - infrastructure_pid: Optional[str] = Field(None) - job_variables: Optional[Dict[str, Any]] = Field(None) + tags: list[str] = Field(default_factory=list) + infrastructure_pid: Optional[str] = Field(default=None) + job_variables: Optional[dict[str, Any]] = Field(default=None) class TaskRunCreate(ActionBaseModel): @@ -300,7 +303,7 @@ class TaskRunCreate(ActionBaseModel): default=None, description="The name of the task run", ) - flow_run_id: Optional[UUID] = Field(None) + flow_run_id: Optional[UUID] = Field(default=None) task_key: str = Field( default=..., description="A unique identifier for the task being run." ) @@ -311,17 +314,17 @@ class TaskRunCreate(ActionBaseModel): " within the same flow run." ), ) - cache_key: Optional[str] = Field(None) - cache_expiration: Optional[objects.DateTime] = Field(None) - task_version: Optional[str] = Field(None) + cache_key: Optional[str] = Field(default=None) + cache_expiration: Optional[objects.DateTime] = Field(default=None) + task_version: Optional[str] = Field(default=None) empirical_policy: objects.TaskRunPolicy = Field( default_factory=objects.TaskRunPolicy, ) - tags: List[str] = Field(default_factory=list) - labels: KeyValueLabelsField - task_inputs: Dict[ + tags: list[str] = Field(default_factory=list) + labels: KeyValueLabelsField = Field(default_factory=dict) + task_inputs: dict[ str, - List[ + list[ Union[ objects.TaskRunResult, objects.Parameter, @@ -334,7 +337,7 @@ class TaskRunCreate(ActionBaseModel): class TaskRunUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a task run""" - name: Optional[str] = Field(None) + name: Optional[str] = Field(default=None) class FlowRunCreate(ActionBaseModel): @@ -347,22 +350,23 @@ class FlowRunCreate(ActionBaseModel): name: Optional[str] = Field(default=None, description="The name of the flow run.") flow_id: UUID = Field(default=..., description="The id of the flow being run.") - deployment_id: Optional[UUID] = Field(None) - flow_version: Optional[str] = Field(None) - parameters: Dict[str, Any] = Field( + deployment_id: Optional[UUID] = Field(default=None) + flow_version: Optional[str] = Field(default=None) + parameters: dict[str, Any] = Field( default_factory=dict, description="The parameters for the flow run." ) - context: Dict[str, Any] = Field( + context: dict[str, Any] = Field( default_factory=dict, description="The context for the flow run." ) - parent_task_run_id: Optional[UUID] = Field(None) - infrastructure_document_id: Optional[UUID] = Field(None) + parent_task_run_id: Optional[UUID] = Field(default=None) + infrastructure_document_id: Optional[UUID] = Field(default=None) empirical_policy: objects.FlowRunPolicy = Field( default_factory=objects.FlowRunPolicy ) - tags: List[str] = Field(default_factory=list) - labels: KeyValueLabelsField - idempotency_key: Optional[str] = Field(None) + tags: list[str] = Field(default_factory=list) + idempotency_key: Optional[str] = Field(default=None) + + labels: KeyValueLabelsField = Field(default_factory=dict) class DeploymentFlowRunCreate(ActionBaseModel): @@ -374,32 +378,32 @@ class DeploymentFlowRunCreate(ActionBaseModel): ) name: Optional[str] = Field(default=None, description="The name of the flow run.") - parameters: Dict[str, Any] = Field( + parameters: dict[str, Any] = Field( default_factory=dict, description="The parameters for the flow run." ) enforce_parameter_schema: Optional[bool] = Field( default=None, description="Whether or not to enforce the parameter schema on this run.", ) - context: Dict[str, Any] = Field( + context: dict[str, Any] = Field( default_factory=dict, description="The context for the flow run." ) - infrastructure_document_id: Optional[UUID] = Field(None) + infrastructure_document_id: Optional[UUID] = Field(default=None) empirical_policy: objects.FlowRunPolicy = Field( default_factory=objects.FlowRunPolicy ) - tags: List[str] = Field(default_factory=list) - idempotency_key: Optional[str] = Field(None) - parent_task_run_id: Optional[UUID] = Field(None) - work_queue_name: Optional[str] = Field(None) - job_variables: Optional[dict] = Field(None) + tags: list[str] = Field(default_factory=list) + idempotency_key: Optional[str] = Field(default=None) + parent_task_run_id: Optional[UUID] = Field(default=None) + work_queue_name: Optional[str] = Field(default=None) + job_variables: Optional[dict[str, Any]] = Field(default=None) class SavedSearchCreate(ActionBaseModel): """Data used by the Prefect REST API to create a saved search.""" name: str = Field(default=..., description="The name of the saved search.") - filters: List[objects.SavedSearchFilter] = Field( + filters: list[objects.SavedSearchFilter] = Field( default_factory=list, description="The filter set for the saved search." ) @@ -436,12 +440,12 @@ class ConcurrencyLimitV2Create(ActionBaseModel): class ConcurrencyLimitV2Update(ActionBaseModel): """Data used by the Prefect REST API to update a v2 concurrency limit.""" - active: Optional[bool] = Field(None) - name: Optional[Name] = Field(None) - limit: Optional[NonNegativeInteger] = Field(None) - active_slots: Optional[NonNegativeInteger] = Field(None) - denied_slots: Optional[NonNegativeInteger] = Field(None) - slot_decay_per_second: Optional[NonNegativeFloat] = Field(None) + active: Optional[bool] = Field(default=None) + name: Optional[Name] = Field(default=None) + limit: Optional[NonNegativeInteger] = Field(default=None) + active_slots: Optional[NonNegativeInteger] = Field(default=None) + denied_slots: Optional[NonNegativeInteger] = Field(default=None) + slot_decay_per_second: Optional[NonNegativeFloat] = Field(default=None) class BlockTypeCreate(ActionBaseModel): @@ -471,24 +475,24 @@ class BlockTypeCreate(ActionBaseModel): class BlockTypeUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a block type.""" - logo_url: Optional[objects.HttpUrl] = Field(None) - documentation_url: Optional[objects.HttpUrl] = Field(None) - description: Optional[str] = Field(None) - code_example: Optional[str] = Field(None) + logo_url: Optional[objects.HttpUrl] = Field(default=None) + documentation_url: Optional[objects.HttpUrl] = Field(default=None) + description: Optional[str] = Field(default=None) + code_example: Optional[str] = Field(default=None) @classmethod - def updatable_fields(cls) -> set: + def updatable_fields(cls) -> set[str]: return get_class_fields_only(cls) class BlockSchemaCreate(ActionBaseModel): """Data used by the Prefect REST API to create a block schema.""" - fields: Dict[str, Any] = Field( + fields: dict[str, Any] = Field( default_factory=dict, description="The block schema's field schema" ) - block_type_id: Optional[UUID] = Field(None) - capabilities: List[str] = Field( + block_type_id: Optional[UUID] = Field(default=None) + capabilities: list[str] = Field( default_factory=list, description="A list of Block capabilities", ) @@ -504,7 +508,7 @@ class BlockDocumentCreate(ActionBaseModel): name: Optional[Name] = Field( default=None, description="The name of the block document" ) - data: Dict[str, Any] = Field( + data: dict[str, Any] = Field( default_factory=dict, description="The block document's data" ) block_schema_id: UUID = Field( @@ -524,7 +528,9 @@ class BlockDocumentCreate(ActionBaseModel): _validate_name_format = field_validator("name")(validate_block_document_name) @model_validator(mode="before") - def validate_name_is_present_if_not_anonymous(cls, values): + def validate_name_is_present_if_not_anonymous( + cls, values: dict[str, Any] + ) -> dict[str, Any]: return validate_name_present_on_nonanonymous_blocks(values) @@ -534,7 +540,7 @@ class BlockDocumentUpdate(ActionBaseModel): block_schema_id: Optional[UUID] = Field( default=None, description="A block schema ID" ) - data: Dict[str, Any] = Field( + data: dict[str, Any] = Field( default_factory=dict, description="The block document's data" ) merge_existing_data: bool = Field( @@ -565,11 +571,11 @@ class LogCreate(ActionBaseModel): level: int = Field(default=..., description="The log level.") message: str = Field(default=..., description="The log message.") timestamp: DateTime = Field(default=..., description="The log timestamp.") - flow_run_id: Optional[UUID] = Field(None) - task_run_id: Optional[UUID] = Field(None) - worker_id: Optional[UUID] = Field(None) + flow_run_id: Optional[UUID] = Field(default=None) + task_run_id: Optional[UUID] = Field(default=None) + worker_id: Optional[UUID] = Field(default=None) - def model_dump(self, *args, **kwargs): + def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """ The worker_id field is only included in logs sent to Prefect Cloud. If it's unset, we should not include it in the log payload. @@ -586,11 +592,11 @@ class WorkPoolCreate(ActionBaseModel): name: NonEmptyishName = Field( description="The name of the work pool.", ) - description: Optional[str] = Field(None) + description: Optional[str] = Field(default=None) type: str = Field( description="The work pool type.", default="prefect-agent" ) # TODO: change default - base_job_template: Dict[str, Any] = Field( + base_job_template: dict[str, Any] = Field( default_factory=dict, description="The base job template for the work pool.", ) @@ -606,17 +612,17 @@ class WorkPoolCreate(ActionBaseModel): class WorkPoolUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a work pool.""" - description: Optional[str] = Field(None) - is_paused: Optional[bool] = Field(None) - base_job_template: Optional[Dict[str, Any]] = Field(None) - concurrency_limit: Optional[int] = Field(None) + description: Optional[str] = Field(default=None) + is_paused: Optional[bool] = Field(default=None) + base_job_template: Optional[dict[str, Any]] = Field(default=None) + concurrency_limit: Optional[int] = Field(default=None) class WorkQueueCreate(ActionBaseModel): """Data used by the Prefect REST API to create a work queue.""" name: str = Field(default=..., description="The name of the work queue.") - description: Optional[str] = Field(None) + description: Optional[str] = Field(default=None) is_paused: bool = Field( default=False, description="Whether the work queue is paused.", @@ -644,16 +650,16 @@ class WorkQueueCreate(ActionBaseModel): class WorkQueueUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a work queue.""" - name: Optional[str] = Field(None) - description: Optional[str] = Field(None) + name: Optional[str] = Field(default=None) + description: Optional[str] = Field(default=None) is_paused: bool = Field( default=False, description="Whether or not the work queue is paused." ) - concurrency_limit: Optional[NonNegativeInteger] = Field(None) + concurrency_limit: Optional[NonNegativeInteger] = Field(default=None) priority: Optional[PositiveInteger] = Field( None, description="The queue's priority." ) - last_polled: Optional[DateTime] = Field(None) + last_polled: Optional[DateTime] = Field(default=None) # DEPRECATED @@ -670,10 +676,10 @@ class FlowRunNotificationPolicyCreate(ActionBaseModel): is_active: bool = Field( default=True, description="Whether the policy is currently active" ) - state_names: List[str] = Field( + state_names: list[str] = Field( default=..., description="The flow run states that trigger notifications" ) - tags: List[str] = Field( + tags: list[str] = Field( default=..., description="The flow run tags that trigger notifications (set [] to disable)", ) @@ -695,7 +701,7 @@ class FlowRunNotificationPolicyCreate(ActionBaseModel): @field_validator("message_template") @classmethod - def validate_message_template_variables(cls, v): + def validate_message_template_variables(cls, v: Optional[str]) -> Optional[str]: return validate_message_template_variables(v) @@ -703,8 +709,8 @@ class FlowRunNotificationPolicyUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a flow run notification policy.""" is_active: Optional[bool] = Field(default=None) - state_names: Optional[List[str]] = Field(default=None) - tags: Optional[List[str]] = Field(default=None) + state_names: Optional[list[str]] = Field(default=None) + tags: Optional[list[str]] = Field(default=None) block_document_id: Optional[UUID] = Field(default=None) message_template: Optional[str] = Field(default=None) @@ -715,8 +721,8 @@ class ArtifactCreate(ActionBaseModel): key: Optional[str] = Field(default=None) type: Optional[str] = Field(default=None) description: Optional[str] = Field(default=None) - data: Optional[Union[Dict[str, Any], Any]] = Field(default=None) - metadata_: Optional[Dict[str, str]] = Field(default=None) + data: Optional[Union[dict[str, Any], Any]] = Field(default=None) + metadata_: Optional[dict[str, str]] = Field(default=None) flow_run_id: Optional[UUID] = Field(default=None) task_run_id: Optional[UUID] = Field(default=None) @@ -726,9 +732,9 @@ class ArtifactCreate(ActionBaseModel): class ArtifactUpdate(ActionBaseModel): """Data used by the Prefect REST API to update an artifact.""" - data: Optional[Union[Dict[str, Any], Any]] = Field(None) - description: Optional[str] = Field(None) - metadata_: Optional[Dict[str, str]] = Field(None) + data: Optional[Union[dict[str, Any], Any]] = Field(default=None) + description: Optional[str] = Field(default=None) + metadata_: Optional[dict[str, str]] = Field(default=None) class VariableCreate(ActionBaseModel): @@ -745,7 +751,7 @@ class VariableCreate(ActionBaseModel): description="The value of the variable", examples=["my-value"], ) - tags: Optional[List[str]] = Field(default=None) + tags: Optional[list[str]] = Field(default=None) # validators _validate_name_format = field_validator("name")(validate_variable_name) @@ -765,7 +771,7 @@ class VariableUpdate(ActionBaseModel): description="The value of the variable", examples=["my-value"], ) - tags: Optional[List[str]] = Field(default=None) + tags: Optional[list[str]] = Field(default=None) # validators _validate_name_format = field_validator("name")(validate_variable_name) @@ -801,8 +807,8 @@ class GlobalConcurrencyLimitCreate(ActionBaseModel): class GlobalConcurrencyLimitUpdate(ActionBaseModel): """Data used by the Prefect REST API to update a global concurrency limit.""" - name: Optional[Name] = Field(None) - limit: Optional[NonNegativeInteger] = Field(None) - active: Optional[bool] = Field(None) - active_slots: Optional[NonNegativeInteger] = Field(None) - slot_decay_per_second: Optional[NonNegativeFloat] = Field(None) + name: Optional[Name] = Field(default=None) + limit: Optional[NonNegativeInteger] = Field(default=None) + active: Optional[bool] = Field(default=None) + active_slots: Optional[NonNegativeInteger] = Field(default=None) + slot_decay_per_second: Optional[NonNegativeFloat] = Field(default=None) diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index ccd802b3dda4..087cd5b78ee3 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -1,15 +1,16 @@ import datetime import warnings +from collections.abc import Callable, Mapping from functools import partial from typing import ( TYPE_CHECKING, Annotated, Any, - Dict, + ClassVar, Generic, - List, Optional, Union, + cast, overload, ) from uuid import UUID, uuid4 @@ -23,13 +24,12 @@ HttpUrl, IPvAnyNetwork, SerializationInfo, + SerializerFunctionWrapHandler, Tag, field_validator, model_serializer, model_validator, ) -from pydantic.functional_validators import ModelWrapValidatorHandler -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Literal, Self, TypeVar from prefect._internal.compatibility import deprecated @@ -64,8 +64,13 @@ from prefect.utilities.pydantic import handle_secret_render if TYPE_CHECKING: + from prefect.client.schemas.actions import StateCreate from prefect.results import BaseResult, ResultRecordMetadata + DateTime = pendulum.DateTime +else: + from pydantic_extra_types.pendulum_dt import DateTime + R = TypeVar("R", default=Any) @@ -180,7 +185,7 @@ class StateDetails(PrefectBaseModel): pause_timeout: Optional[DateTime] = None pause_reschedule: bool = False pause_key: Optional[str] = None - run_input_keyset: Optional[Dict[str, str]] = None + run_input_keyset: Optional[dict[str, str]] = None refresh_cache: Optional[bool] = None retriable: Optional[bool] = None transition_id: Optional[UUID] = None @@ -215,11 +220,21 @@ class State(ObjectBaseModel, Generic[R]): ] = Field(default=None) @overload - def result(self: "State[R]", raise_on_failure: bool = True) -> R: + def result( + self: "State[R]", + raise_on_failure: Literal[True] = ..., + fetch: bool = ..., + retry_result_failure: bool = ..., + ) -> R: ... @overload - def result(self: "State[R]", raise_on_failure: bool = False) -> Union[R, Exception]: + def result( + self: "State[R]", + raise_on_failure: Literal[False] = False, + fetch: bool = ..., + retry_result_failure: bool = ..., + ) -> Union[R, Exception]: ... @deprecated.deprecated_parameter( @@ -311,7 +326,7 @@ def result( retry_result_failure=retry_result_failure, ) - def to_state_create(self): + def to_state_create(self) -> "StateCreate": """ Convert this state to a `StateCreate` type which can be used to set the state of a run in the API. @@ -327,7 +342,7 @@ def to_state_create(self): ) if isinstance(self.data, BaseResult): - data = self.data + data = cast(BaseResult[R], self.data) elif isinstance(self.data, ResultRecord) and should_persist_result(): data = self.data.metadata else: @@ -348,14 +363,14 @@ def default_name_from_type(self) -> Self: # validation check and an error will be raised after this function is called name = self.name if name is None and self.type: - self.name = " ".join([v.capitalize() for v in self.type.value.split("_")]) + self.name = " ".join([v.capitalize() for v in self.type.split("_")]) return self @model_validator(mode="after") def default_scheduled_start_time(self) -> Self: if self.type == StateType.SCHEDULED: if not self.state_details.scheduled_time: - self.state_details.scheduled_time = DateTime.now("utc") + self.state_details.scheduled_time = pendulum.DateTime.now("utc") return self @model_validator(mode="after") @@ -395,17 +410,19 @@ def is_paused(self) -> bool: return self.type == StateType.PAUSED def model_copy( - self, *, update: Optional[Dict[str, Any]] = None, deep: bool = False - ): + self, *, update: Optional[Mapping[str, Any]] = None, deep: bool = False + ) -> Self: """ Copying API models should return an object that could be inserted into the database again. The 'timestamp' is reset using the default factory. """ - update = update or {} - update.setdefault("timestamp", self.model_fields["timestamp"].get_default()) + update = { + "timestamp": self.model_fields["timestamp"].get_default(), + **(update or {}), + } return super().model_copy(update=update, deep=deep) - def fresh_copy(self, **kwargs) -> Self: + def fresh_copy(self, **kwargs: Any) -> Self: """ Return a fresh copy of the state with a new ID. """ @@ -443,12 +460,14 @@ def __str__(self) -> str: `MyCompletedState("my message", type=COMPLETED)` """ - display = [] + display: list[str] = [] if self.message: display.append(repr(self.message)) - if self.type.value.lower() != self.name.lower(): + if TYPE_CHECKING: + assert self.name is not None + if self.type.lower() != self.name.lower(): display.append(f"type={self.type.value}") return f"{self.name}({', '.join(display)})" @@ -487,7 +506,7 @@ class FlowRunPolicy(PrefectBaseModel): retry_delay: Optional[int] = Field( default=None, description="The delay time between retries, in seconds." ) - pause_keys: Optional[set] = Field( + pause_keys: Optional[set[str]] = Field( default_factory=set, description="Tracks pauses this run has observed." ) resuming: Optional[bool] = Field( @@ -499,7 +518,7 @@ class FlowRunPolicy(PrefectBaseModel): @model_validator(mode="before") @classmethod - def populate_deprecated_fields(cls, values: Any): + def populate_deprecated_fields(cls, values: Any) -> Any: if isinstance(values, dict): return set_run_policy_deprecated_fields(values) return values @@ -536,7 +555,7 @@ class FlowRun(ObjectBaseModel): description="The version of the flow executed in this flow run.", examples=["1.0"], ) - parameters: Dict[str, Any] = Field( + parameters: dict[str, Any] = Field( default_factory=dict, description="Parameters for the flow run." ) idempotency_key: Optional[str] = Field( @@ -546,7 +565,7 @@ class FlowRun(ObjectBaseModel): " run is not created multiple times." ), ) - context: Dict[str, Any] = Field( + context: dict[str, Any] = Field( default_factory=dict, description="Additional context for the flow run.", examples=[{"my_var": "my_val"}], @@ -554,7 +573,7 @@ class FlowRun(ObjectBaseModel): empirical_policy: FlowRunPolicy = Field( default_factory=FlowRunPolicy, ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of tags on the flow run", examples=[["tag-1", "tag-2"]], @@ -632,7 +651,7 @@ class FlowRun(ObjectBaseModel): description="The state of the flow run.", examples=["State(type=StateType.COMPLETED)"], ) - job_variables: Optional[dict] = Field( + job_variables: Optional[dict[str, Any]] = Field( default=None, description="Job variables for the flow run.", ) @@ -663,7 +682,7 @@ def __eq__(self, other: Any) -> bool: @field_validator("name", mode="before") @classmethod - def set_default_name(cls, name): + def set_default_name(cls, name: Optional[str]) -> str: return get_or_create_run_name(name) @@ -687,7 +706,7 @@ class TaskRunPolicy(PrefectBaseModel): deprecated=True, ) retries: Optional[int] = Field(default=None, description="The number of retries.") - retry_delay: Union[None, int, List[int]] = Field( + retry_delay: Union[None, int, list[int]] = Field( default=None, description="A delay time or list of delay times between retries, in seconds.", ) @@ -710,18 +729,20 @@ def populate_deprecated_fields(self): self.retries = self.max_retries if not self.retry_delay and self.retry_delay_seconds != 0: - self.retry_delay = self.retry_delay_seconds + self.retry_delay = int(self.retry_delay_seconds) return self @field_validator("retry_delay") @classmethod - def validate_configured_retry_delays(cls, v): + def validate_configured_retry_delays( + cls, v: Optional[list[float]] + ) -> Optional[list[float]]: return list_length_50_or_less(v) @field_validator("retry_jitter_factor") @classmethod - def validate_jitter_factor(cls, v): + def validate_jitter_factor(cls, v: Optional[float]) -> Optional[float]: return validate_not_negative(v) @@ -731,9 +752,11 @@ class TaskRunInput(PrefectBaseModel): could include, constants, parameters, or other task runs. """ - model_config = ConfigDict(frozen=True) + model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True) - input_type: str + if not TYPE_CHECKING: + # subclasses provide the concrete type for this field + input_type: str class TaskRunResult(TaskRunInput): @@ -791,7 +814,7 @@ class TaskRun(ObjectBaseModel): empirical_policy: TaskRunPolicy = Field( default_factory=TaskRunPolicy, ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of tags for the task run.", examples=[["tag-1", "tag-2"]], @@ -800,7 +823,7 @@ class TaskRun(ObjectBaseModel): state_id: Optional[UUID] = Field( default=None, description="The id of the current task run state." ) - task_inputs: Dict[str, List[Union[TaskRunResult, Parameter, Constant]]] = Field( + task_inputs: dict[str, list[Union[TaskRunResult, Parameter, Constant]]] = Field( default_factory=dict, description=( "Tracks the source of inputs to a task run. Used for internal bookkeeping. " @@ -865,7 +888,7 @@ class TaskRun(ObjectBaseModel): @field_validator("name", mode="before") @classmethod - def set_default_name(cls, name): + def set_default_name(cls, name: Optional[str]) -> Name: return get_or_create_run_name(name) @@ -883,7 +906,7 @@ class Workspace(PrefectBaseModel): workspace_name: str = Field(..., description="The workspace name.") workspace_description: str = Field(..., description="Description of the workspace.") workspace_handle: str = Field(..., description="The workspace's unique handle.") - model_config = ConfigDict(extra="ignore") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore") @property def handle(self) -> str: @@ -912,7 +935,7 @@ def ui_url(self) -> str: f"/workspace/{self.workspace_id}" ) - def __hash__(self): + def __hash__(self) -> int: return hash(self.handle) @@ -935,7 +958,7 @@ class IPAllowlist(PrefectBaseModel): Expected payload for an IP allowlist from the Prefect Cloud API. """ - entries: List[IPAllowlistEntry] + entries: list[IPAllowlistEntry] class IPAllowlistMyAccessResponse(PrefectBaseModel): @@ -973,14 +996,14 @@ class BlockSchema(ObjectBaseModel): """A representation of a block schema.""" checksum: str = Field(default=..., description="The block schema's unique checksum") - fields: Dict[str, Any] = Field( + fields: dict[str, Any] = Field( default_factory=dict, description="The block schema's field schema" ) block_type_id: Optional[UUID] = Field(default=..., description="A block type ID") block_type: Optional[BlockType] = Field( default=None, description="The associated block type" ) - capabilities: List[str] = Field( + capabilities: list[str] = Field( default_factory=list, description="A list of Block capabilities", ) @@ -999,7 +1022,7 @@ class BlockDocument(ObjectBaseModel): "The block document's name. Not required for anonymous block documents." ), ) - data: Dict[str, Any] = Field( + data: dict[str, Any] = Field( default_factory=dict, description="The block document's data" ) block_schema_id: UUID = Field(default=..., description="A block schema ID") @@ -1011,7 +1034,7 @@ class BlockDocument(ObjectBaseModel): block_type: Optional[BlockType] = Field( default=None, description="The associated block type" ) - block_document_references: Dict[str, Dict[str, Any]] = Field( + block_document_references: dict[str, dict[str, Any]] = Field( default_factory=dict, description="Record of the block document's references" ) is_anonymous: bool = Field( @@ -1026,13 +1049,15 @@ class BlockDocument(ObjectBaseModel): @model_validator(mode="before") @classmethod - def validate_name_is_present_if_not_anonymous(cls, values): + def validate_name_is_present_if_not_anonymous( + cls, values: dict[str, Any] + ) -> dict[str, Any]: return validate_name_present_on_nonanonymous_blocks(values) @model_serializer(mode="wrap") def serialize_data( - self, handler: ModelWrapValidatorHandler, info: SerializationInfo - ): + self, handler: SerializerFunctionWrapHandler, info: SerializationInfo + ) -> Any: self.data = visit_collection( self.data, visit_fn=partial(handle_secret_render, context=info.context or {}), @@ -1047,7 +1072,7 @@ class Flow(ObjectBaseModel): name: Name = Field( default=..., description="The name of the flow", examples=["my-flow"] ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of flow tags", examples=[["tag-1", "tag-2"]], @@ -1091,22 +1116,22 @@ class Deployment(ObjectBaseModel): concurrency_limit: Optional[int] = Field( default=None, description="The concurrency limit for the deployment." ) - schedules: List[DeploymentSchedule] = Field( + schedules: list[DeploymentSchedule] = Field( default_factory=list, description="A list of schedules for the deployment." ) - job_variables: Dict[str, Any] = Field( + job_variables: dict[str, Any] = Field( default_factory=dict, description="Overrides to apply to flow run infrastructure at runtime.", ) - parameters: Dict[str, Any] = Field( + parameters: dict[str, Any] = Field( default_factory=dict, description="Parameters for flow runs scheduled by the deployment.", ) - pull_steps: Optional[List[dict]] = Field( + pull_steps: Optional[list[dict[str, Any]]] = Field( default=None, description="Pull steps for cloning and running this deployment.", ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of tags for the deployment", examples=[["tag-1", "tag-2"]], @@ -1123,7 +1148,7 @@ class Deployment(ObjectBaseModel): default=None, description="The last time the deployment was polled for status updates.", ) - parameter_openapi_schema: Optional[Dict[str, Any]] = Field( + parameter_openapi_schema: Optional[dict[str, Any]] = Field( default=None, description="The parameter schema of the flow, including defaults.", ) @@ -1177,7 +1202,7 @@ class ConcurrencyLimit(ObjectBaseModel): default=..., description="A tag the concurrency limit is applied to." ) concurrency_limit: int = Field(default=..., description="The concurrency limit.") - active_slots: List[UUID] = Field( + active_slots: list[UUID] = Field( default_factory=list, description="A list of active run ids using a concurrency slot", ) @@ -1224,7 +1249,7 @@ class BlockDocumentReference(ObjectBaseModel): @model_validator(mode="before") @classmethod - def validate_parent_and_ref_are_different(cls, values): + def validate_parent_and_ref_are_different(cls, values: Any) -> Any: if isinstance(values, dict): return validate_parent_and_ref_diff(values) return values @@ -1234,7 +1259,7 @@ class Configuration(ObjectBaseModel): """An ORM representation of account info.""" key: str = Field(default=..., description="Account info key") - value: Dict[str, Any] = Field(default=..., description="Account info") + value: dict[str, Any] = Field(default=..., description="Account info") class SavedSearchFilter(PrefectBaseModel): @@ -1258,7 +1283,7 @@ class SavedSearch(ObjectBaseModel): """An ORM representation of saved search data. Represents a set of filter criteria.""" name: str = Field(default=..., description="The name of the saved search.") - filters: List[SavedSearchFilter] = Field( + filters: list[SavedSearchFilter] = Field( default_factory=list, description="The filter set for the saved search." ) @@ -1281,11 +1306,11 @@ class Log(ObjectBaseModel): class QueueFilter(PrefectBaseModel): """Filter criteria definition for a work queue.""" - tags: Optional[List[str]] = Field( + tags: Optional[list[str]] = Field( default=None, description="Only include flow runs with these tags in the work queue.", ) - deployment_ids: Optional[List[UUID]] = Field( + deployment_ids: Optional[list[UUID]] = Field( default=None, description="Only include flow runs from these deployments in the work queue.", ) @@ -1345,7 +1370,7 @@ class WorkQueueHealthPolicy(PrefectBaseModel): ) def evaluate_health_status( - self, late_runs_count: int, last_polled: Optional[DateTime] = None + self, late_runs_count: int, last_polled: Optional[pendulum.DateTime] = None ) -> bool: """ Given empirical information about the state of the work queue, evaluate its health status. @@ -1397,10 +1422,10 @@ class FlowRunNotificationPolicy(ObjectBaseModel): is_active: bool = Field( default=True, description="Whether the policy is currently active" ) - state_names: List[str] = Field( + state_names: list[str] = Field( default=..., description="The flow run states that trigger notifications" ) - tags: List[str] = Field( + tags: list[str] = Field( default=..., description="The flow run tags that trigger notifications (set [] to disable)", ) @@ -1422,7 +1447,7 @@ class FlowRunNotificationPolicy(ObjectBaseModel): @field_validator("message_template") @classmethod - def validate_message_template_variables(cls, v): + def validate_message_template_variables(cls, v: Optional[str]) -> Optional[str]: return validate_message_template_variables(v) @@ -1454,7 +1479,7 @@ class WorkPool(ObjectBaseModel): default=None, description="A description of the work pool." ) type: str = Field(description="The work pool type.") - base_job_template: Dict[str, Any] = Field( + base_job_template: dict[str, Any] = Field( default_factory=dict, description="The work pool's base job template." ) is_paused: bool = Field( @@ -1469,10 +1494,12 @@ class WorkPool(ObjectBaseModel): ) # this required field has a default of None so that the custom validator - # below will be called and produce a more helpful error message - default_queue_id: UUID = Field( - None, description="The id of the pool's default queue." - ) + # below will be called and produce a more helpful error message. Because + # the field metadata is attached via an annotation, the default is hidden + # from type checkers. + default_queue_id: Annotated[ + UUID, Field(default=None, description="The id of the pool's default queue.") + ] @property def is_push_pool(self) -> bool: @@ -1484,7 +1511,7 @@ def is_managed_pool(self) -> bool: @field_validator("default_queue_id") @classmethod - def helpful_error_for_missing_default_queue_id(cls, v): + def helpful_error_for_missing_default_queue_id(cls, v: Optional[UUID]) -> UUID: return validate_default_queue_id_not_none(v) @@ -1495,8 +1522,8 @@ class Worker(ObjectBaseModel): work_pool_id: UUID = Field( description="The work pool with which the queue is associated." ) - last_heartbeat_time: datetime.datetime = Field( - None, description="The last time the worker process sent a heartbeat." + last_heartbeat_time: Optional[datetime.datetime] = Field( + default=None, description="The last time the worker process sent a heartbeat." ) heartbeat_interval_seconds: Optional[int] = Field( default=None, @@ -1529,14 +1556,14 @@ class Artifact(ObjectBaseModel): default=None, description="A markdown-enabled description of the artifact." ) # data will eventually be typed as `Optional[Union[Result, Any]]` - data: Optional[Union[Dict[str, Any], Any]] = Field( + data: Optional[Union[dict[str, Any], Any]] = Field( default=None, description=( "Data associated with the artifact, e.g. a result.; structure depends on" " the artifact type." ), ) - metadata_: Optional[Dict[str, str]] = Field( + metadata_: Optional[dict[str, str]] = Field( default=None, description=( "User-defined artifact metadata. Content must be string key and value" @@ -1552,7 +1579,9 @@ class Artifact(ObjectBaseModel): @field_validator("metadata_") @classmethod - def validate_metadata_length(cls, v): + def validate_metadata_length( + cls, v: Optional[dict[str, str]] + ) -> Optional[dict[str, str]]: return validate_max_metadata_length(v) @@ -1571,14 +1600,14 @@ class ArtifactCollection(ObjectBaseModel): description: Optional[str] = Field( default=None, description="A markdown-enabled description of the artifact." ) - data: Optional[Union[Dict[str, Any], Any]] = Field( + data: Optional[Union[dict[str, Any], Any]] = Field( default=None, description=( "Data associated with the artifact, e.g. a result.; structure depends on" " the artifact type." ), ) - metadata_: Optional[Dict[str, str]] = Field( + metadata_: Optional[dict[str, str]] = Field( default=None, description=( "User-defined artifact metadata. Content must be string key and value" @@ -1605,7 +1634,7 @@ class Variable(ObjectBaseModel): description="The value of the variable", examples=["my_value"], ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of variable tags", examples=[["tag-1", "tag-2"]], @@ -1630,7 +1659,7 @@ def decoded_value(self) -> Any: @field_validator("key", check_fields=False) @classmethod - def validate_name_characters(cls, v): + def validate_name_characters(cls, v: str) -> str: raise_on_name_alphanumeric_dashes_only(v) return v @@ -1675,7 +1704,7 @@ class CsrfToken(ObjectBaseModel): ) -__getattr__ = getattr_migration(__name__) +__getattr__: Callable[[str], Any] = getattr_migration(__name__) class Integration(PrefectBaseModel): @@ -1693,7 +1722,7 @@ class WorkerMetadata(PrefectBaseModel): should support flexible metadata. """ - integrations: List[Integration] = Field( + integrations: list[Integration] = Field( default=..., description="Prefect integrations installed in the worker." ) - model_config = ConfigDict(extra="allow") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow") diff --git a/src/prefect/client/schemas/responses.py b/src/prefect/client/schemas/responses.py index 29102b65f022..cb27a6f55392 100644 --- a/src/prefect/client/schemas/responses.py +++ b/src/prefect/client/schemas/responses.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Dict, List, Optional, TypeVar, Union +from typing import Any, ClassVar, Generic, Optional, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -13,7 +13,7 @@ from prefect.utilities.collections import AutoEnum from prefect.utilities.names import generate_slug -R = TypeVar("R") +T = TypeVar("T") class SetStateStatus(AutoEnum): @@ -120,7 +120,7 @@ class HistoryResponse(PrefectBaseModel): interval_end: DateTime = Field( default=..., description="The end date of the interval." ) - states: List[HistoryResponseState] = Field( + states: list[HistoryResponseState] = Field( default=..., description="A list of state histories during the interval." ) @@ -130,18 +130,18 @@ class HistoryResponse(PrefectBaseModel): ] -class OrchestrationResult(PrefectBaseModel): +class OrchestrationResult(PrefectBaseModel, Generic[T]): """ A container for the output of state orchestration. """ - state: Optional[objects.State] + state: Optional[objects.State[T]] status: SetStateStatus details: StateResponseDetails class WorkerFlowRunResponse(PrefectBaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) work_pool_id: UUID work_queue_id: UUID @@ -179,7 +179,7 @@ class FlowRunResponse(ObjectBaseModel): description="The version of the flow executed in this flow run.", examples=["1.0"], ) - parameters: Dict[str, Any] = Field( + parameters: dict[str, Any] = Field( default_factory=dict, description="Parameters for the flow run." ) idempotency_key: Optional[str] = Field( @@ -189,7 +189,7 @@ class FlowRunResponse(ObjectBaseModel): " run is not created multiple times." ), ) - context: Dict[str, Any] = Field( + context: dict[str, Any] = Field( default_factory=dict, description="Additional context for the flow run.", examples=[{"my_var": "my_val"}], @@ -197,7 +197,7 @@ class FlowRunResponse(ObjectBaseModel): empirical_policy: objects.FlowRunPolicy = Field( default_factory=objects.FlowRunPolicy, ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of tags on the flow run", examples=[["tag-1", "tag-2"]], @@ -275,7 +275,7 @@ class FlowRunResponse(ObjectBaseModel): description="The state of the flow run.", examples=["objects.State(type=objects.StateType.COMPLETED)"], ) - job_variables: Optional[dict] = Field( + job_variables: Optional[dict[str, Any]] = Field( default=None, description="Job variables for the flow run." ) @@ -335,22 +335,22 @@ class DeploymentResponse(ObjectBaseModel): default=None, description="The concurrency options for the deployment.", ) - schedules: List[objects.DeploymentSchedule] = Field( + schedules: list[objects.DeploymentSchedule] = Field( default_factory=list, description="A list of schedules for the deployment." ) - job_variables: Dict[str, Any] = Field( + job_variables: dict[str, Any] = Field( default_factory=dict, description="Overrides to apply to flow run infrastructure at runtime.", ) - parameters: Dict[str, Any] = Field( + parameters: dict[str, Any] = Field( default_factory=dict, description="Parameters for flow runs scheduled by the deployment.", ) - pull_steps: Optional[List[dict]] = Field( + pull_steps: Optional[list[dict[str, Any]]] = Field( default=None, description="Pull steps for cloning and running this deployment.", ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="A list of tags for the deployment", examples=[["tag-1", "tag-2"]], @@ -367,7 +367,7 @@ class DeploymentResponse(ObjectBaseModel): default=None, description="The last time the deployment was polled for status updates.", ) - parameter_openapi_schema: Optional[Dict[str, Any]] = Field( + parameter_openapi_schema: Optional[dict[str, Any]] = Field( default=None, description="The parameter schema of the flow, including defaults.", ) @@ -400,7 +400,7 @@ class DeploymentResponse(ObjectBaseModel): default=None, description="Optional information about the updater of this deployment.", ) - work_queue_id: UUID = Field( + work_queue_id: Optional[UUID] = Field( default=None, description=( "The id of the work pool queue to which this deployment is assigned." @@ -423,7 +423,7 @@ class DeploymentResponse(ObjectBaseModel): class MinimalConcurrencyLimitResponse(PrefectBaseModel): - model_config = ConfigDict(extra="ignore") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore") id: UUID name: str diff --git a/src/prefect/client/schemas/schedules.py b/src/prefect/client/schemas/schedules.py index 1a2b97a74f8f..4b9cf1b3cf5b 100644 --- a/src/prefect/client/schemas/schedules.py +++ b/src/prefect/client/schemas/schedules.py @@ -3,13 +3,13 @@ """ import datetime -from typing import Annotated, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Optional, Union import dateutil import dateutil.rrule +import dateutil.tz import pendulum from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import TypeAlias, TypeGuard from prefect._internal.schemas.bases import PrefectBaseModel @@ -20,6 +20,14 @@ validate_rrule_string, ) +if TYPE_CHECKING: + # type checkers have difficulty accepting that + # pydantic_extra_types.pendulum_dt and pendulum.DateTime can be used + # together. + DateTime = pendulum.DateTime +else: + from pydantic_extra_types.pendulum_dt import DateTime + MAX_ITERATIONS = 1000 # approx. 1 years worth of RDATEs + buffer MAX_RRULE_LENGTH = 6500 @@ -54,7 +62,7 @@ class IntervalSchedule(PrefectBaseModel): timezone (str, optional): a valid timezone string """ - model_config = ConfigDict(extra="forbid", exclude_none=True) + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") interval: datetime.timedelta = Field(gt=datetime.timedelta(0)) anchor_date: Annotated[DateTime, AfterValidator(default_anchor_date)] = Field( @@ -68,6 +76,19 @@ def validate_timezone(self): self.timezone = default_timezone(self.timezone, self.model_dump()) return self + if TYPE_CHECKING: + # The model accepts str or datetime values for `anchor_date` + def __init__( + self, + /, + interval: datetime.timedelta, + anchor_date: Optional[ + Union[pendulum.DateTime, datetime.datetime, str] + ] = None, + timezone: Optional[str] = None, + ) -> None: + ... + class CronSchedule(PrefectBaseModel): """ @@ -94,7 +115,7 @@ class CronSchedule(PrefectBaseModel): """ - model_config = ConfigDict(extra="forbid") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") cron: str = Field(default=..., examples=["0 0 * * *"]) timezone: Optional[str] = Field(default=None, examples=["America/New_York"]) @@ -107,18 +128,36 @@ class CronSchedule(PrefectBaseModel): @field_validator("timezone") @classmethod - def valid_timezone(cls, v): + def valid_timezone(cls, v: Optional[str]) -> str: return default_timezone(v) @field_validator("cron") @classmethod - def valid_cron_string(cls, v): + def valid_cron_string(cls, v: str) -> str: return validate_cron_string(v) DEFAULT_ANCHOR_DATE = pendulum.date(2020, 1, 1) +def _rrule_dt( + rrule: dateutil.rrule.rrule, name: str = "_dtstart" +) -> Optional[datetime.datetime]: + return getattr(rrule, name, None) + + +def _rrule( + rruleset: dateutil.rrule.rruleset, name: str = "_rrule" +) -> list[dateutil.rrule.rrule]: + return getattr(rruleset, name, []) + + +def _rdates( + rrule: dateutil.rrule.rruleset, name: str = "_rdate" +) -> list[datetime.datetime]: + return getattr(rrule, name, []) + + class RRuleSchedule(PrefectBaseModel): """ RRule schedule, based on the iCalendar standard @@ -139,7 +178,7 @@ class RRuleSchedule(PrefectBaseModel): timezone (str, optional): a valid timezone string """ - model_config = ConfigDict(extra="forbid") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") rrule: str timezone: Optional[str] = Field( @@ -148,58 +187,60 @@ class RRuleSchedule(PrefectBaseModel): @field_validator("rrule") @classmethod - def validate_rrule_str(cls, v): + def validate_rrule_str(cls, v: str) -> str: return validate_rrule_string(v) @classmethod - def from_rrule(cls, rrule: dateutil.rrule.rrule): + def from_rrule( + cls, rrule: Union[dateutil.rrule.rrule, dateutil.rrule.rruleset] + ) -> "RRuleSchedule": if isinstance(rrule, dateutil.rrule.rrule): - if rrule._dtstart.tzinfo is not None: - timezone = rrule._dtstart.tzinfo.name + dtstart = _rrule_dt(rrule) + if dtstart and dtstart.tzinfo is not None: + timezone = dtstart.tzinfo.tzname(dtstart) else: timezone = "UTC" return RRuleSchedule(rrule=str(rrule), timezone=timezone) - elif isinstance(rrule, dateutil.rrule.rruleset): - dtstarts = [rr._dtstart for rr in rrule._rrule if rr._dtstart is not None] - unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts) - unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None) - - if len(unique_timezones) > 1: - raise ValueError( - f"rruleset has too many dtstart timezones: {unique_timezones}" - ) - - if len(unique_dstarts) > 1: - raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}") - - if unique_dstarts and unique_timezones: - timezone = dtstarts[0].tzinfo.name - else: - timezone = "UTC" - - rruleset_string = "" - if rrule._rrule: - rruleset_string += "\n".join(str(r) for r in rrule._rrule) - if rrule._exrule: - rruleset_string += "\n" if rruleset_string else "" - rruleset_string += "\n".join(str(r) for r in rrule._exrule).replace( - "RRULE", "EXRULE" - ) - if rrule._rdate: - rruleset_string += "\n" if rruleset_string else "" - rruleset_string += "RDATE:" + ",".join( - rd.strftime("%Y%m%dT%H%M%SZ") for rd in rrule._rdate - ) - if rrule._exdate: - rruleset_string += "\n" if rruleset_string else "" - rruleset_string += "EXDATE:" + ",".join( - exd.strftime("%Y%m%dT%H%M%SZ") for exd in rrule._exdate - ) - return RRuleSchedule(rrule=rruleset_string, timezone=timezone) + rrules = _rrule(rrule) + dtstarts = [dts for rr in rrules if (dts := _rrule_dt(rr)) is not None] + unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts) + unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None) + + if len(unique_timezones) > 1: + raise ValueError( + f"rruleset has too many dtstart timezones: {unique_timezones}" + ) + + if len(unique_dstarts) > 1: + raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}") + + if unique_dstarts and unique_timezones: + [unique_tz] = unique_timezones + timezone = unique_tz.tzname(dtstarts[0]) else: - raise ValueError(f"Invalid RRule object: {rrule}") - - def to_rrule(self) -> dateutil.rrule.rrule: + timezone = "UTC" + + rruleset_string = "" + if rrules: + rruleset_string += "\n".join(str(r) for r in rrules) + if exrule := _rrule(rrule, "_exrule"): + rruleset_string += "\n" if rruleset_string else "" + rruleset_string += "\n".join(str(r) for r in exrule).replace( + "RRULE", "EXRULE" + ) + if rdates := _rdates(rrule): + rruleset_string += "\n" if rruleset_string else "" + rruleset_string += "RDATE:" + ",".join( + rd.strftime("%Y%m%dT%H%M%SZ") for rd in rdates + ) + if exdates := _rdates(rrule, "_exdate"): + rruleset_string += "\n" if rruleset_string else "" + rruleset_string += "EXDATE:" + ",".join( + exd.strftime("%Y%m%dT%H%M%SZ") for exd in exdates + ) + return RRuleSchedule(rrule=rruleset_string, timezone=timezone) + + def to_rrule(self) -> Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]: """ Since rrule doesn't properly serialize/deserialize timezones, we localize dates here @@ -211,51 +252,53 @@ def to_rrule(self) -> dateutil.rrule.rrule: ) timezone = dateutil.tz.gettz(self.timezone) if isinstance(rrule, dateutil.rrule.rrule): - kwargs = dict(dtstart=rrule._dtstart.replace(tzinfo=timezone)) - if rrule._until: + dtstart = _rrule_dt(rrule) + assert dtstart is not None + kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone)) + if until := _rrule_dt(rrule, "_until"): kwargs.update( - until=rrule._until.replace(tzinfo=timezone), + until=until.replace(tzinfo=timezone), ) return rrule.replace(**kwargs) - elif isinstance(rrule, dateutil.rrule.rruleset): - # update rrules - localized_rrules = [] - for rr in rrule._rrule: - kwargs = dict(dtstart=rr._dtstart.replace(tzinfo=timezone)) - if rr._until: - kwargs.update( - until=rr._until.replace(tzinfo=timezone), - ) - localized_rrules.append(rr.replace(**kwargs)) - rrule._rrule = localized_rrules - - # update exrules - localized_exrules = [] - for exr in rrule._exrule: - kwargs = dict(dtstart=exr._dtstart.replace(tzinfo=timezone)) - if exr._until: - kwargs.update( - until=exr._until.replace(tzinfo=timezone), - ) - localized_exrules.append(exr.replace(**kwargs)) - rrule._exrule = localized_exrules - - # update rdates - localized_rdates = [] - for rd in rrule._rdate: - localized_rdates.append(rd.replace(tzinfo=timezone)) - rrule._rdate = localized_rdates - - # update exdates - localized_exdates = [] - for exd in rrule._exdate: - localized_exdates.append(exd.replace(tzinfo=timezone)) - rrule._exdate = localized_exdates - - return rrule + + # update rrules + localized_rrules: list[dateutil.rrule.rrule] = [] + for rr in _rrule(rrule): + dtstart = _rrule_dt(rr) + assert dtstart is not None + kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone)) + if until := _rrule_dt(rr, "_until"): + kwargs.update(until=until.replace(tzinfo=timezone)) + localized_rrules.append(rr.replace(**kwargs)) + setattr(rrule, "_rrule", localized_rrules) + + # update exrules + localized_exrules: list[dateutil.rrule.rruleset] = [] + for exr in _rrule(rrule, "_exrule"): + dtstart = _rrule_dt(exr) + assert dtstart is not None + kwargs = dict(dtstart=dtstart.replace(tzinfo=timezone)) + if until := _rrule_dt(exr, "_until"): + kwargs.update(until=until.replace(tzinfo=timezone)) + localized_exrules.append(exr.replace(**kwargs)) + setattr(rrule, "_exrule", localized_exrules) + + # update rdates + localized_rdates: list[datetime.datetime] = [] + for rd in _rdates(rrule): + localized_rdates.append(rd.replace(tzinfo=timezone)) + setattr(rrule, "_rdate", localized_rdates) + + # update exdates + localized_exdates: list[datetime.datetime] = [] + for exd in _rdates(rrule, "_exdate"): + localized_exdates.append(exd.replace(tzinfo=timezone)) + setattr(rrule, "_exdate", localized_exdates) + + return rrule @field_validator("timezone") - def valid_timezone(cls, v): + def valid_timezone(cls, v: Optional[str]) -> str: """ Validate that the provided timezone is a valid IANA timezone. @@ -277,7 +320,7 @@ def valid_timezone(cls, v): class NoSchedule(PrefectBaseModel): - model_config = ConfigDict(extra="forbid") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") SCHEDULE_TYPES: TypeAlias = Union[ @@ -326,7 +369,7 @@ def construct_schedule( if isinstance(interval, (int, float)): interval = datetime.timedelta(seconds=interval) if not anchor_date: - anchor_date = DateTime.now() + anchor_date = pendulum.DateTime.now() schedule = IntervalSchedule( interval=interval, anchor_date=anchor_date, timezone=timezone ) diff --git a/src/prefect/client/subscriptions.py b/src/prefect/client/subscriptions.py index d13873e14b05..8e04b3735e8a 100644 --- a/src/prefect/client/subscriptions.py +++ b/src/prefect/client/subscriptions.py @@ -1,5 +1,7 @@ import asyncio -from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar +from collections.abc import Iterable +from logging import Logger +from typing import Any, Generic, Optional, TypeVar import orjson import websockets @@ -11,7 +13,7 @@ from prefect.logging import get_logger from prefect.settings import PREFECT_API_KEY -logger = get_logger(__name__) +logger: Logger = get_logger(__name__) S = TypeVar("S", bound=IDBaseModel) @@ -19,7 +21,7 @@ class Subscription(Generic[S]): def __init__( self, - model: Type[S], + model: type[S], path: str, keys: Iterable[str], client_id: Optional[str] = None, @@ -28,9 +30,9 @@ def __init__( self.model = model self.client_id = client_id base_url = base_url.replace("http", "ws", 1) if base_url else None - self.subscription_url = f"{base_url}{path}" + self.subscription_url: str = f"{base_url}{path}" - self.keys = list(keys) + self.keys: list[str] = list(keys) self._connect = websockets.connect( self.subscription_url, @@ -78,10 +80,10 @@ async def _ensure_connected(self): ).decode() ) - auth: Dict[str, Any] = orjson.loads(await websocket.recv()) + auth: dict[str, Any] = orjson.loads(await websocket.recv()) assert auth["type"] == "auth_success", auth.get("message") - message = {"type": "subscribe", "keys": self.keys} + message: dict[str, Any] = {"type": "subscribe", "keys": self.keys} if self.client_id: message.update({"client_id": self.client_id}) diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py index 81ff31199e6e..86e7be152f65 100644 --- a/src/prefect/client/utilities.py +++ b/src/prefect/client/utilities.py @@ -5,32 +5,31 @@ # This module must not import from `prefect.client` when it is imported to avoid # circular imports for decorators such as `inject_client` which are widely used. +from collections.abc import Awaitable, Coroutine from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Coroutine, - Optional, - Tuple, - TypeVar, - Union, - cast, -) - -from typing_extensions import Concatenate, ParamSpec +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient, SyncPrefectClient P = ParamSpec("P") -R = TypeVar("R") +R = TypeVar("R", infer_variance=True) + + +def _current_async_client( + client: Union["PrefectClient", "SyncPrefectClient"], +) -> TypeIs["PrefectClient"]: + from prefect._internal.concurrency.event_loop import get_running_loop + + # Only a PrefectClient will have a _loop attribute that is the current loop + return getattr(client, "_loop", None) == get_running_loop() def get_or_create_client( client: Optional["PrefectClient"] = None, -) -> Tuple[Union["PrefectClient", "SyncPrefectClient"], bool]: +) -> tuple["PrefectClient", bool]: """ Returns provided client, infers a client from context if available, or creates a new client. @@ -42,29 +41,22 @@ def get_or_create_client( """ if client is not None: return client, True - from prefect._internal.concurrency.event_loop import get_running_loop + from prefect.context import AsyncClientContext, FlowRunContext, TaskRunContext async_client_context = AsyncClientContext.get() flow_run_context = FlowRunContext.get() task_run_context = TaskRunContext.get() - if async_client_context and async_client_context.client._loop == get_running_loop(): # type: ignore[reportPrivateUsage] - return async_client_context.client, True - elif ( - flow_run_context - and getattr(flow_run_context.client, "_loop", None) == get_running_loop() - ): - return flow_run_context.client, True - elif ( - task_run_context - and getattr(task_run_context.client, "_loop", None) == get_running_loop() - ): - return task_run_context.client, True - else: - from prefect.client.orchestration import get_client as get_httpx_client + for context in (async_client_context, flow_run_context, task_run_context): + if context is None: + continue + if _current_async_client(context_client := context.client): + return context_client, True + + from prefect.client.orchestration import get_client as get_httpx_client - return get_httpx_client(), False + return get_httpx_client(), False def client_injector( @@ -73,7 +65,7 @@ def client_injector( @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: client, _ = get_or_create_client() - return await func(cast("PrefectClient", client), *args, **kwargs) + return await func(client, *args, **kwargs) return wrapper @@ -91,18 +83,18 @@ def inject_client( @wraps(fn) async def with_injected_client(*args: P.args, **kwargs: P.kwargs) -> R: - client, inferred = get_or_create_client( - cast(Optional["PrefectClient"], kwargs.pop("client", None)) - ) - _client = cast("PrefectClient", client) + given = kwargs.pop("client", None) + if TYPE_CHECKING: + assert given is None or isinstance(given, PrefectClient) + client, inferred = get_or_create_client(given) if not inferred: - context = _client + context = client else: from prefect.utilities.asyncutils import asyncnullcontext - context = asyncnullcontext() + context = asyncnullcontext(client) async with context as new_client: - kwargs.setdefault("client", new_client or _client) + kwargs |= {"client": new_client} return await fn(*args, **kwargs) return with_injected_client diff --git a/src/prefect/main.py b/src/prefect/main.py index 4fea3999e2ad..0d56990c829d 100644 --- a/src/prefect/main.py +++ b/src/prefect/main.py @@ -1,4 +1,6 @@ # Import user-facing API +from typing import Any + from prefect.deployments import deploy from prefect.states import State from prefect.logging import get_run_logger @@ -25,28 +27,17 @@ # Perform any forward-ref updates needed for Pydantic models import prefect.client.schemas -prefect.context.FlowRunContext.model_rebuild( - _types_namespace={ - "Flow": Flow, - "BaseResult": BaseResult, - "ResultRecordMetadata": ResultRecordMetadata, - } -) -prefect.context.TaskRunContext.model_rebuild( - _types_namespace={"Task": Task, "BaseResult": BaseResult} -) -prefect.client.schemas.State.model_rebuild( - _types_namespace={ - "BaseResult": BaseResult, - "ResultRecordMetadata": ResultRecordMetadata, - } -) -prefect.client.schemas.StateCreate.model_rebuild( - _types_namespace={ - "BaseResult": BaseResult, - "ResultRecordMetadata": ResultRecordMetadata, - } +_types: dict[str, Any] = dict( + Task=Task, + Flow=Flow, + BaseResult=BaseResult, + ResultRecordMetadata=ResultRecordMetadata, ) +prefect.context.FlowRunContext.model_rebuild(_types_namespace=_types) +prefect.context.TaskRunContext.model_rebuild(_types_namespace=_types) +prefect.client.schemas.State.model_rebuild(_types_namespace=_types) +prefect.client.schemas.StateCreate.model_rebuild(_types_namespace=_types) +prefect.client.schemas.OrchestrationResult.model_rebuild(_types_namespace=_types) Transaction.model_rebuild() # Configure logging