Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: output widget support #68

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion solara/server/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from zmq.eventloop.zmqstream import ZMQStream

import solara
from solara.server.shell import SolaraInteractiveShell

from . import settings, websocket

Expand Down Expand Up @@ -257,12 +258,14 @@ def __init__(self):
ipywidgets.widgets.widget.Widget.comm.klass = Comm
else:
self.comm_manager = CommManager(parent=self, kernel=self)
self.shell = None
self.log = logging.getLogger("fake")

comm_msg_types = ["comm_open", "comm_msg", "comm_close"]
for msg_type in comm_msg_types:
self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type)
self.shell = SolaraInteractiveShell()
self.shell.display_pub.session = self.session
self.shell.display_pub.pub_socket = self.iopub_socket

async def _flush_control_queue(self):
pass
Expand All @@ -275,3 +278,11 @@ def pre_handler_hook(self, *args):

def post_handler_hook(self, *args):
pass

def set_parent(self, ident, parent, channel="shell"):
"""Overridden from parent to tell the display hook and output streams
about the parent message.
"""
super().set_parent(ident, parent, channel)
if channel == "shell":
self.shell.set_parent(parent)
30 changes: 29 additions & 1 deletion solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import ipykernel.kernelbase
import IPython.display
import ipywidgets
import ipywidgets.widgets.widget_output
from IPython.core.interactiveshell import InteractiveShell

from . import app, reload, settings
from .utils import pdb_guard
Expand All @@ -28,7 +30,7 @@ class FakeIPython:
def __init__(self, context: app.AppContext):
self.context = context
self.kernel = context.kernel
self.display_pub = mock.MagicMock()
self.display_pub = self.kernel.shell.display_pub
# needed for the pyplot interface of matplotlib
# (although we don't really support it)
self.events = mock.MagicMock()
Expand Down Expand Up @@ -68,6 +70,11 @@ def kernel_instance_dispatch(cls, *args, **kwargs):
return context.kernel


def interactive_shell_instance_dispatch(cls, *args, **kwargs):
context = app.get_current_context()
return context.kernel.shell


def kernel_initialized_dispatch(cls):
try:
app.get_current_context()
Expand Down Expand Up @@ -222,6 +229,22 @@ def Thread_debug_run(self):
_patched = False


def Output_enter(self):
self._flush()

