diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index dae1deab2..004bdde0d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -242,6 +242,8 @@ jobs: steps: - uses: actions/checkout@v4 + - uses: ts-graphviz/setup-graphviz@v1 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/packages/solara-meta/pyproject.toml b/packages/solara-meta/pyproject.toml index 1792a3f88..52ea342cd 100644 --- a/packages/solara-meta/pyproject.toml +++ b/packages/solara-meta/pyproject.toml @@ -71,6 +71,7 @@ dev = [ "types-requests", "types-markdown", "types-PyYAML", + "objgraph", "pytest", "pytest-mock", "pytest-cov", diff --git a/solara/components/markdown.py b/solara/components/markdown.py index fcb1b6ecc..7ebf8df84 100644 --- a/solara/components/markdown.py +++ b/solara/components/markdown.py @@ -4,10 +4,9 @@ import textwrap import traceback import warnings -from typing import Any, Dict, List, Union, cast +from typing import Any, Callable, Dict, List, Union, cast import ipyvuetify as v - try: import pymdownx.emoji import pymdownx.highlight @@ -16,6 +15,7 @@ has_pymdownx = True except ModuleNotFoundError: has_pymdownx = False +import reacton.core import solara import solara.components.applayout @@ -50,7 +50,7 @@ def ExceptionGuard(children=[]): solara.Column(children=children) -def _run_solara(code): +def _run_solara(code, cleanups): ast = compile(code, "markdown", "exec") local_scope: Dict[Any, Any] = {} exec(ast, local_scope) @@ -63,6 +63,13 @@ def _run_solara(code): else: raise NameError("No Page of app defined") box = v.Html(tag="div") + + rc: reacton.core.RenderContext + + def cleanup(): + rc.close() + + cleanups.append(cleanup) box, rc = solara.render(cast(solara.Element, app), container=box) # type: ignore widget_id = box._model_id return ( @@ -224,7 +231,7 @@ def _markdown_template( return template -def _highlight(src, language, unsafe_solara_execute, extra, *args, **kwargs): +def _highlight(cleanups, src, language, unsafe_solara_execute, extra, *args, **kwargs): """Highlight a block of code""" if not has_pygments: @@ -243,7 +250,7 @@ def _highlight(src, language, unsafe_solara_execute, extra, *args, **kwargs): if run_src_with_solara: if unsafe_solara_execute: - html_widget = _run_solara(src) + html_widget = _run_solara(src, cleanups) return src_html + html_widget else: return src_html + html_no_execute_enabled @@ -260,8 +267,10 @@ def MarkdownIt(md_text: str, highlight: List[int] = [], unsafe_solara_execute: b from mdit_py_plugins.footnote import footnote_plugin # noqa: F401 from mdit_py_plugins.front_matter import front_matter_plugin # noqa: F401 + cleanups = solara.use_ref(cast(List[Callable[[], None]], [])) + def highlight_code(code, name, attrs): - return _highlight(code, name, unsafe_solara_execute, attrs) + return _highlight(cleanups.current, code, name, unsafe_solara_execute, attrs) md = MarkdownItMod( "js-default", @@ -274,6 +283,15 @@ def highlight_code(code, name, attrs): md = md.use(container.container_plugin, name="note") html = md.render(md_text) hash = hashlib.sha256((html + str(unsafe_solara_execute) + repr(highlight)).encode("utf-8")).hexdigest() + + def cleanup_wrapper(): + def cleanup(): + for cleanup in cleanups.current: + cleanup() + + return cleanup + + solara.use_effect(cleanup_wrapper) return v.VuetifyTemplate.element(template=_markdown_template(html)).key(hash) @@ -332,11 +350,12 @@ def Page(): md_text = textwrap.dedent(md_text) style = solara.util._flatten_style(style) + cleanups = solara.use_ref(cast(List[Callable[[], None]], [])) def make_markdown_object(): def highlight(src, language, *args, **kwargs): try: - return _highlight(src, language, unsafe_solara_execute, *args, **kwargs) + return _highlight(cleanups.current, src, language, unsafe_solara_execute, *args, **kwargs) except Exception as e: logger.exception("Error highlighting code: %s", src) return repr(e) @@ -372,6 +391,16 @@ def highlight(src, language, *args, **kwargs): md = solara.use_memo(make_markdown_object, dependencies=[unsafe_solara_execute]) html = md.convert(md_text) + + def cleanup_wrapper(): + def cleanup(): + for cleanup in cleanups.current: + cleanup() + + return cleanup + + solara.use_effect(cleanup_wrapper) + # if we update the template value, the whole vue tree will rerender (ipvue/ipyvuetify issue) # however, using the hash we simply generate a new widget each time hash = hashlib.sha256((html + str(unsafe_solara_execute)).encode("utf-8")).hexdigest() diff --git a/solara/server/app.py b/solara/server/app.py index 59f37363b..a9df0aaa7 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -7,6 +7,7 @@ import threading import traceback import warnings +import weakref from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, cast @@ -132,7 +133,7 @@ def add_path(): else: # the module itself will be added by reloader # automatically - with reload.reloader.watch(): + with kernel_context.without_context(), reload.reloader.watch(): self.type = AppType.MODULE try: spec = importlib.util.find_spec(self.name) @@ -420,6 +421,9 @@ def solara_comm_target(comm, msg_first): def on_msg(msg): nonlocal app + comm = comm_ref() + assert comm is not None + context = kernel_context.get_current_context() data = msg["content"]["data"] method = data["method"] if method == "run": @@ -435,7 +439,12 @@ def on_msg(msg): themes = args.get("themes") dark = args.get("dark") load_themes(themes, dark) - load_app_widget(None, app, path) + try: + load_app_widget(None, app, path) + except Exception as e: + msg = f"Error loading app: from path {path} and app {app_name}" + logger.exception(msg) + raise RuntimeError(msg) from e comm.send({"method": "finished", "widget_id": context.container._model_id}) elif method == "app-status": context = kernel_context.get_current_context() @@ -464,9 +473,10 @@ def on_msg(msg): else: logger.error("Unknown comm method called on solara.control comm: %s", method) - comm.on_msg(on_msg) - def reload(): + comm = comm_ref() + assert comm is not None + context = kernel_context.get_current_context() # we don't reload the app ourself, we send a message to the client # this ensures that we don't run code of any client that for some reason is connected # but not working anymore. And it indirectly passes a message from the current thread @@ -474,8 +484,11 @@ def reload(): logger.debug(f"Send reload to client: {context.id}") comm.send({"method": "reload"}) - context = kernel_context.get_current_context() - context.reload = reload + comm.on_msg(on_msg) + comm_ref = weakref.ref(comm) + del comm + + kernel_context.get_current_context().reload = reload def register_solara_comm_target(kernel: Kernel): diff --git a/solara/server/kernel.py b/solara/server/kernel.py index 82ffd60ad..fb5986b96 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -247,6 +247,8 @@ def send( header=None, metadata=None, ): + if stream is None: + return # can happen when the kernel is closed but someone was still trying to send a message try: if isinstance(msg_or_type, dict): msg = msg_or_type @@ -313,6 +315,39 @@ def __init__(self): self.shell.display_pub.session = self.session self.shell.display_pub.pub_socket = self.iopub_socket + def close(self): + if self.comm_manager is None: + raise RuntimeError("Kernel already closed") + self.session.close() + self._cleanup_references() + + def _cleanup_references(self): + try: + # all of these reduce the circular references + # making it easier for the garbage collector to clean up + self.shell_handlers.clear() + self.control_handlers.clear() + for comm_object in list(self.comm_manager.comms.values()): # type: ignore + comm_object.close() + self.comm_manager.targets.clear() # type: ignore + # self.comm_manager.kernel points to us, but we cannot set it to None + # so we remove the circular reference by setting the comm_manager to None + self.comm_manager = None # type: ignore + self.session.parent = None # type: ignore + + self.shell.display_pub.session = None # type: ignore + self.shell.display_pub.pub_socket = None # type: ignore + del self.shell.__dict__ + self.shell = None # type: ignore + self.session.websockets.clear() + self.session.stream = None # type: ignore + self.session = None # type: ignore + self.stream.session = None # type: ignore + self.stream = None # type: ignore + self.iopub_socket = None # type: ignore + except Exception: + logger.exception("Error cleaning up references from kernel, not fatal") + async def _flush_control_queue(self): pass diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index 19a115bea..f53104448 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -6,6 +6,8 @@ except ModuleNotFoundError: contextvars = None # type: ignore +import concurrent.futures +import contextlib import dataclasses import enum import logging @@ -71,6 +73,7 @@ class VirtualKernelContext: page_status: Dict[str, PageStatus] = dataclasses.field(default_factory=dict) # only used for testing _last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None + _last_kernel_cull_future: "Optional[concurrent.futures.Future[None]]" = None closed_event: threading.Event = dataclasses.field(default_factory=threading.Event) _on_close_callbacks: List[Callable[[], None]] = dataclasses.field(default_factory=list) lock: threading.RLock = dataclasses.field(default_factory=threading.RLock) @@ -89,6 +92,7 @@ def restart(self): f() self._on_close_callbacks.clear() self.__post_init__() + lock: threading.RLock = dataclasses.field(default_factory=threading.RLock) def display(self, *args): print(args) # noqa @@ -112,6 +116,10 @@ def close(self): with self, self.lock: for key in self.page_status: self.page_status[key] = PageStatus.CLOSED + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + if self._last_kernel_cull_future: + self._last_kernel_cull_future.cancel() if self.closed_event.is_set(): logger.error("Tried to close a kernel context that is already closed: %s", self.id) return @@ -129,9 +137,11 @@ def close(self): # what if we reference each other # import gc # gc.collect() - self.kernel.session.close() + self.kernel.close() + self.kernel = None # type: ignore if self.id in contexts: del contexts[self.id] + del current_context[get_current_thread_key()] self.closed_event.set() def _state_reset(self): @@ -158,6 +168,8 @@ def state_save(self, state_directory: os.PathLike): pickle.dump(state, f) def page_connect(self, page_id: str): + if self.closed_event.is_set(): + raise RuntimeError("Cannot connect a page to a closed kernel") logger.info("Connect page %s for kernel %s", page_id, self.id) with self.lock: if self.closed_event.is_set(): @@ -184,13 +196,19 @@ async def kernel_cull(): logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id) self.close() if current_event_loop is not None and future is not None: - current_event_loop.call_soon_threadsafe(future.set_result, None) + try: + current_event_loop.call_soon_threadsafe(future.set_result, None) + except RuntimeError: + pass # event loop already closed, happens during testing except asyncio.CancelledError: if current_event_loop is not None and future is not None: - if sys.version_info >= (3, 9): - current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") - else: - current_event_loop.call_soon_threadsafe(future.cancel) + try: + if sys.version_info >= (3, 9): + current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") + else: + current_event_loop.call_soon_threadsafe(future.cancel) + except RuntimeError: + pass # event loop already closed, happens during testing raise async def create_task(): @@ -212,7 +230,16 @@ async def create_task(): self._last_kernel_cull_task.cancel() logger.info("Scheduling kernel cull for virtual kernel %s", self.id) - asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) + async def create_task(): + task = asyncio.create_task(kernel_cull()) + # create a reference to the task so we can cancel it later + self._last_kernel_cull_task = task + try: + await task + except RuntimeError: + pass # event loop already closed, happens during testing + + self._last_kernel_cull_future = asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) return future def page_disconnect(self, page_id: str) -> "Optional[asyncio.Future[None]]": @@ -259,7 +286,11 @@ def page_close(self, page_id: str): pass else: future.set_result(None) + + logger.info("page status: %s", self.page_status) with self.lock: + if self.closed_event.is_set(): + raise RuntimeError("Cannot connect a page to a closed kernel") if self.page_status[page_id] == PageStatus.CLOSED: logger.info("Page %s already closed for kernel %s", page_id, self.id) return @@ -351,6 +382,11 @@ def set_context_for_thread(context: VirtualKernelContext, thread: threading.Thre current_context[key] = context +def clear_context_for_thread(thread: threading.Thread): + key = get_thread_key(thread) + current_context.pop(key, None) + + def has_current_context() -> bool: thread_key = get_current_thread_key() return (thread_key in current_context) and (current_context[thread_key] is not None) @@ -377,6 +413,21 @@ def set_current_context(context: Optional[VirtualKernelContext]): current_context[thread_key] = context +@contextlib.contextmanager +def without_context(): + context = None + try: + context = get_current_context() + except RuntimeError: + pass + thread_key = get_current_thread_key() + current_context[thread_key] = None + try: + yield + finally: + current_context[thread_key] = context + + def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websocket.WebsocketWrapper): from solara.server import app as appmodule diff --git a/solara/server/patch.py b/solara/server/patch.py index 83ea07469..c8cf4c1e8 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -15,6 +15,8 @@ import ipywidgets.widgets.widget_output from IPython.core.interactiveshell import InteractiveShell +import solara.util + from . import app, kernel_context, reload, settings from .utils import pdb_guard @@ -248,7 +250,8 @@ def auto_watch_get_template(get_template): def wrapper(abs_path): template = get_template(abs_path) - reload.reloader.watcher.add_file(abs_path) + with kernel_context.without_context(): + reload.reloader.watcher.add_file(abs_path) return template return wrapper @@ -271,10 +274,13 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs): ThreadDebugInfo.created += 1 self.current_context = None - try: - self.current_context = kernel_context.get_current_context() - except RuntimeError: - logger.debug(f"No context for thread {self}") + # if we do this for the dummy threads, we got into a recursion + # since threading.current_thread will call the _DummyThread constructor + if not ("name" in kwargs and "Dummy-" in kwargs["name"]): + try: + self.current_context = kernel_context.get_current_context() + except RuntimeError: + logger.debug(f"No context for thread {self._name}") def WidgetContextAwareThread__bootstrap(self): @@ -290,6 +296,7 @@ def WidgetContextAwareThread__bootstrap(self): def _WidgetContextAwareThread__bootstrap(self): if not hasattr(self, "current_context"): + # this happens when a thread was running before we patched return Thread__bootstrap(self) if self.current_context: # we need to call this manually, because set_context_for_thread @@ -298,15 +305,20 @@ def _WidgetContextAwareThread__bootstrap(self): if kernel_context.async_context_id is not None: kernel_context.async_context_id.set(self.current_context.id) kernel_context.set_context_for_thread(self.current_context, self) - shell = self.current_context.kernel.shell - shell.display_pub.register_hook(shell.display_in_reacton_hook) + display_pub = shell.display_pub + display_in_reacton_hook = shell.display_in_reacton_hook + display_pub.register_hook(display_in_reacton_hook) try: - with pdb_guard(): + context = self.current_context or solara.util.nullcontext() + with pdb_guard(), context: Thread__bootstrap(self) finally: - if self.current_context: - shell.display_pub.unregister_hook(shell.display_in_reacton_hook) + current_context = self.current_context + self.current_context = None + kernel_context.clear_context_for_thread(self) + if current_context: + display_pub.unregister_hook(display_in_reacton_hook) _patched = False @@ -352,6 +364,7 @@ def patch_ipyreact(): # make this a no-op, we'll create the widget when needed ipyreact.importmap._update_import_map = lambda: None + def patch_ipyvue_performance(): import functools from collections.abc import Iterable diff --git a/solara/server/server.py b/solara/server/server.py index 8d1fee774..7763c83b5 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -157,7 +157,8 @@ async def app_loop( message = await ws.receive() except websocket.WebSocketDisconnect: try: - context.kernel.session.websockets.remove(ws) + if context.kernel is not None and context.kernel.session is not None: + context.kernel.session.websockets.remove(ws) except KeyError: pass logger.debug("Disconnected") @@ -168,10 +169,15 @@ async def app_loop( else: msg = deserialize_binary_message(message) t1 = time.time() - if not process_kernel_messages(kernel, msg): - # if we shut down the kernel, we do not keep the page session alive - context.close() - return + # we don't want to have the kernel closed while we are processing a message + # therefore we use this mutex that is also used in the context.close method + with context.lock: + if context.closed_event.is_set(): + return + if not process_kernel_messages(kernel, msg): + # if we shut down the kernel, we do not keep the page session alive + context.close() + return t2 = time.time() if settings.main.timing: widgets_ids_after = set(patch.widgets) diff --git a/solara/server/shell.py b/solara/server/shell.py index dfabe31d0..151469dc6 100644 --- a/solara/server/shell.py +++ b/solara/server/shell.py @@ -1,3 +1,4 @@ +import atexit import io import sys from binascii import b2a_base64 @@ -180,10 +181,23 @@ class SolaraInteractiveShell(InteractiveShell): history_manager = Any() # type: ignore display_pub: SolaraDisplayPublisher + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + atexit.unregister(self.atexit_operations) + + magic = self.magics_manager.registry["ScriptMagics"] + atexit.unregister(magic.kill_bg_processes) + def set_parent(self, parent): """Tell the children about the parent message.""" self.display_pub.set_parent(parent) + def init_sys_modules(self): + pass # don't create a __main__, it will cause a mem leak + + def init_prefilter(self): + pass # avoid consuming memory + def init_history(self): self.history_manager = Mock() # type: ignore diff --git a/solara/website/pages/documentation/examples/general/live_update.py b/solara/website/pages/documentation/examples/general/live_update.py index 5302d0bcf..bfbdb36eb 100644 --- a/solara/website/pages/documentation/examples/general/live_update.py +++ b/solara/website/pages/documentation/examples/general/live_update.py @@ -32,6 +32,7 @@ def LiveUpdatingComponent(counter): """Component which will be redrawn whenever the counter value changes.""" fig, ax = plt.subplots() ax.plot(np.arange(10), np.random.random(10)) + plt.close(fig) solara.FigureMatplotlib(fig) diff --git a/tests/integration/memleak_test.py b/tests/integration/memleak_test.py new file mode 100644 index 000000000..f3ee962bd --- /dev/null +++ b/tests/integration/memleak_test.py @@ -0,0 +1,89 @@ +import gc +import threading +import time +import weakref +from pathlib import Path +from typing import Optional + +import objgraph +import playwright.sync_api +import pytest + +import solara +import solara.server.kernel_context + +HERE = Path(__file__).parent + + +set_value = None +context: Optional["solara.server.kernel_context.VirtualKernelContext"] = None + + +@pytest.fixture +def no_cull_timeout(): + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + solara.server.settings.kernel.cull_timeout = "0.0001s" + try: + yield + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous + + +def _scoped_test_memleak( + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, +): + with solara_app("solara.website.pages"): + page_session.goto(solara_server.base_url) + page_session.locator("text=Examples").first.wait_for() + assert len(solara.server.kernel_context.contexts) == 1 + context = weakref.ref(list(solara.server.kernel_context.contexts.values())[0]) + # we should not have created a new context + assert len(solara.server.kernel_context.contexts) == 1 + kernel = weakref.ref(context().kernel) + shell = weakref.ref(kernel().shell) + session = weakref.ref(kernel().session) + page_session.goto("about:blank") + if context()._last_kernel_cull_task: + if not context()._last_kernel_cull_task.done(): + event = threading.Event() + context()._last_kernel_cull_task.add_done_callback(lambda _: event.set()) + assert event.wait() + assert context().closed_event.wait(10) + if shell(): + del shell().__dict__ + return context, kernel, shell, session + + +def test_memleak( + pytestconfig, + request, + browser: playwright.sync_api.Browser, + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, + no_cull_timeout, +): + # for unknown reasons, del does not work in CI + context_ref, kernel_ref, shell_ref, session_ref = _scoped_test_memleak(page_session, solara_server, solara_app, extra_include_path) + + for i in range(200): + time.sleep(0.1) + for gen in [2, 1, 0]: + gc.collect(gen) + if context_ref() is None and kernel_ref() is None and shell_ref() is None and session_ref() is None: + break + else: + name = solara_server.__class__.__name__ + output_path = Path(pytestconfig.getoption("--output")) / f"mem-leak-{name}.pdf" + output_path.parent.mkdir(parents=True, exist_ok=True) + print("output to", output_path, output_path.resolve()) # noqa + objgraph.show_backrefs([context_ref(), kernel_ref(), shell_ref(), session_ref()], filename=str(output_path), max_depth=15, too_many=15) + + assert context_ref() is None + assert kernel_ref() is None + assert shell_ref() is None + assert session_ref() is None diff --git a/tests/unit/patch_test.py b/tests/unit/patch_test.py index 708b4f358..4dfbd00eb 100644 --- a/tests/unit/patch_test.py +++ b/tests/unit/patch_test.py @@ -22,9 +22,10 @@ def test_widget_dict(no_kernel_context): - kernel_shared = kernel.Kernel() - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, session_id="session-1") - context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel_shared, session_id="session-2") + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2, session_id="session-2") with context1: btn1 = widgets.Button(description="context1")