Skip to content

Commit

Permalink
Merge pull request #68 from widgetti/feat_output_widget_support
Browse files Browse the repository at this point in the history
Feat: output widget support
  • Loading branch information
maartenbreddels authored Apr 14, 2023
2 parents 61348e0 + 6802aa5 commit 67a4e1e
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 2 deletions.
13 changes: 12 additions & 1 deletion solara/server/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from zmq.eventloop.zmqstream import ZMQStream

import solara
from solara.server.shell import SolaraInteractiveShell

from . import settings, websocket

Expand Down Expand Up @@ -257,12 +258,14 @@ def __init__(self):
ipywidgets.widgets.widget.Widget.comm.klass = Comm
else:
self.comm_manager = CommManager(parent=self, kernel=self)
self.shell = None
self.log = logging.getLogger("fake")

comm_msg_types = ["comm_open", "comm_msg", "comm_close"]
for msg_type in comm_msg_types:
self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type)
self.shell = SolaraInteractiveShell()
self.shell.display_pub.session = self.session
self.shell.display_pub.pub_socket = self.iopub_socket

async def _flush_control_queue(self):
pass
Expand All @@ -275,3 +278,11 @@ def pre_handler_hook(self, *args):

def post_handler_hook(self, *args):
pass

def set_parent(self, ident, parent, channel="shell"):
"""Overridden from parent to tell the display hook and output streams
about the parent message.
"""
super().set_parent(ident, parent, channel)
if channel == "shell":
self.shell.set_parent(parent)
30 changes: 29 additions & 1 deletion solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import ipykernel.kernelbase
import IPython.display
import ipywidgets
import ipywidgets.widgets.widget_output
from IPython.core.interactiveshell import InteractiveShell

from . import app, reload, settings
from .utils import pdb_guard
Expand All @@ -28,7 +30,7 @@ class FakeIPython:
def __init__(self, context: app.AppContext):
self.context = context
self.kernel = context.kernel
self.display_pub = mock.MagicMock()
self.display_pub = self.kernel.shell.display_pub
# needed for the pyplot interface of matplotlib
# (although we don't really support it)
self.events = mock.MagicMock()
Expand Down Expand Up @@ -68,6 +70,11 @@ def kernel_instance_dispatch(cls, *args, **kwargs):
return context.kernel


def interactive_shell_instance_dispatch(cls, *args, **kwargs):
context = app.get_current_context()
return context.kernel.shell


def kernel_initialized_dispatch(cls):
try:
app.get_current_context()
Expand Down Expand Up @@ -222,6 +229,22 @@ def Thread_debug_run(self):
_patched = False


def Output_enter(self):
self._flush()