def hook(msg):
if msg["msg_type"] == "display_data":
self.outputs += ({"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]},)
return None
return msg

get_ipython().display_pub.register_hook(hook)


def Output_exit(self, exc_type, exc_value, traceback):
get_ipython().display_pub._hooks.pop()


def patch():
global _patched
if _patched:
Expand Down Expand Up @@ -261,14 +284,19 @@ def patch():
# variable has type "Callable[[VarArg(Any), KwArg(Any)], Any]")
# not sure why we cannot reproduce that locally
ipykernel.kernelbase.Kernel.instance = classmethod(kernel_instance_dispatch) # type: ignore
InteractiveShell.instance = classmethod(interactive_shell_instance_dispatch) # type: ignore
# on CI we get a mypy error:
# solara/server/patch.py:211: error: Cannot assign to a method
# solara/server/patch.py:211: error: Incompatible types in assignment (expression has type "classmethod[Any]", variable has type "Callable[[], Any]")
# not sure why we cannot reproduce that locally
ipykernel.kernelbase.Kernel.initialized = classmethod(kernel_initialized_dispatch) # type: ignore
ipywidgets.widgets.widget.get_ipython = get_ipython
# TODO: find a way to actually monkeypatch get_ipython
IPython.get_ipython = get_ipython

ipywidgets.widgets.widget_output.Output.__enter__ = Output_enter
ipywidgets.widgets.widget_output.Output.__exit__ = Output_exit

def model_id_debug(self: ipywidgets.widgets.widget.Widget):
from ipyvue.ForceLoad import force_load_instance

Expand Down
206 changes: 206 additions & 0 deletions solara/server/shell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import sys
from threading import local
from unittest.mock import Mock

import reacton.patch_display
from IPython.core.displaypub import DisplayPublisher
from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
from jupyter_client.session import Session, extract_header
from traitlets import Any, CBytes, Dict, Instance, Type, default


def encode_images(obj):
# no-op in ipykernel
return obj


def json_clean(obj):
# no-op in ipykernel
return obj


# based on the zmq display publisher from ipykernel
# ideally this goes out of ipykernel
class SolaraDisplayPublisher(DisplayPublisher):
"""A display publisher that publishes data using a ZeroMQ PUB socket."""

session = Instance(Session, allow_none=True)
pub_socket = Any(allow_none=True)
parent_header = Dict({})
topic = CBytes(b"display_data")

_thread_local = Any()

def set_parent(self, parent):
"""Set the parent for outbound messages."""
self.parent_header = extract_header(parent)

def _flush_streams(self):
"""flush IO Streams prior to display"""
sys.stdout.flush()
sys.stderr.flush()

@default("_thread_local")
def _default_thread_local(self):
"""Initialize our thread local storage"""
return local()

@property
def _hooks(self):
if not hasattr(self._thread_local, "hooks"):
# create new list for a new thread
self._thread_local.hooks = []
return self._thread_local.hooks

def publish(
self,
data,
metadata=None,
transient=None,
update=False,
):
"""Publish a display-data message
Parameters
----------
data : dict
A mime-bundle dict, keyed by mime-type.
metadata : dict, optional
Metadata associated with the data.
transient : dict, optional, keyword-only
Transient data that may only be relevant during a live display,
such as display_id.
Transient data should not be persisted to documents.
update : bool, optional, keyword-only
If True, send an update_display_data message instead of display_data.
"""
self._flush_streams()
if metadata is None:
metadata = {}
if transient is None:
transient = {}
self._validate_data(data, metadata)
content = {}
content["data"] = encode_images(data)
content["metadata"] = metadata
content["transient"] = transient

msg_type = "update_display_data" if update else "display_data"

# Use 2-stage process to send a message,
# in order to put it through the transform
# hooks before potentially sending.
msg = self.session.msg(msg_type, json_clean(content), parent=self.parent_header)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_socket,
msg,
ident=self.topic,
)

def clear_output(self, wait=False):
"""Clear output associated with the current execution (cell).
Parameters
----------
wait : bool (default: False)
If True, the output will not be cleared immediately,
instead waiting for the next display before clearing.
This reduces bounce during repeated clear & display loops.
"""
content = dict(wait=wait)
self._flush_streams()
self.session.send(
self.pub_socket,
"clear_output",
content,
parent=self.parent_header,
ident=self.topic,
)

def register_hook(self, hook):
"""
Registers a hook with the thread-local storage.
Parameters
----------
hook : Any callable object
Returns
-------
Either a publishable message, or `None`.
The DisplayHook objects must return a message from
the __call__ method if they still require the
`session.send` method to be called after transformation.
Returning `None` will halt that execution path, and
session.send will not be called.
"""
self._hooks.append(hook)

def unregister_hook(self, hook):
"""
Un-registers a hook with the thread-local storage.
Parameters
----------
hook : Any callable object which has previously been
registered as a hook.
Returns
-------
bool - `True` if the hook was removed, `False` if it wasn't
found.
"""
try:
self._hooks.remove(hook)
return True
except ValueError:
return False


class SolaraInteractiveShell(InteractiveShell):
display_pub_class = Type(SolaraDisplayPublisher)
history_manager = Any() # type: ignore

def set_parent(self, parent):
"""Tell the children about the parent message."""
self.display_pub.set_parent(parent)

def init_history(self):
self.history_manager = Mock() # type: ignore

def init_display_formatter(self):
super().init_display_formatter()
self.display_formatter.ipython_display_formatter = reacton.patch_display.ReactonDisplayFormatter()

def init_display_pub(self):
super().init_display_pub()
self.display_pub.register_hook(self.display_in_reacton_hook)

def display_in_reacton_hook(self, msg):
"""Will intercept a display call and add the display data to an output widget when in a reacton context/render function."""
# similar to reacton.patch_display.publish
from reacton.core import get_render_context

rc = get_render_context(required=False)
# only during the render phase we want to capture the display calls
# during the reconsolidation phase we want to let the original display publisher do its thing
# such as adding it to a output widget
if rc is not None and not rc.reconsolidating and msg["msg_type"] == "display_data":
from reacton.ipywidgets import Output

Output(outputs=[{"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]}])
return None # do not send to the frontend
return msg


InteractiveShellABC.register(SolaraInteractiveShell)
25 changes: 25 additions & 0 deletions tests/unit/shell_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from unittest.mock import Mock

import IPython.display

from solara.server import app, kernel


def test_shell(no_app_context):
ws1 = Mock()
ws2 = Mock()
kernel1 = kernel.Kernel()
kernel2 = kernel.Kernel()
kernel1.session.websockets.add(ws1)
kernel2.session.websockets.add(ws2)
context1 = app.AppContext(id="1", kernel=kernel1)
context2 = app.AppContext(id="2", kernel=kernel2)

with context1:
IPython.display.display("test1")
assert ws1.send.call_count == 1
assert ws2.send.call_count == 0
with context2:
IPython.display.display("test1")
assert ws1.send.call_count == 1
assert ws2.send.call_count == 1