Skip to content

Commit

Permalink
[typing] prefect._internal, prefect.server.utilities (#16497)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpieters authored Dec 26, 2024
1 parent bafed7c commit 4386cb5
Show file tree
Hide file tree
Showing 22 changed files with 489 additions and 606 deletions.
18 changes: 15 additions & 3 deletions src/prefect/_internal/_logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import logging
import sys

from typing_extensions import Self

if sys.version_info < (3, 11):

def getLevelNamesMapping() -> dict[str, int]:
return getattr(logging, "_nameToLevel").copy()
else:
getLevelNamesMapping = logging.getLevelNamesMapping # novermin


class SafeLogger(logging.Logger):
Expand All @@ -11,11 +21,13 @@ def isEnabledFor(self, level: int):
# deadlocks during complex concurrency handling
from prefect.settings import PREFECT_LOGGING_INTERNAL_LEVEL

return level >= logging._nameToLevel[PREFECT_LOGGING_INTERNAL_LEVEL.value()]
internal_level = getLevelNamesMapping()[PREFECT_LOGGING_INTERNAL_LEVEL.value()]

return level >= internal_level

def getChild(self, suffix: str):
def getChild(self, suffix: str) -> Self:
logger = super().getChild(suffix)
logger.__class__ = SafeLogger
logger.__class__ = self.__class__
return logger


Expand Down
38 changes: 22 additions & 16 deletions src/prefect/_internal/compatibility/async_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import inspect
from collections.abc import Coroutine
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from prefect.flows import Flow
from prefect.tasks import Task

R = TypeVar("R")
Expand Down Expand Up @@ -41,7 +43,9 @@ def is_in_async_context() -> bool:


def _is_acceptable_callable(
obj: Union[Callable[P, R], "Task[P, R]", classmethod],
obj: Union[
Callable[P, R], "Flow[P, R]", "Task[P, R]", "classmethod[type[Any], P, R]"
],
) -> bool:
if inspect.iscoroutinefunction(obj):
return True
Expand All @@ -58,35 +62,37 @@ def _is_acceptable_callable(


def async_dispatch(
async_impl: Callable[P, Coroutine[Any, Any, R]],
async_impl: Union[
Callable[P, Coroutine[Any, Any, R]],
"classmethod[type[Any], P, Coroutine[Any, Any, R]]",
],
) -> Callable[[Callable[P, R]], Callable[P, Union[R, Coroutine[Any, Any, R]]]]:
"""
Decorator that dispatches to either sync or async implementation based on context.
Args:
async_impl: The async implementation to dispatch to when in async context
"""
if not _is_acceptable_callable(async_impl):
raise TypeError("async_impl must be an async function")
if isinstance(async_impl, classmethod):
async_impl = cast(Callable[P, Coroutine[Any, Any, R]], async_impl.__func__)

def decorator(
sync_fn: Callable[P, R],
) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]:
if not _is_acceptable_callable(async_impl):
raise TypeError("async_impl must be an async function")

@wraps(sync_fn)
def wrapper(
*args: P.args,
_sync: Optional[bool] = None, # type: ignore
**kwargs: P.kwargs,
) -> Union[R, Coroutine[Any, Any, R]]:
should_run_sync = _sync if _sync is not None else not is_in_async_context()

if should_run_sync:
return sync_fn(*args, **kwargs)
if isinstance(async_impl, classmethod):
return async_impl.__func__(*args, **kwargs)
return async_impl(*args, **kwargs)

return wrapper # type: ignore
_sync = kwargs.pop("_sync", None)
should_run_sync = (
bool(_sync) if _sync is not None else not is_in_async_context()
)
fn = sync_fn if should_run_sync else async_impl
return fn(*args, **kwargs)

return wrapper

return decorator
58 changes: 41 additions & 17 deletions src/prefect/_internal/compatibility/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import functools
import sys
import warnings
from typing import Any, Callable, List, Optional, Type, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import pendulum
from pydantic import BaseModel
from typing_extensions import ParamSpec, TypeAlias, TypeVar

from prefect.utilities.callables import get_call_parameters
from prefect.utilities.importtools import (
Expand All @@ -25,8 +26,10 @@
to_qualified_name,
)

T = TypeVar("T", bound=Callable)
P = ParamSpec("P")
R = TypeVar("R", infer_variance=True)
M = TypeVar("M", bound=BaseModel)
T = TypeVar("T")


DEPRECATED_WARNING = (
Expand All @@ -38,7 +41,7 @@
"path after {end_date}. {help}"
)
DEPRECATED_DATEFMT = "MMM YYYY" # e.g. Feb 2023
DEPRECATED_MODULE_ALIASES: List[AliasedModuleDefinition] = []
DEPRECATED_MODULE_ALIASES: list[AliasedModuleDefinition] = []


class PrefectDeprecationWarning(DeprecationWarning):
Expand All @@ -61,6 +64,8 @@ def generate_deprecation_message(
)

if not end_date:
if TYPE_CHECKING:
assert start_date is not None
parsed_start_date = pendulum.from_format(start_date, DEPRECATED_DATEFMT)
parsed_end_date = parsed_start_date.add(months=6)
end_date = parsed_end_date.format(DEPRECATED_DATEFMT)
Expand All @@ -83,8 +88,8 @@ def deprecated_callable(
end_date: Optional[str] = None,
stacklevel: int = 2,
help: str = "",
) -> Callable[[T], T]:
def decorator(fn: T):
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
message = generate_deprecation_message(
name=to_qualified_name(fn),
start_date=start_date,
Expand All @@ -93,7 +98,7 @@ def decorator(fn: T):
)

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel)
return fn(*args, **kwargs)

