Skip to content

Commit

Permalink
Automatically unquote to offer better support
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jun 4, 2022
1 parent 69d8f6a commit 9971ab0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
6 changes: 4 additions & 2 deletions plum/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import inspect
import logging
from functools import wraps
import typing
from functools import wraps

from .resolvable import Promise
from .signature import Signature
from .type import ptype, is_object, VarArgs, deliver_forward_reference
from .util import check_future_annotations
from .util import unquote

__all__ = [
"extract_signature",
Expand Down Expand Up @@ -56,6 +56,8 @@ def extract_signature(f, get_type_hints=False):
function.
"""
if get_type_hints:
# Unquote type hints so that they are resolved to the right types.
f.__annotations__ = {k: unquote(v) for k, v in f.__annotations__.items()}
f.__annotations__ = typing.get_type_hints(f)

# Extract specification.
Expand Down
6 changes: 2 additions & 4 deletions plum/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging

from .resolvable import Resolvable, Promise
from .util import multihash, Comparable
from .util import multihash, Comparable, unquote

__all__ = [
"TypeMeta",
Expand Down Expand Up @@ -286,9 +286,7 @@ def get_forward_reference(name):
# The name can possibly be wrapped in an extra set of quotes if the future
# `annotations` is used and the type is given as a string. E.g., see
# https://github.com/wesselb/plum/issues/41
# We simply remove any extra quotes.
while len(name) >= 2 and name[0] == name[-1] and name[0] in {'"', "'"}:
name = name[1:-1]
name = unquote(name)
reference = ForwardReferencedType(name)
_unresolved_forward_references.append(reference)
return reference
Expand Down
15 changes: 15 additions & 0 deletions plum/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"get_class",
"get_context",
"check_future_annotations",
"unquote",
]

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -129,3 +130,17 @@ def check_future_annotations():
"""
frame = inspect.currentframe()
return "annotations" in frame.f_back.f_back.f_globals


def unquote(x):
"""Remove quotes from a string at the outermost level.
Args:
x (str): String to remove quotes from.
Return:
str: `x` but without quotes.
"""
while len(x) >= 2 and x[0] == x[-1] and x[0] in {'"', "'"}:
x = x[1:-1]
return x
2 changes: 1 addition & 1 deletion tests/dispatcher/test_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, value):
self.value = value

@dispatch
def __add__(self, other: Union[Number, int]) -> Number:
def __add__(self, other: Union[Number, int]) -> "Number":
if isinstance(other, int):
other_value = other
else:
Expand Down

0 comments on commit 9971ab0

Please sign in to comment.