Skip to content

Commit

Permalink
feat(mypy): more overloads on ASyncDecorator (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Apr 30, 2024
1 parent c2471fa commit 7a6fdbc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 7 additions & 2 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

from typing_extensions import Concatenate, ParamSpec, Self, Unpack

if TYPE_CHECKING:
from a_sync import ASyncGenericBase
B = TypeVar("B", bound=ASyncGenericBase)
else:
B = TypeVar("B")

T = TypeVar("T")
K = TypeVar("K")
Expand Down Expand Up @@ -37,10 +42,10 @@ class SyncBoundMethod(Protocol[I, P, T]):

@runtime_checkable
class AsyncUnboundMethod(Protocol[I, P, T]):
__get__: Callable[[I, None], CoroBoundMethod[I, P, T]]
__get__: Callable[[I, Type], CoroBoundMethod[I, P, T]]
@runtime_checkable
class SyncUnboundMethod(Protocol[I, P, T]):
__get__: Callable[[I, None], SyncBoundMethod[I, P, T]]
__get__: Callable[[I, Type], SyncBoundMethod[I, P, T]]
AnyUnboundMethod = Union[AsyncUnboundMethod[I, P, T], SyncUnboundMethod[I, P, T]]

AsyncGetterFunction = Callable[[I], Awaitable[T]]
Expand Down
12 changes: 10 additions & 2 deletions a_sync/a_sync/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

if TYPE_CHECKING:
from a_sync import TaskMapping
from a_sync.a_sync.method import (ASyncBoundMethod, ASyncBoundMethodAsyncDefault,
ASyncBoundMethodSyncDefault)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -176,10 +178,10 @@ def validate_inputs(self) -> None:
raise ValueError(f"'default' must be either 'sync', 'async', or None. You passed {self.modifiers.default}.")

@overload
def __call__(self, func: CoroFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override]
def __call__(self, func: AnyFn[Concatenate[B, P], T]) -> "ASyncBoundMethod[B, P, T]": # type: ignore [override]
...
@overload
def __call__(self, func: SyncFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override]
def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override]
...
def __call__(self, func: AnyFn[P, T]) -> ASyncFunction[P, T]: # type: ignore [override]
if self.default == "async":
Expand Down Expand Up @@ -234,6 +236,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> MaybeAwaitable[T]:

class ASyncDecoratorSyncDefault(ASyncDecorator):
@overload
def __call__(self, func: AnyFn[Concatenate[B, P], T]) -> "ASyncBoundMethodSyncDefault[P, T]": # type: ignore [override]
...
@overload
def __call__(self, func: AnyBoundMethod[P, T]) -> ASyncFunctionSyncDefault[P, T]: # type: ignore [override]
...
@overload
Expand All @@ -244,6 +249,9 @@ def __call__(self, func: AnyFn[P, T]) -> ASyncFunctionSyncDefault[P, T]:

class ASyncDecoratorAsyncDefault(ASyncDecorator):
@overload
def __call__(self, func: AnyFn[Concatenate[B, P], T]) -> "ASyncBoundMethodAsyncDefault[P, T]": # type: ignore [override]
...
@overload
def __call__(self, func: AnyBoundMethod[P, T]) -> ASyncFunctionAsyncDefault[P, T]: # type: ignore [override]
...
@overload
Expand Down

0 comments on commit 7a6fdbc

Please sign in to comment.