Expand All @@ -108,8 +113,8 @@ def deprecated_class(
end_date: Optional[str] = None,
stacklevel: int = 2,
help: str = "",
) -> Callable[[T], T]:
def decorator(cls: T):
) -> Callable[[type[T]], type[T]]:
def decorator(cls: type[T]) -> type[T]:
message = generate_deprecation_message(
name=to_qualified_name(cls),
start_date=start_date,
Expand All @@ -120,7 +125,7 @@ def decorator(cls: T):
original_init = cls.__init__

@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
def new_init(self: T, *args: Any, **kwargs: Any) -> None:
warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel)
original_init(self, *args, **kwargs)

Expand All @@ -139,7 +144,7 @@ def deprecated_parameter(
help: str = "",
when: Optional[Callable[[Any], bool]] = None,
when_message: str = "",
) -> Callable[[T], T]:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Mark a parameter in a callable as deprecated.
Expand All @@ -155,7 +160,7 @@ def foo(x, y = None):

when = when or (lambda _: True)

def decorator(fn: T):
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
message = generate_deprecation_message(
name=f"The parameter {name!r} for {fn.__name__!r}",
start_date=start_date,
Expand All @@ -165,7 +170,7 @@ def decorator(fn: T):
)

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
parameters = get_call_parameters(fn, args, kwargs, apply_defaults=False)
except Exception:
Expand All @@ -182,6 +187,10 @@ def wrapper(*args, **kwargs):
return decorator


JsonValue: TypeAlias = Union[int, float, str, bool, None, list["JsonValue"], "JsonDict"]
JsonDict: TypeAlias = dict[str, JsonValue]


