diff --git a/jupyter_server_nbmodel/extension.py b/jupyter_server_nbmodel/extension.py index 4793d7b..2a5b273 100644 --- a/jupyter_server_nbmodel/extension.py +++ b/jupyter_server_nbmodel/extension.py @@ -3,7 +3,7 @@ from jupyter_server.extension.application import ExtensionApp from jupyter_server.services.kernels.handlers import _kernel_id_regex -from .handlers import ExecuteHandler, ExecutionStack, RequestHandler +from .handlers import ExecuteHandler, ExecutionStack, InputHandler, RequestHandler from .log import get_logger RTC_EXTENSIONAPP_NAME = "jupyter_server_ydoc" @@ -36,6 +36,10 @@ def initialize_handlers(self): ExecuteHandler, {"ydoc_extension": rtc_extension, "execution_stack": self.__tasks}, ), + ( + f"/api/kernels/{_kernel_id_regex}/input", + InputHandler, + ), ( f"/api/kernels/{_kernel_id_regex}/requests/{_request_id_regex}", RequestHandler, diff --git a/jupyter_server_nbmodel/handlers.py b/jupyter_server_nbmodel/handlers.py index 62ccfba..e143748 100644 --- a/jupyter_server_nbmodel/handlers.py +++ b/jupyter_server_nbmodel/handlers.py @@ -37,6 +37,7 @@ class ExecutionStack: """ def __init__(self): + self.__pending_inputs: dict[str, dict] = {} self.__tasks: dict[str, asyncio.Task] = {} def __del__(self): @@ -55,11 +56,12 @@ def cancel(self, uid: str) -> None: self.__tasks[uid].cancel() - def get(self, uid: str) -> t.Any: + def get(self, kernel_id: str, uid: str) -> t.Any: """Get the request ``uid`` results or None. Args: - uid (str): Request index + kernel_id : Kernel identifier + uid : Request index Returns: Any: None if the request is pending else its result @@ -71,13 +73,16 @@ def get(self, uid: str) -> t.Any: if uid not in self.__tasks: raise ValueError(f"Request {uid} does not exists.") + if kernel_id in self.__pending_inputs: + return self.__pending_inputs.pop(kernel_id) + if self.__tasks[uid].done(): task = self.__tasks.pop(uid) return task.result() else: return None - def put(self, task: t.Awaitable, *args) -> str: + def put(self, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map) -> str: """Add a asynchronous execution request. Args: @@ -87,36 +92,57 @@ def put(self, task: t.Awaitable, *args) -> str: Returns: Request identifier """ - uid = uuid.uuid4() + uid = str(uuid.uuid4()) - async def execute_task(uid, f, *args) -> t.Any: - try: - get_logger().debug(f"Will execute request {uid}.") - result = await f(*args) - except asyncio.CancelledError: - raise - except Exception as e: - exception_type, _, tb = sys.exc_info() - result = { - "type": exception_type.__qualname__, - "error": str(e), - "message": repr(e), - "traceback": traceback.format_tb(tb), - } - get_logger().error("Error for request %s.", result) - else: - get_logger().debug(f"Has executed request {uid}.") + self.__tasks[uid] = asyncio.create_task( + execute_task(uid, km, snippet, ycell, partial(self._stdin_hook, km.kernel_id)) + ) + return uid - return result + def _stdin_hook(self, kernel_id, msg) -> None: + get_logger().info(f"Execution request {kernel_id} received a input request {msg!s}") + if kernel_id in self.__pending_inputs: + get_logger().error(f"Execution request {kernel_id} received a input request while waiting for an input.\n{msg}") + + header = msg["header"].copy() + header["date"] = header["date"].isoformat() + self.__pending_inputs[kernel_id] = {"parent_header": header, "input_request": msg["content"]} - self.__tasks[uid] = asyncio.create_task(execute_task(uid, task, *args)) - return uid + +async def execute_task( + uid, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map, stdin_hook +) -> t.Any: + try: + get_logger().debug(f"Will execute request {uid}.") + result = await _execute_snippet(uid, km, snippet, ycell, stdin_hook) + except asyncio.CancelledError: + raise + except Exception as e: + exception_type, _, tb = sys.exc_info() + result = { + "type": exception_type.__qualname__, + "error": str(e), + "message": repr(e), + "traceback": traceback.format_tb(tb), + } + get_logger().error("Error for request %s.", result) + else: + get_logger().debug(f"Has executed request {uid}.") + + return result -async def execute_snippet( - km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map +async def _execute_snippet( + uid: str, + km: jupyter_client.client.KernelClient, + snippet: str, + ycell: y.Map, + stdin_hook, ) -> dict[str, t.Any]: client = km.client() + client.session.session = uid + # FIXME + # client.session.username = username if ycell is not None: # Reset cell @@ -125,7 +151,6 @@ async def execute_snippet( outputs = [] - # FIXME set the username of client.session to server user # FIXME we don't check if the session is consistent (aka the kernel is linked to the document) # - should we? try: @@ -133,7 +158,7 @@ async def execute_snippet( client.execute_interactive( snippet, output_hook=partial(_output_hook, ycell, outputs), - stdin_hook=_stdin_hook if client.allow_stdin else None, + stdin_hook=stdin_hook if client.allow_stdin else None, ) ) @@ -191,9 +216,6 @@ def _output_hook(ycell, outputs, msg) -> None: # FIXME ... -def _stdin_hook(msg) -> None: - get_logger().info("Code snippet execution is waiting for an input.") - class ExecuteHandler(ExtensionHandlerMixin, APIHandler): """Handle request for snippet execution.""" @@ -288,13 +310,38 @@ async def post(self, kernel_id: str) -> None: get_logger().error(msg, exc_info=e) raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg) from e - uid = self._execution_stack.put(execute_snippet, km, snippet, ycell) + uid = self._execution_stack.put(km, snippet, ycell) self.set_status(HTTPStatus.ACCEPTED) self.set_header("Location", f"/api/kernels/{kernel_id}/requests/{uid}") self.finish("{}") +class InputHandler(ExtensionHandlerMixin, APIHandler): + """Handle request for input reply.""" + + @tornado.web.authenticated + async def post(self, kernel_id: str) -> None: + body = self.get_json_body() + + try: + km = self.kernel_manager.get_kernel(kernel_id) + except KeyError as e: + msg = f"Unknown kernel with id: {kernel_id}" + get_logger().error(msg, exc_info=e) + raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg) from e + + client = km.client() + + try: + # only send stdin reply if there *was not* another request + # or execution finished while we were reading. + if not (await client.stdin_channel.msg_ready() or await client.shell_channel.msg_ready()): + client.input(body["input"]) + finally: + del client + + class RequestHandler(ExtensionHandlerMixin, APIHandler): """Handler for /api/kernels//requests/""" @@ -305,14 +352,15 @@ def initialize( self._stack = execution_stack @tornado.web.authenticated - def get(self, kernel_id: str, uid: str) -> None: + def get(self, kernel_id: str, request_id: str) -> None: """`GET /api/kernels//requests/` Returns the request ``uid`` status. Status are: - * 200: Task result is returned - * 202: Task is pending - * 500: Task ends with errors + * 200: Request result is returned + * 202: Request is pending + * 300: Request has a pending input + * 500: Request ends with errors Args: index: Request identifier @@ -321,7 +369,7 @@ def get(self, kernel_id: str, uid: str) -> None: 404 if request ``uid`` does not exist """ try: - r = self._stack.get(uid) + r = self._stack.get(kernel_id, request_id) except ValueError as err: raise tornado.web.HTTPError(404, reason=str(err)) from err else: @@ -332,12 +380,15 @@ def get(self, kernel_id: str, uid: str) -> None: if "error" in r: self.set_status(500) self.log.debug(f"{r}") + elif "input_request" in r: + self.set_status(300) + self.set_header("Location", f"/api/kernels/{kernel_id}/input") else: self.set_status(200) self.finish(json.dumps(r)) @tornado.web.authenticated - def delete(self, kernel_id: str, uid: str) -> None: + def delete(self, kernel_id: str, request_id: str) -> None: """`DELETE /api/kernels//requests/` cancels the request ``uid``. Status are: @@ -350,7 +401,7 @@ def delete(self, kernel_id: str, uid: str) -> None: 404 if request ``uid`` does not exist """ try: - self._stack.cancel(int(uid)) + self._stack.cancel(request_id) except ValueError as err: raise tornado.web.HTTPError(404, reason=str(err)) from err else: