Skip to content

Commit

Permalink
Merge pull request #44 from PhilipVinc/pv/autoreload
Browse files Browse the repository at this point in the history
autoreload support
  • Loading branch information
wesselb authored Jun 9, 2022
2 parents 4e24872 + 8682157 commit 08d6f16
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 0 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Everybody likes multiple dispatch, just like everybody likes plums.
- [Add Multiple Methods](#add-multiple-methods)
- [Extend a Function From Another Package](#extend-a-function-from-another-package)
- [Directly Invoke a Method](#directly-invoke-a-method)
* [IPython's autoreload support](#support-for-ipython-autoreload)

## Installation

Expand Down Expand Up @@ -1165,3 +1166,28 @@ def f(x: str):
>>> f.invoke(str)(1)
'str'
```

### Support for IPython autoreload

Plum does not work out of the box with IPython's autoreload, and if you reload a file where a class is defined, you will most likely break your dispatch table.

However, experimental support for IPython's autoreload is included into plum but it is not enabled by default, as it overrides some internal methods of IPython.
To activate it, either set the environment variable `PLUM_AUTORELOAD=1` **before** loading plum

```bash
export PLUM_AUTORELOAD=1
```

or manually call the `autoreload.activate` method in an interactive session.

```python
import plum
plum.autoreload.activate()
```

If there are issues with autoreload, please open a bug report.





2 changes: 2 additions & 0 deletions plum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ def _function_dispatch(*args, **kw_args):
from .resolvable import *
from .signature import *
from .type import *

from . import autoreload
87 changes: 87 additions & 0 deletions plum/autoreload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import gc
import os

from .type import Type, Union
from .dispatcher import Dispatcher

__all__ = ["activate", "deactivate"]


def _update_instances(old, new):
"""Use garbage collector to find all instances that refer to the old
class definition and update their __class__ to point to the new class
definition"""

refs = gc.get_referrers(old)

updated_plum_type = False

for ref in refs:
if type(ref) is old:
ref.__class__ = new
elif type(ref) == Type:
updated_plum_type = True
ref._type = new

# if we updated a plum type, then
# use the gc to get all dispatchers and clear
# their cache
if updated_plum_type:
refs = gc.get_referrers(Dispatcher)
for ref in refs:
if type(ref) is Dispatcher:
ref.clear_cache()


_update_instances_original = None


def activate():
"""
Pirate autoreload's `update_instance` function to handle Plum types.
"""
from IPython.extensions import autoreload
from IPython.extensions.autoreload import update_instances

# First, cache the original method so we can deactivate ourselves.
global _update_instances_original
if _update_instances_original is None:
_update_instances_original = autoreload.update_instances

# Then, override the update_instance method
setattr(autoreload, "update_instances", _update_instances)


def deactivate():
"""
Disable Plum's autoreload hack.
"""
global _update_instances_original
if _update_instances_original is None: # pragma: no cover
raise RuntimeError("Plum Autoreload module was never activated.")

from IPython.extensions import autoreload

setattr(autoreload, "update_instances", _update_instances_original)


# Detect `PLUM_AUTORELOAD` env variable
_autoload = os.environ.get("PLUM_AUTORELOAD", "0").lower()
if _autoload in ("y", "yes", "t", "true", "on", "1"): # pragma: no cover
_autoload = True
else:
_autoload = False

if _autoload: # pragma: no cover
try:
# Try to load IPython and get the iPython session, but don't crash if
# this does not work (for example IPython not installed, or python shell)
from IPython import get_ipython

ip = get_ipython()
if ip is not None:
if "IPython.extensions.storemagic" in ip.extension_manager.loaded:
activate()

except ImportError:
pass
6 changes: 6 additions & 0 deletions plum/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class VarArgs(AbstractType):
number of types. Defaults to `object`.
"""

__slots__ = ["type"]

def __init__(self, type=object):
self.type = ptype(type)

Expand Down Expand Up @@ -159,6 +161,8 @@ class Union(ComparableType):
alias (str, optional): Give the union a name.
"""

__slots__ = ["_types"]

def __init__(self, *types, alias=None):
# Lazily convert to a set to avoid resolution errors.
self._types = tuple(ptype(t) for t in types)
Expand Down Expand Up @@ -220,6 +224,8 @@ class Type(ComparableType):
type (type): Type to encapsulate.
"""

__slots__ = ["_type"]

def __init__(self, type):
self._type = type

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ black==22.3.0
pre-commit
setuptools_scm[toml]
setuptools_scm_git_archive
IPython
60 changes: 60 additions & 0 deletions tests/test_autoreload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np
import pytest

from pathlib import Path

from plum import Dispatcher, autoreload as p_autoreload
from plum.function import NotFoundLookupError


def test_autoreload_activate_deactivate():
p_autoreload.activate()

assert p_autoreload._update_instances_original is not None
assert (
p_autoreload._update_instances_original.__module__
== "IPython.extensions.autoreload"
)

from IPython.extensions import autoreload

assert autoreload.update_instances.__module__ == "plum.autoreload"

p_autoreload.deactivate()

assert (
p_autoreload._update_instances_original.__module__
== "IPython.extensions.autoreload"
)
assert autoreload.update_instances.__module__ == "IPython.extensions.autoreload"
assert autoreload.update_instances == p_autoreload._update_instances_original


def test_autoreload_works():
dispatch = Dispatcher()

class A1:
pass

class A2:
pass

@dispatch
def test(x: A1):
return 1

assert test(A1()) == 1

with pytest.raises(NotFoundLookupError):
test(A2())

a1 = A1()

p_autoreload._update_instances(A1, A2)

assert test(A2()) == 1

with pytest.raises(NotFoundLookupError):
test(A1())

assert isinstance(a1, A2)

0 comments on commit 08d6f16

Please sign in to comment.