From bdd742d28dd87405c0358b670adcbe9f9e2e724c Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 24 May 2024 16:41:22 +0200 Subject: [PATCH] Support stdin in server-side execution --- .../fps_kernels/kernel_driver/driver.py | 80 +++++++++++++++++-- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index d1761b77..3e808e31 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -2,9 +2,10 @@ import os import time import uuid +from functools import partial from typing import Any, Dict, List, Optional, cast -from pycrdt import Array, Map +from pycrdt import Array, Map, Text from jupyverse_api.yjs import Yjs @@ -46,6 +47,7 @@ def __init__( self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {} self.comm_messages: asyncio.Queue = asyncio.Queue() self.tasks: List[asyncio.Task] = [] + self._background_tasks: set[asyncio.Task] = set() async def restart(self, startup_timeout: float = float("inf")) -> None: for task in self.tasks: @@ -80,13 +82,23 @@ async def connect(self, startup_timeout: float = float("inf")) -> None: def connect_channels(self, connection_cfg: Optional[cfg_t] = None): connection_cfg = connection_cfg or self.connection_cfg - self.shell_channel = connect_channel("shell", connection_cfg) + self.shell_channel = connect_channel( + "shell", + connection_cfg, + identity=self.session_id.encode(), + ) self.control_channel = connect_channel("control", connection_cfg) self.iopub_channel = connect_channel("iopub", connection_cfg) + self.stdin_channel = connect_channel( + "stdin", + connection_cfg, + identity=self.session_id.encode(), + ) def listen_channels(self): self.tasks.append(asyncio.create_task(self.listen_iopub())) self.tasks.append(asyncio.create_task(self.listen_shell())) + self.tasks.append(asyncio.create_task(self.listen_stdin())) async def stop(self) -> None: self.kernel_process.kill() @@ -111,6 +123,13 @@ async def listen_shell(self): if msg_id in self.execute_requests.keys(): self.execute_requests[msg_id]["shell_msg"].put_nowait(msg) + async def listen_stdin(self): + while True: + msg = await receive_message(self.stdin_channel, change_str_to_date=True) + msg_id = msg["parent_header"].get("msg_id") + if msg_id in self.execute_requests.keys(): + self.execute_requests[msg_id]["stdin_msg"].put_nowait(msg) + async def execute( self, ycell: Map, @@ -121,7 +140,7 @@ async def execute( if ycell["cell_type"] != "code": return ycell["execution_state"] = "busy" - content = {"code": str(ycell["source"]), "silent": False} + content = {"code": str(ycell["source"]), "silent": False, "allow_stdin": True} msg = create_message( "execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt) ) @@ -134,6 +153,7 @@ async def execute( self.execute_requests[msg_id] = { "iopub_msg": asyncio.Queue(), "shell_msg": asyncio.Queue(), + "stdin_msg": asyncio.Queue(), } if wait_for_executed: deadline = time.time() + timeout @@ -165,9 +185,11 @@ async def execute( ycell["execution_state"] = "idle" del self.execute_requests[msg_id] else: - self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell))) + stdin_task = asyncio.create_task(self._handle_stdin(msg_id, ycell)) + self.tasks.append(stdin_task) + self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell, stdin_task))) - async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: + async def _handle_iopub(self, msg_id: str, ycell: Map, stdin_task: asyncio.Task) -> None: while True: msg = await self.execute_requests[msg_id]["iopub_msg"].get() await self._handle_outputs(ycell["outputs"], msg) @@ -175,11 +197,59 @@ async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: (msg["header"]["msg_type"] == "status" and msg["content"]["execution_state"] == "idle") ): + stdin_task.cancel() msg = await self.execute_requests[msg_id]["shell_msg"].get() with ycell.doc.transaction(): ycell["execution_count"] = msg["content"]["execution_count"] ycell["execution_state"] = "idle" + async def _handle_stdin(self, msg_id: str, ycell: Map) -> None: + while True: + msg = await self.execute_requests[msg_id]["stdin_msg"].get() + if msg["msg_type"] == "input_request": + content = msg["content"] + outputs = ycell["outputs"] + with outputs.doc.transaction(): + text = Text() + stdin_output = Map( + { + "output_type": "stdin", + "submitted": False, + "password": content["password"], + "prompt": content["prompt"], + "value": text, + } + ) + stdin_idx = len(outputs) + outputs.append(stdin_output) + stdin_output.observe(partial(self._handle_stdin_submission, outputs, stdin_idx)) + + def _handle_stdin_submission(self, outputs, stdin_idx, event): + if event.target["submitted"]: + # send input reply to kernel + value = str(event.target["value"]) + content = {"value": value} + msg = create_message( + "input_reply", content, session_id=self.session_id, msg_id=str(self.msg_cnt) + ) + msg_id = msg["header"]["msg_id"] + task0 = asyncio.create_task( + send_message(msg, self.stdin_channel, self.key, change_date_to_str=True) + ) + task1 = asyncio.create_task(self._change_stdin_to_stream(outputs, stdin_idx, value)) + self._background_tasks.add(task0) + self._background_tasks.add(task1) + task0.add_done_callback(self._background_tasks.discard) + task1.add_done_callback(self._background_tasks.discard) + + async def _change_stdin_to_stream(self, outputs, stdin_idx, value): + # replace stdin output with stream output + outputs[stdin_idx] = { + "output_type": "stream", + "name": "stdout", + "text": [value], + } + async def _handle_comms(self) -> None: if self.yjs is None or self.yjs.widgets is None: # type: ignore return