def deprecated_field(
name: str,
*,
Expand All @@ -191,7 +200,7 @@ def deprecated_field(
help: str = "",
when: Optional[Callable[[Any], bool]] = None,
stacklevel: int = 2,
):
) -> Callable[[type[M]], type[M]]:
"""
Mark a field in a Pydantic model as deprecated.
Expand All @@ -212,7 +221,7 @@ class Model(BaseModel)

# Replaces the model's __init__ method with one that performs an additional warning
# check
def decorator(model_cls: Type[M]) -> Type[M]:
def decorator(model_cls: type[M]) -> type[M]:
message = generate_deprecation_message(
name=f"The field {name!r} in {model_cls.__name__!r}",
start_date=start_date,
Expand All @@ -224,16 +233,31 @@ def decorator(model_cls: Type[M]) -> Type[M]:
cls_init = model_cls.__init__

@functools.wraps(model_cls.__init__)
def __init__(__pydantic_self__, **data: Any) -> None:
def __init__(__pydantic_self__: M, **data: Any) -> None:
if name in data.keys() and when(data[name]):
warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel)

cls_init(__pydantic_self__, **data)

field = __pydantic_self__.model_fields.get(name)
if field is not None:
field.json_schema_extra = field.json_schema_extra or {}
field.json_schema_extra["deprecated"] = True
json_schema_extra = field.json_schema_extra or {}

if not isinstance(json_schema_extra, dict):
# json_schema_extra is a hook function; wrap it to add the deprecated flag.
extra_func = json_schema_extra

@functools.wraps(extra_func)
def wrapped(__json_schema: JsonDict) -> None:
extra_func(__json_schema)
__json_schema["deprecated"] = True

json_schema_extra = wrapped

else:
json_schema_extra["deprecated"] = True

field.json_schema_extra = json_schema_extra

# Patch the model's init method
model_cls.__init__ = __init__
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/_internal/compatibility/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"""

import sys
from typing import Any, Callable, Dict
from typing import Any, Callable

from pydantic_core import PydanticCustomError

Expand Down Expand Up @@ -157,7 +157,7 @@ def wrapper(name: str) -> object:
f"`{import_path}` has been removed. {error_message}"
)

globals: Dict[str, Any] = sys.modules[module_name].__dict__
globals: dict[str, Any] = sys.modules[module_name].__dict__
if name in globals:
return globals[name]

Expand Down
26 changes: 12 additions & 14 deletions src/prefect/_internal/concurrency/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import threading
from types import FrameType
from typing import List, Optional

"""
The following functions are derived from dask/distributed which is licensed under the
Expand Down Expand Up @@ -72,26 +71,25 @@ def repr_frame(frame: FrameType) -> str:
return text + "\n\t" + line


def call_stack(frame: FrameType) -> List[str]:
def call_stack(frame: FrameType) -> list[str]:
"""Create a call text stack from a frame"""
L = []
cur_frame: Optional[FrameType] = frame
frames: list[str] = []
cur_frame = frame
while cur_frame:
L.append(repr_frame(cur_frame))
frames.append(repr_frame(cur_frame))
cur_frame = cur_frame.f_back
return L[::-1]
return frames[::-1]


def stack_for_threads(*threads: threading.Thread) -> List[str]:
frames = sys._current_frames()
def stack_for_threads(*threads: threading.Thread) -> list[str]:
frames = sys._current_frames() # pyright: ignore[reportPrivateUsage]
try:
lines = []
lines: list[str] = []
for thread in threads:
lines.append(
f"------ Call stack of {thread.name} ({hex(thread.ident)}) -----"
)
thread_frames = frames.get(thread.ident)
if thread_frames:
ident = thread.ident
hex_ident = hex(ident) if ident is not None else "<unknown>"
lines.append(f"------ Call stack of {thread.name} ({hex_ident}) -----")
if ident is not None and (thread_frames := frames.get(ident)):
lines.append("".join(call_stack(thread_frames)))
else:
lines.append("No stack frames found")
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/_internal/concurrency/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Event:
"""

def __init__(self) -> None:
self._waiters = collections.deque()
self._waiters: collections.deque[asyncio.Future[bool]] = collections.deque()
self._value = False
self._lock = threading.Lock()

Expand Down Expand Up @@ -69,7 +69,7 @@ async def wait(self) -> Literal[True]:
if self._value:
return True

fut = asyncio.get_running_loop().create_future()
fut: asyncio.Future[bool] = asyncio.get_running_loop().create_future()
self._waiters.append(fut)

try:
Expand Down
Loading

0 comments on commit 4386cb5

Please sign in to comment.