Skip to content

Commit

Permalink
tighten addon_toolkit types
Browse files Browse the repository at this point in the history
  • Loading branch information
aaxelb committed May 3, 2024
1 parent 59253d2 commit bce59e4
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 74 deletions.
2 changes: 1 addition & 1 deletion addon_imps/storage/box_dot_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _params_from_cursor(self, cursor: str = "") -> dict[str, str]:
# https://developer.box.com/guides/api-calls/pagination/offset-based/
try:
_cursor = OffsetCursor.from_str(cursor)
return {"offset": _cursor.offset, "limit": _cursor.limit}
return {"offset": str(_cursor.offset), "limit": str(_cursor.limit)}
except ValueError:
return {}

Expand Down
2 changes: 1 addition & 1 deletion addon_service/common/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

# abstract method from HttpRequestor:
@contextlib.asynccontextmanager
async def do_send(self, request: HttpRequestInfo):
async def send_request(self, request: HttpRequestInfo):
try:
async with self._try_send(request) as _response:
yield _response
Expand Down
22 changes: 11 additions & 11 deletions addon_toolkit/constrained_network/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,25 @@ class HttpRequestor(typing.Protocol):
def response_info_cls(self) -> type[HttpResponseInfo]: ...

# abstract method for subclasses
def do_send(
def send_request(
self, request: HttpRequestInfo
) -> contextlib.AbstractAsyncContextManager[HttpResponseInfo]: ...

@contextlib.asynccontextmanager
async def request(
async def _request(
self,
http_method: HTTPMethod,
uri_path: str,
query: Multidict | KeyValuePairs | None = None,
headers: Multidict | KeyValuePairs | None = None,
):
) -> typing.Any: # loose type; method-specific methods below are more accurate
_request_info = HttpRequestInfo(
http_method=http_method,
uri_path=uri_path,
query=(query if isinstance(query, Multidict) else Multidict(query)),
headers=(headers if isinstance(headers, Multidict) else Multidict(headers)),
)
async with self.do_send(_request_info) as _response:
async with self.send_request(_request_info) as _response:
yield _response

# TODO: streaming send/receive (only if/when needed)
Expand All @@ -88,10 +88,10 @@ async def request(
# convenience methods for http methods
# (same call signature as self.request, minus `http_method`)

OPTIONS: _MethodRequestMethod = partialmethod(request, HTTPMethod.OPTIONS)
HEAD: _MethodRequestMethod = partialmethod(request, HTTPMethod.HEAD)
GET: _MethodRequestMethod = partialmethod(request, HTTPMethod.GET)
PATCH: _MethodRequestMethod = partialmethod(request, HTTPMethod.PATCH)
POST: _MethodRequestMethod = partialmethod(request, HTTPMethod.POST)
PUT: _MethodRequestMethod = partialmethod(request, HTTPMethod.PUT)
DELETE: _MethodRequestMethod = partialmethod(request, HTTPMethod.DELETE)
OPTIONS: _MethodRequestMethod = partialmethod(_request, HTTPMethod.OPTIONS)
HEAD: _MethodRequestMethod = partialmethod(_request, HTTPMethod.HEAD)
GET: _MethodRequestMethod = partialmethod(_request, HTTPMethod.GET)
PATCH: _MethodRequestMethod = partialmethod(_request, HTTPMethod.PATCH)
POST: _MethodRequestMethod = partialmethod(_request, HTTPMethod.POST)
PUT: _MethodRequestMethod = partialmethod(_request, HTTPMethod.PUT)
DELETE: _MethodRequestMethod = partialmethod(_request, HTTPMethod.DELETE)
7 changes: 3 additions & 4 deletions addon_toolkit/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@

@dataclasses.dataclass(frozen=True)
class Credentials(typing.Protocol):
def asdict(self):
def asdict(self) -> dict[str, typing.Any]:
return dataclasses.asdict(self)

def iter_headers(self) -> typing.Iterator[tuple[str, str]]:
return
yield
yield from () # no headers unless implemented by subclass


@dataclasses.dataclass(frozen=True, kw_only=True)
class AccessTokenCredentials(Credentials):
access_token: str

def iter_headers(self):
def iter_headers(self) -> typing.Iterator[tuple[str, str]]:
yield ("Authorization", f"Bearer {self.access_token}")


Expand Down
24 changes: 15 additions & 9 deletions addon_toolkit/cursor.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
import base64
import dataclasses
import json
from typing import (
ClassVar,
Protocol,
)
import typing


def encode_cursor_dataclass(dataclass_instance) -> str:
class DataclassInstance(typing.Protocol):
__dataclass_fields__: typing.ClassVar[dict[str, typing.Any]]


SomeDataclassInstance = typing.TypeVar("SomeDataclassInstance", bound=DataclassInstance)


def encode_cursor_dataclass(dataclass_instance: DataclassInstance) -> str:
_as_json = json.dumps(dataclasses.astuple(dataclass_instance))
_cursor_bytes = base64.b64encode(_as_json.encode())
return _cursor_bytes.decode()


def decode_cursor_dataclass(cursor: str, dataclass_class):
def decode_cursor_dataclass(
cursor: str, dataclass_class: type[SomeDataclassInstance]
) -> SomeDataclassInstance:
_as_list = json.loads(base64.b64decode(cursor))
return dataclass_class(*_as_list)


class Cursor(Protocol):
class Cursor(DataclassInstance, typing.Protocol):
@classmethod
def from_str(cls, cursor: str):
def from_str(cls, cursor: str) -> typing.Self:
return decode_cursor_dataclass(cursor, cls)

@property
Expand Down Expand Up @@ -52,7 +58,7 @@ class OffsetCursor(Cursor):
limit: int
total_count: int # use -1 to mean "many more"

MAX_INDEX: ClassVar[int] = 9999
MAX_INDEX: typing.ClassVar[int] = 9999

@property
def next_cursor_str(self) -> str | None:
Expand Down
31 changes: 17 additions & 14 deletions addon_toolkit/declarator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
TypeVar,
)

from addon_toolkit.typing import DataclassInstance


DecoratorTarget = TypeVar("DecoratorTarget")
DeclarationDataclass = TypeVar("DeclarationDataclass")


@dataclasses.dataclass
class Declarator(Generic[DeclarationDataclass]):
class Declarator(Generic[DataclassInstance]):
"""Declarator: add declarative metadata in python using decorators and dataclasses
define a dataclass with fields you want declared in your decorator, plus a field
Expand Down Expand Up @@ -48,15 +49,15 @@ class Declarator(Generic[DeclarationDataclass]):
TwoPartGreetingDeclaration(a='kia', b='ora', on=<function _kia_ora at 0x...>)
"""

declaration_dataclass: type[DeclarationDataclass]
declaration_dataclass: type[DataclassInstance]
field_for_target: str
static_kwargs: dict[str, Any] | None = None

# private storage linking a decorated class or function to data gleaned from its decorator
__declarations_by_target: weakref.WeakKeyDictionary[
object, DeclarationDataclass
] = dataclasses.field(
default_factory=weakref.WeakKeyDictionary,
__declarations_by_target: weakref.WeakKeyDictionary[object, DataclassInstance] = (
dataclasses.field(
default_factory=weakref.WeakKeyDictionary,
)
)

def __post_init__(self) -> None:
Expand All @@ -69,7 +70,7 @@ def __post_init__(self) -> None:
), f'expected field "{self.field_for_target}" on dataclass "{self.declaration_dataclass}"'

def __call__(
self, **declaration_dataclass_kwargs
self, **declaration_dataclass_kwargs: Any
) -> Callable[[DecoratorTarget], DecoratorTarget]:
"""for using a Declarator as a decorator"""

Expand All @@ -79,13 +80,13 @@ def _decorator(decorator_target: DecoratorTarget) -> DecoratorTarget:

return _decorator

def with_kwargs(self, **static_kwargs) -> "Declarator[DeclarationDataclass]":
def with_kwargs(self, **static_kwargs: Any) -> "Declarator[DataclassInstance]":
"""convenience for decorators that differ only by static field values"""
# note: shared __declarations_by_target
return dataclasses.replace(self, static_kwargs=static_kwargs)

def set_declaration(
self, declaration_target: DecoratorTarget, **declaration_dataclass_kwargs
self, declaration_target: DecoratorTarget, **declaration_dataclass_kwargs: Any
) -> None:
"""create a declaration associated with the target
Expand All @@ -98,14 +99,14 @@ def set_declaration(
**{self.field_for_target: declaration_target},
)

def get_declaration(self, target) -> DeclarationDataclass:
def get_declaration(self, target: DecoratorTarget) -> DataclassInstance:
try:
return self.__declarations_by_target[target]
except KeyError:
raise ValueError(f"no declaration found for {target}")


class ClassDeclarator(Declarator[DeclarationDataclass]):
class ClassDeclarator(Declarator[DataclassInstance]):
"""add declarative metadata to python classes using decorators
(same as Declarator but with additional methods that only make
Expand Down Expand Up @@ -157,13 +158,15 @@ class ClassDeclarator(Declarator[DeclarationDataclass]):
SemanticVersionDeclaration(major=4, minor=2, patch=9, subj=<class 'addon_toolkit.declarator.MyLongLivedBaseClass'>)
"""

def get_declaration_for_class_or_instance(self, type_or_object: type | object):
def get_declaration_for_class_or_instance(
self, type_or_object: type | object
) -> DataclassInstance:
_cls = (
type_or_object if isinstance(type_or_object, type) else type(type_or_object)
)
return self.get_declaration_for_class(_cls)

def get_declaration_for_class(self, cls: type):
def get_declaration_for_class(self, cls: type) -> DataclassInstance:
for _cls in cls.__mro__:
try:
return self.get_declaration(_cls)
Expand Down
25 changes: 11 additions & 14 deletions addon_toolkit/imp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import dataclasses
import enum
import inspect
from typing import (
Iterable,
Iterator,
)
import typing

from asgiref.sync import (
async_to_sync,
Expand Down Expand Up @@ -37,16 +34,16 @@ class AddonImp:
imp_number: int
addon_protocol: AddonProtocolDeclaration = dataclasses.field(init=False)

def __post_init__(self, addon_protocol_cls):
def __post_init__(self, addon_protocol_cls: type) -> None:
object.__setattr__( # using __setattr__ to bypass dataclass frozenness
self,
"addon_protocol",
addon_protocol.get_declaration(addon_protocol_cls),
)

def get_operation_imps(
self, *, capabilities: Iterable[enum.Enum] = ()
) -> Iterator["AddonOperationImp"]:
self, *, capabilities: typing.Iterable[enum.Enum] = ()
) -> typing.Iterator["AddonOperationImp"]:
for _declaration in self.addon_protocol.get_operation_declarations(
capabilities=capabilities
):
Expand Down Expand Up @@ -74,26 +71,26 @@ class AddonOperationImp:
addon_imp: AddonImp
declaration: AddonOperationDeclaration

def __post_init__(self):
def __post_init__(self) -> None:
_protocol_fn = getattr(
self.addon_imp.addon_protocol.protocol_cls, self.declaration.name
)
try:
_imp_fn = self.imp_function
except AttributeError:
except Exception:
_imp_fn = _protocol_fn
if _imp_fn is _protocol_fn:
raise NotImplementedError( # TODO: helpful exception type
f"operation '{self.declaration}' not implemented by {self.addon_imp}"
)

@property
def imp_function(self):
def imp_function(self) -> typing.Any: # TODO: less typing.Any
return getattr(self.addon_imp.imp_cls, self.declaration.name)

async def invoke_thru_addon(
self, addon_instance: object, json_kwargs: JsonableDict
):
) -> typing.Any: # TODO: less typing.Any
_method = self._get_instance_method(addon_instance)
_kwargs = kwargs_from_json(self.declaration.call_signature, json_kwargs)
if not inspect.iscoroutinefunction(_method):
Expand All @@ -104,7 +101,7 @@ async def invoke_thru_addon(

invoke_thru_addon__blocking = async_to_sync(invoke_thru_addon)

def _get_instance_method(self, addon_instance: object):
def _get_instance_method(
self, addon_instance: object
) -> typing.Any: # TODO: less typing.Any
return getattr(addon_instance, self.declaration.name)

# TODO: async def async_call_with_json_kwargs(self, addon_instance: object, json_kwargs: dict):
6 changes: 3 additions & 3 deletions addon_toolkit/iri_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def __init__(self, key_value_pairs: KeyValuePairs | None = None):
_headerslist = list(key_value_pairs)
super().__init__(_headerslist)

def add(self, key: str, value: str, **mediatype_params):
def add(self, key: str, value: str) -> None:
"""add a key-value pair (allowing other values to exist)
alias of `wsgiref.headers.Headers.add_header`
"""
super().add_header(key, value, **mediatype_params)
super().add_header(key, value)

def add_many(self, pairs: Iterable[tuple[str, str]]):
def add_many(self, pairs: Iterable[tuple[str, str]]) -> None:
for _key, _value in pairs:
self.add(_key, _value)

Expand Down
Loading

0 comments on commit bce59e4

Please sign in to comment.