Skip to content

Commit

Permalink
strict type client modules (#16223)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Dec 5, 2024
1 parent be581bb commit 6308207
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 43 deletions.
32 changes: 16 additions & 16 deletions src/prefect/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class PrefectResponse(httpx.Response):
Provides more informative error messages.
"""

def raise_for_status(self) -> None:
def raise_for_status(self) -> Response:
"""
Raise an exception if the response contains an HTTPStatusError.
Expand All @@ -174,7 +174,7 @@ def raise_for_status(self) -> None:
raise PrefectHTTPStatusError.from_httpx_error(exc) from exc.__cause__

@classmethod
def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Self:
def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Response:
"""
Create a `PrefectReponse` from an `httpx.Response`.
Expand All @@ -200,10 +200,10 @@ class PrefectHttpxAsyncClient(httpx.AsyncClient):

def __init__(
self,
*args,
*args: Any,
enable_csrf_support: bool = False,
raise_on_all_errors: bool = True,
**kwargs,
**kwargs: Any,
):
self.enable_csrf_support: bool = enable_csrf_support
self.csrf_token: Optional[str] = None
Expand All @@ -222,10 +222,10 @@ async def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Awaitable[Response]],
send_args: Tuple,
send_kwargs: Dict,
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Exception, ...] = tuple(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
Expand Down Expand Up @@ -297,7 +297,7 @@ async def _send_with_retry(
if exc_info
else (
"Received response with retryable status code"
f" {response.status_code}. "
f" {response.status_code if response else 'unknown'}. "
)
)
+ f"Another attempt will be made in {retry_seconds}s. "
Expand All @@ -314,7 +314,7 @@ async def _send_with_retry(
# We ran out of retries, return the failed response
return response

async def send(self, request: Request, *args, **kwargs) -> Response:
async def send(self, request: Request, *args: Any, **kwargs: Any) -> Response:
"""
Send a request with automatic retry behavior for the following status codes:
Expand Down Expand Up @@ -414,10 +414,10 @@ class PrefectHttpxSyncClient(httpx.Client):

def __init__(
self,
*args,
*args: Any,
enable_csrf_support: bool = False,
raise_on_all_errors: bool = True,
**kwargs,
**kwargs: Any,
):
self.enable_csrf_support: bool = enable_csrf_support
self.csrf_token: Optional[str] = None
Expand All @@ -436,10 +436,10 @@ def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Response],
send_args: Tuple,
send_kwargs: Dict,
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Exception, ...] = tuple(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
Expand Down Expand Up @@ -511,7 +511,7 @@ def _send_with_retry(
if exc_info
else (
"Received response with retryable status code"
f" {response.status_code}. "
f" {response.status_code if response else 'unknown'}. "
)
)
+ f"Another attempt will be made in {retry_seconds}s. "
Expand All @@ -528,7 +528,7 @@ def _send_with_retry(
# We ran out of retries, return the failed response
return response

def send(self, request: Request, *args, **kwargs) -> Response:
def send(self, request: Request, *args: Any, **kwargs: Any) -> Response:
"""
Send a request with automatic retry behavior for the following status codes:
Expand Down
11 changes: 7 additions & 4 deletions src/prefect/client/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
def get_cloud_client(
host: Optional[str] = None,
api_key: Optional[str] = None,
httpx_settings: Optional[dict] = None,
httpx_settings: Optional[Dict[str, Any]] = None,
infer_cloud_url: bool = False,
) -> "CloudClient":
"""
Expand All @@ -45,6 +45,9 @@ def get_cloud_client(
configured_url = prefect.settings.PREFECT_API_URL.value()
host = re.sub(PARSE_API_URL_REGEX, "", configured_url)

if host is None:
raise ValueError("Host was not provided and could not be inferred")

return CloudClient(
host=host,
api_key=api_key or PREFECT_API_KEY.value(),
Expand Down Expand Up @@ -176,7 +179,7 @@ async def __aenter__(self):
await self._client.__aenter__()
return self

async def __aexit__(self, *exc_info):
async def __aexit__(self, *exc_info: Any) -> None:
return await self._client.__aexit__(*exc_info)

def __enter__(self):
Expand All @@ -188,10 +191,10 @@ def __enter__(self):
def __exit__(self, *_):
assert False, "This should never be called but must be defined for __enter__"

async def get(self, route, **kwargs):
async def get(self, route: str, **kwargs: Any) -> Any:
return await self.request("GET", route, **kwargs)

async def request(self, method, route, **kwargs):
async def request(self, method: str, route: str, **kwargs: Any) -> Any:
try:
res = await self._client.request(method, route, **kwargs)
res.raise_for_status()
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/client/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ async def read_worker_metadata(self) -> Dict[str, Any]:
async def __aenter__(self) -> "CollectionsMetadataClient":
...

async def __aexit__(self, *exc_info) -> Any:
async def __aexit__(self, *exc_info: Any) -> Any:
...


def get_collections_metadata_client(
httpx_settings: Optional[Dict] = None,
httpx_settings: Optional[Dict[str, Any]] = None,
) -> "CollectionsMetadataClient":
"""
Creates a client that can be used to fetch metadata for
Expand Down
26 changes: 19 additions & 7 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,33 @@ def __init__(
):
self.model = model
self.client_id = client_id
base_url = base_url.replace("http", "ws", 1)
base_url = base_url.replace("http", "ws", 1) if base_url else None
self.subscription_url = f"{base_url}{path}"

self.keys = list(keys)

self._connect = websockets.connect(
self.subscription_url,
subprotocols=["prefect"],
subprotocols=[websockets.Subprotocol("prefect")],
)
self._websocket = None

def __aiter__(self) -> Self:
return self

@property
def websocket(self) -> websockets.WebSocketClientProtocol:
if not self._websocket:
raise RuntimeError("Subscription is not connected")
return self._websocket

async def __anext__(self) -> S:
while True:
try:
await self._ensure_connected()
message = await self._websocket.recv()
message = await self.websocket.recv()

await self._websocket.send(orjson.dumps({"type": "ack"}).decode())
await self.websocket.send(orjson.dumps({"type": "ack"}).decode())

return self.model.model_validate_json(message)
except (
Expand Down Expand Up @@ -84,13 +90,19 @@ async def _ensure_connected(self):
AssertionError,
websockets.exceptions.ConnectionClosedError,
) as e:
if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
if isinstance(e, AssertionError) or (
e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION
):
if isinstance(e, AssertionError):
reason = e.args[0]
elif isinstance(e, websockets.exceptions.ConnectionClosedError):
elif e.rcvd and e.rcvd.reason:
reason = e.rcvd.reason
else:
reason = "unknown"
else:
reason = None

if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
if reason:
raise Exception(
"Unable to authenticate to the subscription. Please "
"ensure the provided `PREFECT_API_KEY` you are using is "
Expand Down
19 changes: 11 additions & 8 deletions src/prefect/client/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
Optional,
Tuple,
TypeVar,
Union,
cast,
)

from typing_extensions import Concatenate, ParamSpec

if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
from prefect.client.orchestration import PrefectClient, SyncPrefectClient

P = ParamSpec("P")
R = TypeVar("R")


def get_or_create_client(
client: Optional["PrefectClient"] = None,
) -> Tuple["PrefectClient", bool]:
) -> Tuple[Union["PrefectClient", "SyncPrefectClient"], bool]:
"""
Returns provided client, infers a client from context if available, or creates a new client.
Expand All @@ -48,7 +49,7 @@ def get_or_create_client(
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if async_client_context and async_client_context.client._loop == get_running_loop():
if async_client_context and async_client_context.client._loop == get_running_loop(): # type: ignore[reportPrivateUsage]
return async_client_context.client, True
elif (
flow_run_context
Expand All @@ -72,7 +73,7 @@ def client_injector(
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
client, _ = get_or_create_client()
return await func(client, *args, **kwargs)
return await func(cast("PrefectClient", client), *args, **kwargs)

return wrapper

Expand All @@ -90,16 +91,18 @@ def inject_client(

@wraps(fn)
async def with_injected_client(*args: P.args, **kwargs: P.kwargs) -> R:
client = cast(Optional["PrefectClient"], kwargs.pop("client", None))
client, inferred = get_or_create_client(client)
client, inferred = get_or_create_client(
cast(Optional["PrefectClient"], kwargs.pop("client", None))
)
_client = cast("PrefectClient", client)
if not inferred:
context = client
context = _client
else:
from prefect.utilities.asyncutils import asyncnullcontext

context = asyncnullcontext()
async with context as new_client:
kwargs.setdefault("client", new_client or client)
kwargs.setdefault("client", new_client or _client)
return await fn(*args, **kwargs)

return with_injected_client
5 changes: 4 additions & 1 deletion src/prefect/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import partial, wraps
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Expand Down Expand Up @@ -410,7 +411,9 @@ async def ctx_call():


@asynccontextmanager
async def asyncnullcontext(value=None, *args, **kwargs):
async def asyncnullcontext(
value: Optional[Any] = None, *args: Any, **kwargs: Any
) -> AsyncGenerator[Any, None]:
yield value


Expand Down
14 changes: 9 additions & 5 deletions src/prefect/utilities/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import random


def poisson_interval(average_interval, lower=0, upper=1):
def poisson_interval(
average_interval: float, lower: float = 0, upper: float = 1
) -> float:
"""
Generates an "inter-arrival time" for a Poisson process.
Expand All @@ -16,12 +18,12 @@ def poisson_interval(average_interval, lower=0, upper=1):
return -math.log(max(1 - random.uniform(lower, upper), 1e-10)) * average_interval


def exponential_cdf(x, average_interval):
def exponential_cdf(x: float, average_interval: float) -> float:
ld = 1 / average_interval
return 1 - math.exp(-ld * x)


def lower_clamp_multiple(k):
def lower_clamp_multiple(k: float) -> float:
"""
Computes a lower clamp multiple that can be used to bound a random variate drawn
from an exponential distribution.
Expand All @@ -38,7 +40,9 @@ def lower_clamp_multiple(k):
return math.log(max(2**k / (2**k - 1), 1e-10), 2)


def clamped_poisson_interval(average_interval, clamping_factor=0.3):
def clamped_poisson_interval(
average_interval: float, clamping_factor: float = 0.3
) -> float:
"""
Bounds Poisson "inter-arrival times" to a range defined by the clamping factor.
Expand All @@ -57,7 +61,7 @@ def clamped_poisson_interval(average_interval, clamping_factor=0.3):
return poisson_interval(average_interval, lower_rv, upper_rv)


def bounded_poisson_interval(lower_bound, upper_bound):
def bounded_poisson_interval(lower_bound: float, upper_bound: float) -> float:
"""
Bounds Poisson "inter-arrival times" to a range.
Expand Down

0 comments on commit 6308207

Please sign in to comment.