From 33711d0ad8357b5501a5a5259c1d8aaac4ba8935 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Thu, 14 Mar 2024 11:54:50 +0100 Subject: [PATCH] fix: on_kernel_start callbacks acculumated after hot reload Introduced in #471 We should remove the on_kernel_start callbacks on a hot reload, but not remove the ones added by the hot reload itself. --- solara/server/app.py | 32 +++++++++++++++++ solara/server/kernel_context.py | 39 ++++++++++++++++++--- solara/server/reload.py | 4 +-- tests/unit/reload_test.py | 33 +++++++++++++++++ tests/unit/solara_test_apps/kernel_start.py | 14 ++++++++ 5 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 tests/unit/reload_test.py create mode 100644 tests/unit/solara_test_apps/kernel_start.py diff --git a/solara/server/app.py b/solara/server/app.py index 9d1952e3e..5029cab38 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -51,6 +51,9 @@ def __init__(self, name, default_app_name="Page"): if reload.reloader.on_change: raise RuntimeError("Previous reloader still had a on_change attached, no cleanup?") reload.reloader.on_change = self.on_file_change + # create a snapshot of the current callbacks, so we can remove the ones we added + # so we don't keep adding them after hot reload + self._on_kernel_start_callbacks_before_run = kernel_context._on_kernel_start_callbacks.copy() self.app_name = default_app_name if ":" in self.fullname: @@ -249,6 +252,35 @@ def reload(self): solara.lab.toestand.ConnectionStore._type_counter.clear() + # we need to remove callbacks that are added in the app code + # which will be re-executed after the reload and we do not + # want to keep executing the old ones. + keep_on_kernel_start_callbacks = [] + for kc in kernel_context._on_kernel_start_callbacks.copy(): + callback, path, module = kc + will_reload = False + if module is not None: + module_name = module.__name__ + if module_name in reload.reloader.get_reload_module_names(): + will_reload = True + elif path is not None: + if str(path.resolve()).startswith(str(self.directory)): + will_reload = True + else: + logger.warning( + "script %s is not in the same directory as the app %s but is using on_kernel_start, " + "this might lead to multiple entries, and might indicate a bug.", + path, + self.directory, + ) + + if will_reload: + logger.info("reload: Removing on_kernel_start callback: %s (since it will be added when reloaded)", callback) + else: + keep_on_kernel_start_callbacks.append(kc) + kernel_context._on_kernel_start_callbacks.clear() + kernel_context._on_kernel_start_callbacks.extend(keep_on_kernel_start_callbacks) + context_values = list(kernel_context.contexts.values()) # save states into the context so the hot reload will # keep the same state diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index d4227a508..239871da3 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,13 +1,15 @@ import asyncio import dataclasses import enum +import inspect import logging import os import pickle import threading import time from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, cast +from types import FrameType, ModuleType +from typing import Any, Callable, Dict, List, NamedTuple, Optional, cast import ipywidgets as widgets import reacton @@ -36,11 +38,40 @@ class PageStatus(enum.Enum): CLOSED = "closed" -_on_kernel_start_callbacks: List[Callable[[], Optional[Callable[[], None]]]] = [] +class _on_kernel_callback_entry(NamedTuple): + f: Callable[[], Optional[Callable[[], None]]] + callpoint: Optional[Path] + module: Optional[ModuleType] + + +_on_kernel_start_callbacks: List[_on_kernel_callback_entry] = [] + + +def _find_root_module_frame() -> Optional[FrameType]: + # basically the module where the call stack origined from + current_frame = inspect.currentframe() + root_module_frame = None + + while current_frame is not None: + if current_frame.f_code.co_name == "": + root_module_frame = current_frame + break + current_frame = current_frame.f_back + + return root_module_frame def on_kernel_start(f: Callable[[], Optional[Callable[[], None]]]): - _on_kernel_start_callbacks.append(f) + root = _find_root_module_frame() + path: Optional[Path] = None + module: Optional[ModuleType] = None + if root is not None: + path_str = inspect.getsourcefile(root) + module = inspect.getmodule(root) + if path_str is not None: + path = Path(path_str) + + _on_kernel_start_callbacks.append(_on_kernel_callback_entry(f, path, module)) @dataclasses.dataclass @@ -74,7 +105,7 @@ class VirtualKernelContext: def __post_init__(self): with self: - for f in _on_kernel_start_callbacks: + for (f, *_) in _on_kernel_start_callbacks: cleanup = f() if cleanup: self.on_close(cleanup) diff --git a/solara/server/reload.py b/solara/server/reload.py index 2120d2f10..bbe48e67c 100644 --- a/solara/server/reload.py +++ b/solara/server/reload.py @@ -152,13 +152,13 @@ def start(self): self._first = False def _on_change(self, name): - # used for testing - self.reload_event_next.set() # flag that we need to reload all modules next time self.requires_reload = True # and forward callback if self.on_change: self.on_change(name) + # used for testing + self.reload_event_next.set() def close(self): self.watcher.close() diff --git a/tests/unit/reload_test.py b/tests/unit/reload_test.py new file mode 100644 index 000000000..4bdb325d2 --- /dev/null +++ b/tests/unit/reload_test.py @@ -0,0 +1,33 @@ +import shutil +from pathlib import Path + +import solara.server.kernel_context +from solara.server import reload +from solara.server.app import AppScript + +HERE = Path(__file__).parent + +kernel_start_path = HERE / "solara_test_apps" / "kernel_start.py" + + +def test_script_reload_component(tmpdir, kernel_context, extra_include_path, no_kernel_context): + + target = Path(tmpdir) / "kernel_start.py" + shutil.copy(kernel_start_path, target) + with extra_include_path(str(tmpdir)): + on_kernel_start_callbacks = solara.server.kernel_context._on_kernel_start_callbacks.copy() + app = AppScript(f"{target}") + try: + app.run() + callback = app.routes[0].module.test_callback # type: ignore + assert solara.server.kernel_context._on_kernel_start_callbacks == [*on_kernel_start_callbacks, callback] + prev = solara.server.kernel_context._on_kernel_start_callbacks.copy() + target.touch() + # wait for the event to trigger + reload.reloader.reload_event_next.wait() + app.run() + callback = app.routes[0].module.test_callback # type: ignore + assert solara.server.kernel_context._on_kernel_start_callbacks != prev + assert solara.server.kernel_context._on_kernel_start_callbacks == [*on_kernel_start_callbacks, callback] + finally: + app.close() diff --git a/tests/unit/solara_test_apps/kernel_start.py b/tests/unit/solara_test_apps/kernel_start.py new file mode 100644 index 000000000..adbf577ff --- /dev/null +++ b/tests/unit/solara_test_apps/kernel_start.py @@ -0,0 +1,14 @@ +import solara +import solara.lab + + +def test_callback(): + pass + + +solara.lab.on_kernel_start(test_callback) + + +@solara.component +def Page(): + solara.Text("Hello, World!")