Skip to content

Commit

Permalink
[typing] prefect.client
Browse files Browse the repository at this point in the history
Code now passes pyright checking in strict mode.
  • Loading branch information
mjpieters committed Dec 7, 2024
1 parent 1c30778 commit 735ecb1
Show file tree
Hide file tree
Showing 13 changed files with 817 additions and 697 deletions.
3 changes: 2 additions & 1 deletion src/prefect/_internal/schemas/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from copy import copy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
from uuid import UUID

import jsonschema
import pendulum
Expand Down Expand Up @@ -653,7 +654,7 @@ def validate_message_template_variables(v: Optional[str]) -> Optional[str]:
return v


def validate_default_queue_id_not_none(v: Optional[str]) -> Optional[str]:
def validate_default_queue_id_not_none(v: Optional[UUID]) -> UUID:
if v is None:
raise ValueError(
"`default_queue_id` is a required field. If you are "
Expand Down
4 changes: 3 additions & 1 deletion src/prefect/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
</div>
"""

from collections.abc import Callable
from typing import Any
from prefect._internal.compatibility.migration import getattr_migration

__getattr__ = getattr_migration(__name__)
__getattr__: Callable[[str], Any] = getattr_migration(__name__)
55 changes: 27 additions & 28 deletions src/prefect/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,11 @@
import time
import uuid
from collections import defaultdict
from collections.abc import AsyncGenerator, Awaitable, MutableMapping
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
MutableMapping,
Optional,
Protocol,
Set,
Tuple,
Type,
runtime_checkable,
)
from logging import Logger
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable

import anyio
import httpx
Expand All @@ -46,14 +35,14 @@

# Datastores for lifespan management, keys should be a tuple of thread and app
# identities.
APP_LIFESPANS: Dict[Tuple[int, int], LifespanManager] = {}
APP_LIFESPANS_REF_COUNTS: Dict[Tuple[int, int], int] = {}
APP_LIFESPANS: dict[tuple[int, int], LifespanManager] = {}
APP_LIFESPANS_REF_COUNTS: dict[tuple[int, int], int] = {}
# Blocks concurrent access to the above dicts per thread. The index should be the thread
# identity.
APP_LIFESPANS_LOCKS: Dict[int, anyio.Lock] = defaultdict(anyio.Lock)
APP_LIFESPANS_LOCKS: dict[int, anyio.Lock] = defaultdict(anyio.Lock)


logger = get_logger("client")
logger: Logger = get_logger("client")


# Define ASGI application types for type checking
Expand Down Expand Up @@ -174,9 +163,9 @@ def raise_for_status(self) -> Response:
raise PrefectHTTPStatusError.from_httpx_error(exc) from exc.__cause__

@classmethod
def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Response:
def from_httpx_response(cls: type[Self], response: httpx.Response) -> Response:
"""
Create a `PrefectReponse` from an `httpx.Response`.
Create a `PrefectResponse` from an `httpx.Response`.
By changing the `__class__` attribute of the Response, we change the method
resolution order to look for methods defined in PrefectResponse, while leaving
Expand Down Expand Up @@ -222,10 +211,10 @@ async def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Awaitable[Response]],
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
send_args: tuple[Any, ...],
send_kwargs: dict[str, Any],
retry_codes: set[int] = set(),
retry_exceptions: tuple[type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
Expand All @@ -240,6 +229,11 @@ async def _send_with_retry(
try_count = 0
response = None

if TYPE_CHECKING:
# older httpx versions type method as str | bytes | Unknown
# but in reality it is always a string.
assert isinstance(request.method, str) # type: ignore

is_change_request = request.method.lower() in {"post", "put", "patch", "delete"}

if self.enable_csrf_support and is_change_request:
Expand Down Expand Up @@ -436,10 +430,10 @@ def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Response],
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
send_args: tuple[Any, ...],
send_kwargs: dict[str, Any],
retry_codes: set[int] = set(),
retry_exceptions: tuple[type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
Expand All @@ -454,6 +448,11 @@ def _send_with_retry(
try_count = 0
response = None

if TYPE_CHECKING:
# older httpx versions type method as str | bytes | Unknown
# but in reality it is always a string.
assert isinstance(request.method, str) # type: ignore

is_change_request = request.method.lower() in {"post", "put", "patch", "delete"}

if self.enable_csrf_support and is_change_request:
Expand Down
36 changes: 20 additions & 16 deletions src/prefect/client/cloud.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import re
from typing import Any, Dict, List, Optional, cast
from typing import Any, NoReturn, Optional, cast
from uuid import UUID

import anyio
import httpx
import pydantic
from starlette import status
from typing_extensions import Self

import prefect.context
import prefect.settings
Expand All @@ -30,7 +31,7 @@
def get_cloud_client(
host: Optional[str] = None,
api_key: Optional[str] = None,
httpx_settings: Optional[Dict[str, Any]] = None,
httpx_settings: Optional[dict[str, Any]] = None,
infer_cloud_url: bool = False,
) -> "CloudClient":
"""
Expand Down Expand Up @@ -62,11 +63,14 @@ class CloudUnauthorizedError(PrefectException):


class CloudClient:
account_id: Optional[str] = None
workspace_id: Optional[str] = None

def __init__(
self,
host: str,
api_key: str,
httpx_settings: Optional[Dict[str, Any]] = None,
httpx_settings: Optional[dict[str, Any]] = None,
) -> None:
httpx_settings = httpx_settings or dict()
httpx_settings.setdefault("headers", dict())
Expand All @@ -79,7 +83,7 @@ def __init__(
**httpx_settings, enable_csrf_support=False
)

api_url = prefect.settings.PREFECT_API_URL.value() or ""
api_url: str = prefect.settings.PREFECT_API_URL.value() or ""
if match := (
re.search(PARSE_API_URL_REGEX, host)
or re.search(PARSE_API_URL_REGEX, api_url)
Expand All @@ -100,7 +104,7 @@ def workspace_base_url(self) -> str:

return f"{self.account_base_url}/workspaces/{self.workspace_id}"

async def api_healthcheck(self):
async def api_healthcheck(self) -> None:
"""
Attempts to connect to the Cloud API and raises the encountered exception if not
successful.
Expand All @@ -110,8 +114,8 @@ async def api_healthcheck(self):
with anyio.fail_after(10):
await self.read_workspaces()

async def read_workspaces(self) -> List[Workspace]:
workspaces = pydantic.TypeAdapter(List[Workspace]).validate_python(
async def read_workspaces(self) -> list[Workspace]:
workspaces = pydantic.TypeAdapter(list[Workspace]).validate_python(
await self.get("/me/workspaces")
)
return workspaces
Expand All @@ -124,17 +128,17 @@ async def read_current_workspace(self) -> Workspace:
return workspace
raise ValueError("Current workspace not found")

async def read_worker_metadata(self) -> Dict[str, Any]:
async def read_worker_metadata(self) -> dict[str, Any]:
response = await self.get(
f"{self.workspace_base_url}/collections/work_pool_types"
)
return cast(Dict[str, Any], response)
return cast(dict[str, Any], response)

async def read_account_settings(self) -> Dict[str, Any]:
async def read_account_settings(self) -> dict[str, Any]:
response = await self.get(f"{self.account_base_url}/settings")
return cast(Dict[str, Any], response)
return cast(dict[str, Any], response)

async def update_account_settings(self, settings: Dict[str, Any]):
async def update_account_settings(self, settings: dict[str, Any]) -> None:
await self.request(
"PATCH",
f"{self.account_base_url}/settings",
Expand All @@ -145,7 +149,7 @@ async def read_account_ip_allowlist(self) -> IPAllowlist:
response = await self.get(f"{self.account_base_url}/ip_allowlist")
return IPAllowlist.model_validate(response)

async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist):
async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist) -> None:
await self.request(
"PUT",
f"{self.account_base_url}/ip_allowlist",
Expand Down Expand Up @@ -175,20 +179,20 @@ async def update_flow_run_labels(
json=labels,
)

async def __aenter__(self):
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

async def __aexit__(self, *exc_info: Any) -> None:
return await self._client.__aexit__(*exc_info)

def __enter__(self):
def __enter__(self) -> NoReturn:
raise RuntimeError(
"The `CloudClient` must be entered with an async context. Use 'async "
"with CloudClient(...)' not 'with CloudClient(...)'"
)

def __exit__(self, *_):
def __exit__(self, *_: object) -> NoReturn:
assert False, "This should never be called but must be defined for __enter__"

async def get(self, route: str, **kwargs: Any) -> Any:
Expand Down
Loading

0 comments on commit 735ecb1

Please sign in to comment.