diff --git a/solara/server/app.py b/solara/server/app.py index 9d1952e3e..6dd251862 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -249,6 +249,31 @@ def reload(self): solara.lab.toestand.ConnectionStore._type_counter.clear() + # we need to remove callbacks that are added in the app code + # which will be re-executed after the reload and we do not + # want to keep executing the old ones. + for kc in kernel_context._on_kernel_start_callbacks.copy(): + callback, path, module, cleanup = kc + will_reload = False + if module is not None: + module_name = module.__name__ + if module_name in reload.reloader.get_reload_module_names(): + will_reload = True + elif path is not None: + if str(path.resolve()).startswith(str(self.directory)): + will_reload = True + else: + logger.warning( + "script %s is not in the same directory as the app %s but is using on_kernel_start, " + "this might lead to multiple entries, and might indicate a bug.", + path, + self.directory, + ) + + if will_reload: + logger.info("reload: Removing on_kernel_start callback: %s (since it will be added when reloaded)", callback) + cleanup() + context_values = list(kernel_context.contexts.values()) # save states into the context so the hot reload will # keep the same state diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index d4227a508..429c13f64 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,13 +1,15 @@ import asyncio import dataclasses import enum +import inspect import logging import os import pickle import threading import time from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, cast +from types import FrameType, ModuleType +from typing import Any, Callable, Dict, List, NamedTuple, Optional, cast import ipywidgets as widgets import reacton @@ -36,11 +38,46 @@ class PageStatus(enum.Enum): CLOSED = "closed" -_on_kernel_start_callbacks: List[Callable[[], Optional[Callable[[], None]]]] = [] +class _on_kernel_callback_entry(NamedTuple): + callback: Callable[[], Optional[Callable[[], None]]] + callpoint: Optional[Path] + module: Optional[ModuleType] + cleanup: Callable[[], None] -def on_kernel_start(f: Callable[[], Optional[Callable[[], None]]]): - _on_kernel_start_callbacks.append(f) +_on_kernel_start_callbacks: List[_on_kernel_callback_entry] = [] + + +def _find_root_module_frame() -> Optional[FrameType]: + # basically the module where the call stack origined from + current_frame = inspect.currentframe() + root_module_frame = None + + while current_frame is not None: + if current_frame.f_code.co_name == "": + root_module_frame = current_frame + break + current_frame = current_frame.f_back + + return root_module_frame + + +def on_kernel_start(f: Callable[[], Optional[Callable[[], None]]]) -> Callable[[], None]: + root = _find_root_module_frame() + path: Optional[Path] = None + module: Optional[ModuleType] = None + if root is not None: + path_str = inspect.getsourcefile(root) + module = inspect.getmodule(root) + if path_str is not None: + path = Path(path_str) + + def cleanup(): + return _on_kernel_start_callbacks.remove(kce) + + kce = _on_kernel_callback_entry(f, path, module, cleanup) + _on_kernel_start_callbacks.append(kce) + return cleanup @dataclasses.dataclass @@ -74,7 +111,7 @@ class VirtualKernelContext: def __post_init__(self): with self: - for f in _on_kernel_start_callbacks: + for (f, *_) in _on_kernel_start_callbacks: cleanup = f() if cleanup: self.on_close(cleanup) diff --git a/solara/server/reload.py b/solara/server/reload.py index 2120d2f10..bbe48e67c 100644 --- a/solara/server/reload.py +++ b/solara/server/reload.py @@ -152,13 +152,13 @@ def start(self): self._first = False def _on_change(self, name): - # used for testing - self.reload_event_next.set() # flag that we need to reload all modules next time self.requires_reload = True # and forward callback if self.on_change: self.on_change(name) + # used for testing + self.reload_event_next.set() def close(self): self.watcher.close() diff --git a/solara/website/pages/documentation/api/utilities/on_kernel_start.py b/solara/website/pages/documentation/api/utilities/on_kernel_start.py index e36b1dd68..fcd1129f4 100644 --- a/solara/website/pages/documentation/api/utilities/on_kernel_start.py +++ b/solara/website/pages/documentation/api/utilities/on_kernel_start.py @@ -4,7 +4,8 @@ Run a function when a virtual kernel (re)starts and optionally run a cleanup function on shutdown. ```python -def on_kernel_start(f: Callable[[], Optional[Callable[[], None]]]): +def on_kernel_start(f: Callable[[], Optional[Callable[[], None]]]) -> Callable[[], None]: + ... ``` `f` will be called on each virtual kernel (re)start. This (usually) happens each time a browser tab connects to the server @@ -12,7 +13,12 @@ def on_kernel_start(f: Callable[[], Optional[Callable[[], None]]]): The (optional) function returned by `f` will be called on kernel shutdown. Note that the cleanup functions are called in reverse order with respect to the order in which they were registered -(e.g. the cleanup function of the last call to `on_kernel_start` will be called first on kernel shutdown) +(e.g. the cleanup function of the last call to `on_kernel_start` will be called first on kernel shutdown). + +The return value of on_kernel_start is a cleanup function that will remove the callback from the list of callbacks to be called on kernel start. + +During hot reload, the callbacks that are added from scripts or modules that will be reloaded will be removed before the app is loaded +again. This can cause the order of the callbacks to be different than at first run. """ from solara.website.components import NoPage diff --git a/tests/unit/reload_test.py b/tests/unit/reload_test.py new file mode 100644 index 000000000..9bf934ba6 --- /dev/null +++ b/tests/unit/reload_test.py @@ -0,0 +1,56 @@ +import shutil +from pathlib import Path + +import pytest + +import solara.lab +import solara.server.kernel_context +from solara.server import reload +from solara.server.app import AppScript + +HERE = Path(__file__).parent + +kernel_start_path = HERE / "solara_test_apps" / "kernel_start.py" + + +@pytest.mark.parametrize("as_module", [False, True]) +def test_script_reload_component(tmpdir, kernel_context, extra_include_path, no_kernel_context, as_module): + + target = Path(tmpdir) / "kernel_start.py" + shutil.copy(kernel_start_path, target) + with extra_include_path(str(tmpdir)): + on_kernel_start_callbacks = solara.server.kernel_context._on_kernel_start_callbacks.copy() + callbacks_start = [k.callback for k in solara.server.kernel_context._on_kernel_start_callbacks] + if as_module: + app = AppScript(f"{target.stem}") + else: + app = AppScript(f"{target}") + try: + app.run() + callback = app.routes[0].module.test_callback # type: ignore + callbacks = [k.callback for k in solara.server.kernel_context._on_kernel_start_callbacks] + assert callbacks == [*callbacks_start, callback] + prev = callbacks.copy() + reload.reloader.reload_event_next.clear() + target.touch() + # wait for the event to trigger + reload.reloader.reload_event_next.wait() + app.run() + callback = app.routes[0].module.test_callback # type: ignore + callbacks = [k[0] for k in solara.server.kernel_context._on_kernel_start_callbacks] + assert callbacks != prev + assert callbacks == [*callbacks_start, callback] + finally: + app.close() + solara.server.kernel_context._on_kernel_start_callbacks.clear() + solara.server.kernel_context._on_kernel_start_callbacks.extend(on_kernel_start_callbacks) + + +def test_on_kernel_start_cleanup(kernel_context, no_kernel_context): + def test_callback_cleanup(): + pass + + cleanup = solara.lab.on_kernel_start(test_callback_cleanup) + assert test_callback_cleanup in [k.callback for k in solara.server.kernel_context._on_kernel_start_callbacks] + cleanup() + assert test_callback_cleanup not in [k.callback for k in solara.server.kernel_context._on_kernel_start_callbacks] diff --git a/tests/unit/solara_test_apps/kernel_start.py b/tests/unit/solara_test_apps/kernel_start.py new file mode 100644 index 000000000..adbf577ff --- /dev/null +++ b/tests/unit/solara_test_apps/kernel_start.py @@ -0,0 +1,14 @@ +import solara +import solara.lab + + +def test_callback(): + pass + + +solara.lab.on_kernel_start(test_callback) + + +@solara.component +def Page(): + solara.Text("Hello, World!")