Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] prefect._internal, prefect.server.utilities #16497

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/prefect/_internal/_logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import logging
import sys

from typing_extensions import Self

if sys.version_info < (3, 11):

def getLevelNamesMapping() -> dict[str, int]:
return getattr(logging, "_nameToLevel").copy()
else:
getLevelNamesMapping = logging.getLevelNamesMapping # novermin


class SafeLogger(logging.Logger):
Expand All @@ -11,11 +21,13 @@ def isEnabledFor(self, level: int):
# deadlocks during complex concurrency handling
from prefect.settings import PREFECT_LOGGING_INTERNAL_LEVEL

return level >= logging._nameToLevel[PREFECT_LOGGING_INTERNAL_LEVEL.value()]
internal_level = getLevelNamesMapping()[PREFECT_LOGGING_INTERNAL_LEVEL.value()]

return level >= internal_level

def getChild(self, suffix: str):
def getChild(self, suffix: str) -> Self:
logger = super().getChild(suffix)
logger.__class__ = SafeLogger
logger.__class__ = self.__class__
return logger


Expand Down
38 changes: 22 additions & 16 deletions src/prefect/_internal/compatibility/async_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import inspect
from collections.abc import Coroutine
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from prefect.flows import Flow
from prefect.tasks import Task

R = TypeVar("R")
Expand Down Expand Up @@ -41,7 +43,9 @@ def is_in_async_context() -> bool:


def _is_acceptable_callable(
obj: Union[Callable[P, R], "Task[P, R]", classmethod],
obj: Union[
Callable[P, R], "Flow[P, R]", "Task[P, R]", "classmethod[type[Any], P, R]"
],
) -> bool:
if inspect.iscoroutinefunction(obj):
return True
Expand All @@ -58,35 +62,37 @@ def _is_acceptable_callable(


def async_dispatch(
async_impl: Callable[P, Coroutine[Any, Any, R]],
async_impl: Union[
Callable[P, Coroutine[Any, Any, R]],
"classmethod[type[Any], P, Coroutine[Any, Any, R]]",
],
) -> Callable[[Callable[P, R]], Callable[P, Union[R, Coroutine[Any, Any, R]]]]:
"""
Decorator that dispatches to either sync or async implementation based on context.

Args:
async_impl: The async implementation to dispatch to when in async context
"""
if not _is_acceptable_callable(async_impl):
raise TypeError("async_impl must be an async function")
if isinstance(async_impl, classmethod):
async_impl = cast(Callable[P, Coroutine[Any, Any, R]], async_impl.__func__)

def decorator(
sync_fn: Callable[P, R],
) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]:
if not _is_acceptable_callable(async_impl):
raise TypeError("async_impl must be an async function")

@wraps(sync_fn)
def wrapper(
*args: P.args,
_sync: Optional[bool] = None, # type: ignore
**kwargs: P.kwargs,
) -> Union[R, Coroutine[Any, Any, R]]:
should_run_sync = _sync if _sync is not None else not is_in_async_context()

if should_run_sync:
return sync_fn(*args, **kwargs)
if isinstance(async_impl, classmethod):
return async_impl.__func__(*args, **kwargs)
return async_impl(*args, **kwargs)

return wrapper # type: ignore
_sync = kwargs.pop("_sync", None)
should_run_sync = (
bool(_sync) if _sync is not None else not is_in_async_context()
)
fn = sync_fn if should_run_sync else async_impl
return fn(*args, **kwargs)

return wrapper

return decorator
58 changes: 41 additions & 17 deletions src/prefect/_internal/compatibility/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import functools
import sys
import warnings
from typing import Any, Callable, List, Optional, Type, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import pendulum
from pydantic import BaseModel
from typing_extensions import ParamSpec, TypeAlias, TypeVar

from prefect.utilities.callables import get_call_parameters
from prefect.utilities.importtools import (
Expand All @@ -25,8 +26,10 @@
to_qualified_name,
)