def hook(msg):
if msg["msg_type"] == "display_data":
self.outputs += ({"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]},)
return None
return msg

get_ipython().display_pub.register_hook(hook)


def Output_exit(self, exc_type, exc_value, traceback):
get_ipython().display_pub._hooks.pop()


def patch():
global _patched
if _patched:
Expand Down Expand Up @@ -261,14 +284,19 @@ def patch():
# variable has type "Callable[[VarArg(Any), KwArg(Any)], Any]")
# not sure why we cannot reproduce that locally
ipykernel.kernelbase.Kernel.instance = classmethod(kernel_instance_dispatch) # type: ignore
InteractiveShell.instance = classmethod(interactive_shell_instance_dispatch) # type: ignore
# on CI we get a mypy error:
# solara/server/patch.py:211: error: Cannot assign to a method
# solara/server/patch.py:211: error: Incompatible types in assignment (expression has type "classmethod[Any]", variable has type "Callable[[], Any]")
# not sure why we cannot reproduce that locally
ipykernel.kernelbase.Kernel.initialized = classmethod(kernel_initialized_dispatch) # type: ignore
ipywidgets.widgets.widget.get_ipython = get_ipython
# TODO: find a way to actually monkeypatch get_ipython
IPython.get_ipython = get_ipython

ipywidgets.widgets.widget_output.Output.__enter__ = Output_enter
ipywidgets.widgets.widget_output.Output.__exit__ = Output_exit

def model_id_debug(self: ipywidgets.widgets.widget.Widget):
from ipyvue.ForceLoad import force_load_instance

Expand Down
206 changes: 206 additions & 0 deletions solara/server/shell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import sys
from threading import local
from unittest.mock import Mock

import reacton.patch_display
from IPython.core.displaypub import DisplayPublisher
from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
from jupyter_client.session import Session, extract_header
from traitlets import Any, CBytes, Dict, Instance, Type, default


def encode_images(obj):
# no-op in ipykernel
return obj


def json_clean(obj):
# no-op in ipykernel
return obj


# based on the zmq display publisher from ipykernel
# ideally this goes out of ipykernel
class SolaraDisplayPublisher(DisplayPublisher):
"""A display publisher that publishes data using a ZeroMQ PUB socket."""

session = Instance(Session, allow_none=True)
pub_socket = Any(allow_none=True)
parent_header = Dict({})
topic = CBytes(b"display_data")

_thread_local = Any()

def set_parent(self, parent):
"""Set the parent for outbound messages."""
self.parent_header = extract_header(parent)

def _flush_streams(self):
"""flush IO Streams prior to display"""
sys.stdout.flush()
sys.stderr.flush()

@default("_thread_local")
def _default_thread_local(self):
"""Initialize our thread local storage"""
return local()

@property
def _hooks(self):
if not hasattr(self._thread_local, "hooks"):
# create new list for a new thread
self._thread_local.hooks = []
return self._thread_local.hooks

def publish(
self,
data,
metadata=None,
transient=None,
update=False,
):
"""Publish a display-data message
Parameters
----------
data : dict
A mime-bundle dict, keyed by mime-type.
metadata : dict, optional
Metadata associated with the data.
transient : dict, optional, keyword-only
Transient data that may only be relevant during a live display,
such as display_id.
Transient data should not be persisted to documents.
update : bool, optional, keyword-only
If True, send an update_display_data message instead of display_data.
"""
self._flush_streams()
if metadata is None:
metadata = {}
if transient is None:
transient = {}
self._validate_data(data, metadata)
content = {}
content["data"] = encode_images(data)
content["metadata"] = metadata
content["transient"] = transient

msg_type = "update_display_data" if update else "display_data"

# Use 2-stage process to send a message,
# in order to put it through the transform
# hooks before potentially sending.
msg = self.session.msg(msg_type, json_clean(content), parent=self.parent_header)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_socket,
msg,
ident=self.topic,
)

def clear_output(self, wait=False):
"""Clear output associated with the current execution (cell).
Parameters
----------
wait : bool (default: False)
If True, the output will not be cleared immediately,
instead waiting for the next display before clearing.
This reduces bounce during repeated clear & display loops.
"""
content = dict(wait=wait)
self._flush_streams()
self.session.send(
self.pub_socket,
"clear_output",
content,
parent=self.parent_header,
ident=self.topic,
)

def register_hook(self, hook):
"""
Registers a hook with the thread-local storage.
Parameters
----------
hook : Any callable object
Returns
-------
Either a publishable message, or `None`.
The DisplayHook objects must return a message from
the __call__ method if they still require the
`session.send` method to be called after transformation.
Returning `None` will halt that execution path, and
session.send will not be called.
"""
self._hooks.append(hook)

def unregister_hook(self, hook):
"""
Un-registers a hook with the thread-local storage.
Parameters
----------
hook : Any callable object which has previously been
registered as a hook.
Returns
-------
bool - `True` if the hook was removed, `False` if it wasn't
found.
"""
try:
self._hooks.remove(hook)
return True
except ValueError:
return False


class SolaraInteractiveShell(InteractiveShell):
display_pub_class = Type(SolaraDisplayPublisher)
history_manager = Any() # type: ignore

def set_parent(self, parent):
"""Tell the children about the parent message."""
self.display_pub.set_parent(parent)

def init_history(self):
self.history_manager = Mock() # type: ignore

def init_display_formatter(self):
super().init_display_formatter()
self.display_formatter.ipython_display_formatter = reacton.patch_display.ReactonDisplayFormatter()

def init_display_pub(self):
super().init_display_pub()
self.display_pub.register_hook(self.display_in_reacton_hook)

def display_in_reacton_hook(self, msg):
"""Will intercept a display call and add the display data to an output widget when in a reacton context/render function."""
# similar to reacton.patch_display.publish
from reacton.core import get_render_context

rc = get_render_context(required=False)
# only during the render phase we want to capture the display calls
# during the reconsolidation phase we want to let the original display publisher do its thing
# such as adding it to a output widget
if rc is not None and not rc.reconsolidating and msg["msg_type"] == "display_data":
from reacton.ipywidgets import Output

Output(outputs=[{"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]}])
return None # do not send to the frontend
return msg


InteractiveShellABC.register(SolaraInteractiveShell)
25 changes: 25 additions & 0 deletions tests/unit/shell_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from unittest.mock import Mock

import IPython.display

from solara.server import app, kernel


def test_shell(no_app_context):
ws1 = Mock()
ws2 = Mock()
kernel1 = kernel.Kernel()
kernel2 = kernel.Kernel()
kernel1.session.websockets.add(ws1)
kernel2.session.websockets.add(ws2)
context1 = app.AppContext(id="1", kernel=kernel1)
context2 = app.AppContext(id="2", kernel=kernel2)

with context1:
IPython.display.display("test1")
assert ws1.send.call_count == 1
assert ws2.send.call_count == 0
with context2:
IPython.display.display("test1")
assert ws1.send.call_count == 1
assert ws2.send.call_count == 1

0 comments on commit 67a4e1e

Please sign in to comment.