Skip to content

Commit

Permalink
Fix default arguments bug and improve MethodRedefinitionWarning (#168)
Browse files Browse the repository at this point in the history
* Fix bug and improve warning

* Test default arguments bug

* Test implementation unwrapping

* Make redefinition warnings opt in

* Describe `warn_redefinition` in docs
  • Loading branch information
wesselb authored Jul 6, 2024
1 parent 40ac70d commit 3eaf417
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 25 deletions.
37 changes: 36 additions & 1 deletion docs/scope.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Scope of Functions


## Dispatchers

% skip: start "Example code"

Consider the following package design.
Expand Down Expand Up @@ -114,4 +117,36 @@ NotFoundLookupError: For function `f`, `(1,)` could not be resolved.
'float'
```

% skip end
% skip: end

## Redefinition Warnings

Whenever you create a dispatcher, you can set `warn_redefinition=True` to throw a warning whenever a method of a function overwrites another.
It is recommended to use this setting.

% invisible-code-block: python
%
% import warnings

```python
>>> from plum import Dispatcher

>>> dispatch = Dispatcher(warn_redefinition=True)

>>> @dispatch
... def f(x: int):
... return x

>>> @dispatch
... def f(x: int):
... return x

>>> with warnings.catch_warnings(record=True) as w: # doctest:+ELLIPSIS
... f(1)
... print(w[0].message)
1
`Method(function_name='f', signature=Signature(int), return_type=typing.Any, implementation=<function f at 0x...>)` (`<doctest .../scope.md[0]>:1`) overwrites the earlier definition `Method(function_name='f', signature=Signature(int), return_type=typing.Any, implementation=<function f at 0x...>)` (`<doctest .../scope.md[0]>:1`).
```

Note that the redefinition warning is thrown whenever the function is run for the first
time, because methods are only registered whenever they are needed.
12 changes: 11 additions & 1 deletion plum/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@
class Dispatcher:
"""A namespace for functions.
Args:
warn_redefinition (bool, optional): Throw a warning whenever a method is
redefined. Defaults to `False`.
Attributes:
functions (dict[str, :class:`.function.Function`]): Functions by name.
classes (dict[str, dict[str, :class:`.function.Function`]]): Methods of
all classes by the qualified name of a class.
warn_redefinition (bool): Throw a warning whenever a method is redefined.
"""

warn_redefinition: bool = False
functions: Dict[str, Function] = field(default_factory=dict)
classes: Dict[str, Dict[str, Function]] = field(default_factory=dict)

Expand Down Expand Up @@ -115,7 +121,11 @@ def _get_function(self, method: Callable) -> Function:
# Create a new function only if the function does not already exist.
name = method.__name__
if name not in namespace:
namespace[name] = Function(method, owner=owner)
namespace[name] = Function(
method,
owner=owner,
warn_redefinition=self.warn_redefinition,
)

return namespace[name]

Expand Down
28 changes: 25 additions & 3 deletions plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class Function(metaclass=_FunctionMeta):
Args:
f (function): Function that is wrapped.
owner (str, optional): Name of the class that owns the function.
warn_redefinition (bool, optional): Throw a warning whenever a method is
redefined. Defaults to `False`.
"""

# When we set `__doc__`, we will lose the docstring of the class, so we save it now.
Expand All @@ -71,7 +73,12 @@ class Function(metaclass=_FunctionMeta):

_instances = []

def __init__(self, f: Callable, owner: Optional[str] = None) -> None:
def __init__(
self,
f: Callable,
owner: Optional[str] = None,
warn_redefinition: bool = False,
) -> None:
Function._instances.append(self)

self._f: Callable = f
Expand All @@ -86,9 +93,14 @@ def __init__(self, f: Callable, owner: Optional[str] = None) -> None:
self._owner_name: Optional[str] = owner
self._owner: Optional[type] = None

self._warn_redefinition = warn_redefinition

# Initialise pending and resolved methods.
self._pending: List[Tuple[Callable, Optional[Signature], int]] = []
self._resolver = Resolver(self.__name__)
self._resolver = Resolver(
self.__name__,
warn_redefinition=self._warn_redefinition,
)
self._resolved: List[Tuple[Callable, Signature, int]] = []

@property
Expand Down Expand Up @@ -233,7 +245,10 @@ def clear_cache(self, reregister: bool = True) -> None:

# Clear resolved.
self._resolved = []
self._resolver = Resolver(self._resolver.function_name)
self._resolver = Resolver(
self._resolver.function_name,
warn_redefinition=self._warn_redefinition,
)

def register(
self, f: Callable, signature: Optional[Signature] = None, precedence=0
Expand Down Expand Up @@ -417,6 +432,8 @@ def invoke(self, *types: TypeHint) -> Callable:
def wrapped_method(*args, **kw_args):
return _convert(method(*args, **kw_args), return_type)

wrapped_method.__wrapped_by_plum__ = method

return wrapped_method

def __get__(self, instance, owner):
Expand Down Expand Up @@ -492,4 +509,9 @@ def wrapped_method(*args, **kw_args):
method = self._f.invoke(type(self._instance), *types)
return method(self._instance, *args, **kw_args)

# We set `f.__wrapped_by_plum__` for :func:`Function.invoke`, but here we do
# not: this method has `self._instance` prepended to its arguments, so there
# is no "wrapped method". In addition, bound functions cannot be directly
# extended, so unwrapping is likely never desired.

return wrapped_method
6 changes: 4 additions & 2 deletions plum/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def __str__(self):
function_name = self.function_name
signature = self.signature
return_type = self.return_type
impl = self.implementation
return f"Method({function_name=}, {signature=}, {return_type=}, {impl=})"
implementation = self.implementation
return (
f"Method({function_name=}, {signature=}, {return_type=}, {implementation=})"
)

def __rich_console__(self, console, options):
yield self.repr_mismatch()
Expand Down
65 changes: 56 additions & 9 deletions plum/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .util import argsort
from plum.method import Method, MethodList
from plum.repr import rich_repr
from plum.repr import repr_source_path, rich_repr
from plum.signature import Signature

__all__ = ["AmbiguousLookupError", "NotFoundLookupError"]
Expand Down Expand Up @@ -202,17 +202,52 @@ def _document(f: Callable, f_name: Optional[str] = None) -> str:
return "\n".join([title] + body).rstrip()


def _unwrap_invoked_methods(f):
"""Undo wrapping of :meth:`Function.invoke`d methods.
:meth:`Function.invoke` uses :func:`functools.wraps` to wrap the function and
convert the output to the right return type. This wrapping obscures where the
method was originally defined, meaning that :func:`plum.repr.repr_source_path`
gives erroneous results. This function undoes that wrapping and makes
:func:`plum.repr.repr_source_path` work correctly.
Args:
f (function): Function, possibly wrapped.
Returns:
function: `f`, but without any wrapping.
"""
while hasattr(f, "__wrapped_by_plum__"):
f = f.__wrapped_by_plum__
return f


class Resolver:
"""Method resolver.
Args:
function_name (str, optional): Name of the function.
warn_redefinition (bool, optional): Throw a warning whenever a method is
redefined. Defaults to `False`.
Attributes:
methods (list[:class:`.method.Method`]): Registered methods.
is_faithful (bool): Whether all methods are faithful or not.
warn_redefinition (bool): Throw a warning whenever a method is redefined.
"""

__slots__ = ("methods", "is_faithful", "function_name")
__slots__ = (
"function_name",
"methods",
"is_faithful",
"warn_redefinition",
)

def __init__(self, function_name: Optional[str] = None) -> None:
def __init__(
self,
function_name: Optional[str] = None,
warn_redefinition: bool = False,
) -> None:
"""Initialise the resolver.
Args:
Expand All @@ -221,6 +256,7 @@ def __init__(self, function_name: Optional[str] = None) -> None:
self.function_name = function_name
self.methods: MethodList = MethodList()
self.is_faithful: bool = True
self.warn_redefinition = warn_redefinition

def doc(self, exclude: Union[Callable, None] = None) -> str:
"""Concatenate the docstrings of all methods of this function. Remove duplicate
Expand Down Expand Up @@ -265,12 +301,23 @@ def register(self, method: Method) -> None:
f"The added method `{method}` is equal to {sum(existing)} "
f"existing methods. This should never happen."
)
previous_method = self.methods[existing.index(True)]
warnings.warn(
f"`{method}` overwrites the earlier definition `{previous_method}`.",
category=MethodRedefinitionWarning,
stacklevel=0,
)

if self.warn_redefinition:
# Determine the new and previous implementation. Unwrap possible
# wrapping by Plum from :meth:`Function.invoke`s, which can obscure the
# location where the implementation was originally defined.
previous_method = self.methods[existing.index(True)]
prev_impl = _unwrap_invoked_methods(previous_method.implementation)
impl = _unwrap_invoked_methods(method.implementation)
warnings.warn(
f"`{method}` (`{repr_source_path(impl)}`) "
f"overwrites the earlier definition "
f"`{previous_method}` "
f"(`{repr_source_path(prev_impl)}`).",
category=MethodRedefinitionWarning,
stacklevel=0,
)

self.methods[existing.index(True)] = method
else:
self.methods.append(method)
Expand Down
6 changes: 3 additions & 3 deletions plum/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,15 @@ def append_default_args(signature: Signature, f: Callable) -> List[Signature]:
remove default arguments.
Returns:
list[:class:`.signature.Signature`]: list of signatures excluding from 0 to all
default arguments.
list[:class:`.signature.Signature`]: List of signatures excluding from no to all
default arguments.
"""
# Extract specification.
f_signature = inspect_signature(f)

signatures = [signature]

arg_names = list(f_signature.parameters.keys())
arg_names = list(f_signature.parameters.keys())[: len(signature.types)]
# We start at the end and, once we reach non-keyword-only arguments, delete the
# argument with defaults values one by one. This generates a sequence of signatures,
# which we return.
Expand Down
22 changes: 22 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AmbiguousLookupError,
NotFoundLookupError,
_change_function_name,
_unwrap_invoked_methods,
)
from plum.signature import Signature

Expand Down Expand Up @@ -594,6 +595,27 @@ def f(x: int):
assert f.invoke(int).__doc__ == "Docs"


def test_invoke_implementation_unwrapping():
dispatch = Dispatcher()

def f(x: int):
return type(x)

f_orig = f
f = dispatch(f)

# Redirect `float`s to `int`s.
dispatch.multi((float,))(f.invoke(int))

assert f(1) == int
assert f(1.0) == float

assert f.methods[0].implementation is f_orig
assert f.methods[1].implementation is not f_orig
assert _unwrap_invoked_methods(f.methods[0].implementation) is f_orig
assert _unwrap_invoked_methods(f.methods[1].implementation) is f_orig


def test_bound():
dispatch = Dispatcher()

Expand Down
47 changes: 41 additions & 6 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ def f(x):
assert r.resolve(m_c1.signature) == m_b2


def test_redefinition_warning():
dispatch = Dispatcher()
@pytest.mark.parametrize("warn_redefinition", [False, True])
def test_redefinition_warning(warn_redefinition):
dispatch = Dispatcher(warn_redefinition=warn_redefinition)

with warnings.catch_warnings():
warnings.simplefilter("error")
Expand All @@ -264,10 +265,44 @@ def f(x: str):
# Warnings are only emitted when all registrations are resolved.
f._resolve_pending_registrations()

with pytest.warns(MethodRedefinitionWarning):
# Perform the testonce before more after clearing the cache. This reinstantiates
# the resolver, so we check that `warn_redefinition` is then set correctly.
for _ in range(2):
if warn_redefinition:
with pytest.warns(MethodRedefinitionWarning):

@dispatch
def f(x: int):
pass
@dispatch
def f(x: int):
pass

f._resolve_pending_registrations()
else:
with warnings.catch_warnings():
warnings.simplefilter("error")

@dispatch
def f(x: int):
pass

f._resolve_pending_registrations()

dispatch.clear_cache()


def test_redefinition_warning_unwrapping():
dispatch = Dispatcher(warn_redefinition=True)

@dispatch
def f(x: int):
pass

# Write and overwrite a method derived from an invoked methods. We test that the
# unwrapping to find the location of the implementation works correctly.
f.dispatch_multi((str,))(f.invoke(int))
f.dispatch_multi((str,))(f.invoke(int))

with pytest.warns(
MethodRedefinitionWarning,
match=r".*`.*test_resolver.py:[0-9]+`.*" * 2,
):
f._resolve_pending_registrations()
9 changes: 9 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,15 @@ def f(a: int, b=1, c: float = 1.0, *d: complex, option=None, **other_options):
assert (sigs[1].types, sigs[1].varargs) == ((int, Any), Missing)
assert (sigs[2].types, sigs[2].varargs) == ((int,), Missing)

# Test the case of more argument names than types.
sigs = append_default_args(Sig(int, Any), f)
assert len(sigs) == 2
assert (sigs[0].types, sigs[0].varargs) == ((int, Any), Missing)
assert (sigs[1].types, sigs[1].varargs) == ((int,), Missing)
sigs = append_default_args(Sig(int), f)
assert len(sigs) == 1
assert (sigs[0].types, sigs[0].varargs) == ((int,), Missing)

# Test that `itemgetter` is supported.
f = operator.itemgetter(0)
assert len(append_default_args(Sig.from_callable(f), f)) == 1

0 comments on commit 3eaf417

Please sign in to comment.