T = TypeVar("T", bound=Callable)
P = ParamSpec("P")
R = TypeVar("R", infer_variance=True)
M = TypeVar("M", bound=BaseModel)
T = TypeVar("T")


DEPRECATED_WARNING = (
Expand All @@ -38,7 +41,7 @@
"path after {end_date}. {help}"
)
DEPRECATED_DATEFMT = "MMM YYYY" # e.g. Feb 2023
DEPRECATED_MODULE_ALIASES: List[AliasedModuleDefinition] = []
DEPRECATED_MODULE_ALIASES: list[AliasedModuleDefinition] = []


class PrefectDeprecationWarning(DeprecationWarning):
Expand All @@ -61,6 +64,8 @@ def generate_deprecation_message(
)

if not end_date:
if TYPE_CHECKING:
assert start_date is not None
parsed_start_date = pendulum.from_format(start_date, DEPRECATED_DATEFMT)
parsed_end_date = parsed_start_date.add(months=6)
end_date = parsed_end_date.format(DEPRECATED_DATEFMT)
Expand All @@ -83,8 +88,8 @@ def deprecated_callable(
end_date: Optional[str] = None,
stacklevel: int = 2,
help: str = "",
) -> Callable[[T], T]:
def decorator(fn: T):
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
message = generate_deprecation_message(
name=to_qualified_name(fn),
start_date=start_date,
Expand All @@ -93,7 +98,7 @@ def decorator(fn: T):
)

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel)
return fn(*args, **kwargs)

Expand All @@ -108,8 +113,8 @@ def deprecated_class(
end_date: Optional[str] = None,
stacklevel: int = 2,
help: str = "",
) -> Callable[[T], T]:
def decorator(cls: T):
) -> Callable[[type[T]], type[T]]:
def decorator(cls: type[T]) -> type[T]:
message = generate_deprecation_message(
name=to_qualified_name(cls),
start_date=start_date,
Expand All @@ -120,7 +125,7 @@ def decorator(cls: T):
original_init = cls.__init__

@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
def new_init(self: T, *args: Any, **kwargs: Any) -> None:
warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel)
original_init(self, *args, **kwargs)

Expand All @@ -139,7 +144,7 @@ def deprecated_parameter(
help: str = "",
when: Optional[Callable[[Any], bool]] = None,
when_message: str = "",
) -> Callable[[T], T]:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Mark a parameter in a callable as deprecated.
Expand All @@ -155,7 +160,7 @@ def foo(x, y = None):

when = when or (lambda _: True)

def decorator(fn: T):
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
message = generate_deprecation_message(
name=f"The parameter {name!r} for {fn.__name__!r}",
start_date=start_date,
Expand All @@ -165,7 +170,7 @@ def decorator(fn: T):
)

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
parameters = get_call_parameters(fn, args, kwargs, apply_defaults=False)
except Exception:
Expand All @@ -182,6 +187,10 @@ def wrapper(*args, **kwargs):
return decorator


JsonValue: TypeAlias = Union[int, float, str, bool, None, list["JsonValue"], "JsonDict"]
JsonDict: TypeAlias = dict[str, JsonValue]


