-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #44 from PhilipVinc/pv/autoreload
autoreload support
- Loading branch information
Showing
6 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import gc | ||
import os | ||
|
||
from .type import Type, Union | ||
from .dispatcher import Dispatcher | ||
|
||
__all__ = ["activate", "deactivate"] | ||
|
||
|
||
def _update_instances(old, new): | ||
"""Use garbage collector to find all instances that refer to the old | ||
class definition and update their __class__ to point to the new class | ||
definition""" | ||
|
||
refs = gc.get_referrers(old) | ||
|
||
updated_plum_type = False | ||
|
||
for ref in refs: | ||
if type(ref) is old: | ||
ref.__class__ = new | ||
elif type(ref) == Type: | ||
updated_plum_type = True | ||
ref._type = new | ||
|
||
# if we updated a plum type, then | ||
# use the gc to get all dispatchers and clear | ||
# their cache | ||
if updated_plum_type: | ||
refs = gc.get_referrers(Dispatcher) | ||
for ref in refs: | ||
if type(ref) is Dispatcher: | ||
ref.clear_cache() | ||
|
||
|
||
_update_instances_original = None | ||
|
||
|
||
def activate(): | ||
""" | ||
Pirate autoreload's `update_instance` function to handle Plum types. | ||
""" | ||
from IPython.extensions import autoreload | ||
from IPython.extensions.autoreload import update_instances | ||
|
||
# First, cache the original method so we can deactivate ourselves. | ||
global _update_instances_original | ||
if _update_instances_original is None: | ||
_update_instances_original = autoreload.update_instances | ||
|
||
# Then, override the update_instance method | ||
setattr(autoreload, "update_instances", _update_instances) | ||
|
||
|
||
def deactivate(): | ||
""" | ||
Disable Plum's autoreload hack. | ||
""" | ||
global _update_instances_original | ||
if _update_instances_original is None: # pragma: no cover | ||
raise RuntimeError("Plum Autoreload module was never activated.") | ||
|
||
from IPython.extensions import autoreload | ||
|
||
setattr(autoreload, "update_instances", _update_instances_original) | ||
|
||
|
||
# Detect `PLUM_AUTORELOAD` env variable | ||
_autoload = os.environ.get("PLUM_AUTORELOAD", "0").lower() | ||
if _autoload in ("y", "yes", "t", "true", "on", "1"): # pragma: no cover | ||
_autoload = True | ||
else: | ||
_autoload = False | ||
|
||
if _autoload: # pragma: no cover | ||
try: | ||
# Try to load IPython and get the iPython session, but don't crash if | ||
# this does not work (for example IPython not installed, or python shell) | ||
from IPython import get_ipython | ||
|
||
ip = get_ipython() | ||
if ip is not None: | ||
if "IPython.extensions.storemagic" in ip.extension_manager.loaded: | ||
activate() | ||
|
||
except ImportError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ black==22.3.0 | |
pre-commit | ||
setuptools_scm[toml] | ||
setuptools_scm_git_archive | ||
IPython |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pathlib import Path | ||
|
||
from plum import Dispatcher, autoreload as p_autoreload | ||
from plum.function import NotFoundLookupError | ||
|
||
|
||
def test_autoreload_activate_deactivate(): | ||
p_autoreload.activate() | ||
|
||
assert p_autoreload._update_instances_original is not None | ||
assert ( | ||
p_autoreload._update_instances_original.__module__ | ||
== "IPython.extensions.autoreload" | ||
) | ||
|
||
from IPython.extensions import autoreload | ||
|
||
assert autoreload.update_instances.__module__ == "plum.autoreload" | ||
|
||
p_autoreload.deactivate() | ||
|
||
assert ( | ||
p_autoreload._update_instances_original.__module__ | ||
== "IPython.extensions.autoreload" | ||
) | ||
assert autoreload.update_instances.__module__ == "IPython.extensions.autoreload" | ||
assert autoreload.update_instances == p_autoreload._update_instances_original | ||
|
||
|
||
def test_autoreload_works(): | ||
dispatch = Dispatcher() | ||
|
||
class A1: | ||
pass | ||
|
||
class A2: | ||
pass | ||
|
||
@dispatch | ||
def test(x: A1): | ||
return 1 | ||
|
||
assert test(A1()) == 1 | ||
|
||
with pytest.raises(NotFoundLookupError): | ||
test(A2()) | ||
|
||
a1 = A1() | ||
|
||
p_autoreload._update_instances(A1, A2) | ||
|
||
assert test(A2()) == 1 | ||
|
||
with pytest.raises(NotFoundLookupError): | ||
test(A1()) | ||
|
||
assert isinstance(a1, A2) |