diff --git a/solara/server/kernel.py b/solara/server/kernel.py index 51109fc60..fff81914a 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -16,6 +16,7 @@ from zmq.eventloop.zmqstream import ZMQStream import solara +from solara.server.shell import SolaraInteractiveShell from . import settings, websocket @@ -257,12 +258,14 @@ def __init__(self): ipywidgets.widgets.widget.Widget.comm.klass = Comm else: self.comm_manager = CommManager(parent=self, kernel=self) - self.shell = None self.log = logging.getLogger("fake") comm_msg_types = ["comm_open", "comm_msg", "comm_close"] for msg_type in comm_msg_types: self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type) + self.shell = SolaraInteractiveShell() + self.shell.display_pub.session = self.session + self.shell.display_pub.pub_socket = self.iopub_socket async def _flush_control_queue(self): pass @@ -275,3 +278,11 @@ def pre_handler_hook(self, *args): def post_handler_hook(self, *args): pass + + def set_parent(self, ident, parent, channel="shell"): + """Overridden from parent to tell the display hook and output streams + about the parent message. + """ + super().set_parent(ident, parent, channel) + if channel == "shell": + self.shell.set_parent(parent) diff --git a/solara/server/patch.py b/solara/server/patch.py index 2928c779c..c5f898ea4 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -10,6 +10,8 @@ import ipykernel.kernelbase import IPython.display import ipywidgets +import ipywidgets.widgets.widget_output +from IPython.core.interactiveshell import InteractiveShell from . import app, reload, settings from .utils import pdb_guard @@ -28,7 +30,7 @@ class FakeIPython: def __init__(self, context: app.AppContext): self.context = context self.kernel = context.kernel - self.display_pub = mock.MagicMock() + self.display_pub = self.kernel.shell.display_pub # needed for the pyplot interface of matplotlib # (although we don't really support it) self.events = mock.MagicMock() @@ -68,6 +70,11 @@ def kernel_instance_dispatch(cls, *args, **kwargs): return context.kernel +def interactive_shell_instance_dispatch(cls, *args, **kwargs): + context = app.get_current_context() + return context.kernel.shell + + def kernel_initialized_dispatch(cls): try: app.get_current_context() @@ -222,6 +229,22 @@ def Thread_debug_run(self): _patched = False +def Output_enter(self): + self._flush() + + def hook(msg): + if msg["msg_type"] == "display_data": + self.outputs += ({"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]},) + return None + return msg + + get_ipython().display_pub.register_hook(hook) + + +def Output_exit(self, exc_type, exc_value, traceback): + get_ipython().display_pub._hooks.pop() + + def patch(): global _patched if _patched: @@ -261,14 +284,19 @@ def patch(): # variable has type "Callable[[VarArg(Any), KwArg(Any)], Any]") # not sure why we cannot reproduce that locally ipykernel.kernelbase.Kernel.instance = classmethod(kernel_instance_dispatch) # type: ignore + InteractiveShell.instance = classmethod(interactive_shell_instance_dispatch) # type: ignore # on CI we get a mypy error: # solara/server/patch.py:211: error: Cannot assign to a method # solara/server/patch.py:211: error: Incompatible types in assignment (expression has type "classmethod[Any]", variable has type "Callable[[], Any]") # not sure why we cannot reproduce that locally ipykernel.kernelbase.Kernel.initialized = classmethod(kernel_initialized_dispatch) # type: ignore ipywidgets.widgets.widget.get_ipython = get_ipython + # TODO: find a way to actually monkeypatch get_ipython IPython.get_ipython = get_ipython + ipywidgets.widgets.widget_output.Output.__enter__ = Output_enter + ipywidgets.widgets.widget_output.Output.__exit__ = Output_exit + def model_id_debug(self: ipywidgets.widgets.widget.Widget): from ipyvue.ForceLoad import force_load_instance diff --git a/solara/server/shell.py b/solara/server/shell.py new file mode 100644 index 000000000..2a8efce2b --- /dev/null +++ b/solara/server/shell.py @@ -0,0 +1,206 @@ +import sys +from threading import local +from unittest.mock import Mock + +import reacton.patch_display +from IPython.core.displaypub import DisplayPublisher +from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC +from jupyter_client.session import Session, extract_header +from traitlets import Any, CBytes, Dict, Instance, Type, default + + +def encode_images(obj): + # no-op in ipykernel + return obj + + +def json_clean(obj): + # no-op in ipykernel + return obj + + +# based on the zmq display publisher from ipykernel +# ideally this goes out of ipykernel +class SolaraDisplayPublisher(DisplayPublisher): + """A display publisher that publishes data using a ZeroMQ PUB socket.""" + + session = Instance(Session, allow_none=True) + pub_socket = Any(allow_none=True) + parent_header = Dict({}) + topic = CBytes(b"display_data") + + _thread_local = Any() + + def set_parent(self, parent): + """Set the parent for outbound messages.""" + self.parent_header = extract_header(parent) + + def _flush_streams(self): + """flush IO Streams prior to display""" + sys.stdout.flush() + sys.stderr.flush() + + @default("_thread_local") + def _default_thread_local(self): + """Initialize our thread local storage""" + return local() + + @property + def _hooks(self): + if not hasattr(self._thread_local, "hooks"): + # create new list for a new thread + self._thread_local.hooks = [] + return self._thread_local.hooks + + def publish( + self, + data, + metadata=None, + transient=None, + update=False, + ): + """Publish a display-data message + + Parameters + ---------- + data : dict + A mime-bundle dict, keyed by mime-type. + metadata : dict, optional + Metadata associated with the data. + transient : dict, optional, keyword-only + Transient data that may only be relevant during a live display, + such as display_id. + Transient data should not be persisted to documents. + update : bool, optional, keyword-only + If True, send an update_display_data message instead of display_data. + """ + self._flush_streams() + if metadata is None: + metadata = {} + if transient is None: + transient = {} + self._validate_data(data, metadata) + content = {} + content["data"] = encode_images(data) + content["metadata"] = metadata + content["transient"] = transient + + msg_type = "update_display_data" if update else "display_data" + + # Use 2-stage process to send a message, + # in order to put it through the transform + # hooks before potentially sending. + msg = self.session.msg(msg_type, json_clean(content), parent=self.parent_header) + + # Each transform either returns a new + # message or None. If None is returned, + # the message has been 'used' and we return. + for hook in self._hooks: + msg = hook(msg) + if msg is None: + return + + self.session.send( + self.pub_socket, + msg, + ident=self.topic, + ) + + def clear_output(self, wait=False): + """Clear output associated with the current execution (cell). + + Parameters + ---------- + wait : bool (default: False) + If True, the output will not be cleared immediately, + instead waiting for the next display before clearing. + This reduces bounce during repeated clear & display loops. + + """ + content = dict(wait=wait) + self._flush_streams() + self.session.send( + self.pub_socket, + "clear_output", + content, + parent=self.parent_header, + ident=self.topic, + ) + + def register_hook(self, hook): + """ + Registers a hook with the thread-local storage. + + Parameters + ---------- + hook : Any callable object + + Returns + ------- + Either a publishable message, or `None`. + The DisplayHook objects must return a message from + the __call__ method if they still require the + `session.send` method to be called after transformation. + Returning `None` will halt that execution path, and + session.send will not be called. + """ + self._hooks.append(hook) + + def unregister_hook(self, hook): + """ + Un-registers a hook with the thread-local storage. + + Parameters + ---------- + hook : Any callable object which has previously been + registered as a hook. + + Returns + ------- + bool - `True` if the hook was removed, `False` if it wasn't + found. + """ + try: + self._hooks.remove(hook) + return True + except ValueError: + return False + + +class SolaraInteractiveShell(InteractiveShell): + display_pub_class = Type(SolaraDisplayPublisher) + history_manager = Any() # type: ignore + + def set_parent(self, parent): + """Tell the children about the parent message.""" + self.display_pub.set_parent(parent) + + def init_history(self): + self.history_manager = Mock() # type: ignore + + def init_display_formatter(self): + super().init_display_formatter() + self.display_formatter.ipython_display_formatter = reacton.patch_display.ReactonDisplayFormatter() + + def init_display_pub(self): + super().init_display_pub() + self.display_pub.register_hook(self.display_in_reacton_hook) + + def display_in_reacton_hook(self, msg): + """Will intercept a display call and add the display data to an output widget when in a reacton context/render function.""" + # similar to reacton.patch_display.publish + from reacton.core import get_render_context + + rc = get_render_context(required=False) + # only during the render phase we want to capture the display calls + # during the reconsolidation phase we want to let the original display publisher do its thing + # such as adding it to a output widget + if rc is not None and not rc.reconsolidating and msg["msg_type"] == "display_data": + from reacton.ipywidgets import Output + + Output(outputs=[{"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]}]) + return None # do not send to the frontend + return msg + + +InteractiveShellABC.register(SolaraInteractiveShell) diff --git a/tests/unit/shell_test.py b/tests/unit/shell_test.py new file mode 100644 index 000000000..c1de3e4ef --- /dev/null +++ b/tests/unit/shell_test.py @@ -0,0 +1,25 @@ +from unittest.mock import Mock + +import IPython.display + +from solara.server import app, kernel + + +def test_shell(no_app_context): + ws1 = Mock() + ws2 = Mock() + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + kernel1.session.websockets.add(ws1) + kernel2.session.websockets.add(ws2) + context1 = app.AppContext(id="1", kernel=kernel1) + context2 = app.AppContext(id="2", kernel=kernel2) + + with context1: + IPython.display.display("test1") + assert ws1.send.call_count == 1 + assert ws2.send.call_count == 0 + with context2: + IPython.display.display("test1") + assert ws1.send.call_count == 1 + assert ws2.send.call_count == 1