def deprecated_field(
name: str,
*,
Expand All @@ -191,7 +200,7 @@ def deprecated_field(
help: str = "",
when: Optional[Callable[[Any], bool]] = None,
stacklevel: int = 2,
):
) -> Callable[[type[M]], type[M]]:
"""
Mark a field in a Pydantic model as deprecated.
Expand All @@ -212,7 +221,7 @@ class Model(BaseModel)

# Replaces the model's __init__ method with one that performs an additional warning
# check
def decorator(model_cls: Type[M]) -> Type[M]:
def decorator(model_cls: type[M]) -> type[M]:
message = generate_deprecation_message(
name=f"The field {name!r} in {model_cls.__name__!r}",
start_date=start_date,
Expand All @@ -224,16 +233,31 @@ def decorator(model_cls: Type[M]) -> Type[M]:
cls_init = model_cls.__init__

@functools.wraps(model_cls.__init__)
def __init__(__pydantic_self__, **data: Any) -> None:
def __init__(__pydantic_self__: M, **data: Any) -> None:
if name in data.keys() and when(data[name]):
warnings.warn(message, PrefectDeprecationWarning, stacklevel=stacklevel)

cls_init(__pydantic_self__, **data)

field = __pydantic_self__.model_fields.get(name)
if field is not None:
field.json_schema_extra = field.json_schema_extra or {}
field.json_schema_extra["deprecated"] = True
json_schema_extra = field.json_schema_extra or {}

if not isinstance(json_schema_extra, dict):
# json_schema_extra is a hook function; wrap it to add the deprecated flag.
extra_func = json_schema_extra

@functools.wraps(extra_func)
def wrapped(__json_schema: JsonDict) -> None:
extra_func(__json_schema)
__json_schema["deprecated"] = True

json_schema_extra = wrapped

else:
json_schema_extra["deprecated"] = True

field.json_schema_extra = json_schema_extra

# Patch the model's init method
model_cls.__init__ = __init__
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/_internal/compatibility/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"""

import sys
from typing import Any, Callable, Dict
from typing import Any, Callable

from pydantic_core import PydanticCustomError

Expand Down Expand Up @@ -157,7 +157,7 @@ def wrapper(name: str) -> object:
f"`{import_path}` has been removed. {error_message}"
)

globals: Dict[str, Any] = sys.modules[module_name].__dict__
globals: dict[str, Any] = sys.modules[module_name].__dict__
if name in globals:
return globals[name]

Expand Down
26 changes: 12 additions & 14 deletions src/prefect/_internal/concurrency/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import threading
from types import FrameType
from typing import List, Optional

"""
The following functions are derived from dask/distributed which is licensed under the
Expand Down Expand Up @@ -72,26 +71,25 @@ def repr_frame(frame: FrameType) -> str:
return text + "\n\t" + line


def call_stack(frame: FrameType) -> List[str]:
def call_stack(frame: FrameType) -> list[str]:
"""Create a call text stack from a frame"""
L = []
cur_frame: Optional[FrameType] = frame
frames: list[str] = []
cur_frame = frame
while cur_frame:
L.append(repr_frame(cur_frame))
frames.append(repr_frame(cur_frame))
cur_frame = cur_frame.f_back
return L[::-1]
return frames[::-1]


def stack_for_threads(*threads: threading.Thread) -> List[str]:
frames = sys._current_frames()
def stack_for_threads(*threads: threading.Thread) -> list[str]:
frames = sys._current_frames() # pyright: ignore[reportPrivateUsage]
try:
lines = []
lines: list[str] = []
for thread in threads:
lines.append(
f"------ Call stack of {thread.name} ({hex(thread.ident)}) -----"
)
thread_frames = frames.get(thread.ident)
if thread_frames:
ident = thread.ident
hex_ident = hex(ident) if ident is not None else "<unknown>"
lines.append(f"------ Call stack of {thread.name} ({hex_ident}) -----")
if ident is not None and (thread_frames := frames.get(ident)):
lines.append("".join(call_stack(thread_frames)))
else:
lines.append("No stack frames found")
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/_internal/concurrency/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Event:
"""

def __init__(self) -> None:
self._waiters = collections.deque()
self._waiters: collections.deque[asyncio.Future[bool]] = collections.deque()
self._value = False
self._lock = threading.Lock()

Expand Down Expand Up @@ -69,7 +69,7 @@ async def wait(self) -> Literal[True]:
if self._value:
return True

fut = asyncio.get_running_loop().create_future()
fut: asyncio.Future[bool] = asyncio.get_running_loop().create_future()
self._waiters.append(fut)

try:
Expand Down
Loading
Loading