Skip to content

Commit

Permalink
FIXME attempt to support stdin -> for now server is locked :s
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollonval committed May 13, 2024
1 parent 96ff3f3 commit ffcfb02
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 40 deletions.
6 changes: 5 additions & 1 deletion jupyter_server_nbmodel/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.services.kernels.handlers import _kernel_id_regex

from .handlers import ExecuteHandler, ExecutionStack, RequestHandler
from .handlers import ExecuteHandler, ExecutionStack, InputHandler, RequestHandler
from .log import get_logger

RTC_EXTENSIONAPP_NAME = "jupyter_server_ydoc"
Expand Down Expand Up @@ -36,6 +36,10 @@ def initialize_handlers(self):
ExecuteHandler,
{"ydoc_extension": rtc_extension, "execution_stack": self.__tasks},
),
(
f"/api/kernels/{_kernel_id_regex}/input",
InputHandler,
),
(
f"/api/kernels/{_kernel_id_regex}/requests/{_request_id_regex}",
RequestHandler,
Expand Down
129 changes: 90 additions & 39 deletions jupyter_server_nbmodel/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ExecutionStack:
"""

def __init__(self):
self.__pending_inputs: dict[str, dict] = {}
self.__tasks: dict[str, asyncio.Task] = {}

def __del__(self):
Expand All @@ -55,11 +56,12 @@ def cancel(self, uid: str) -> None:

self.__tasks[uid].cancel()

def get(self, uid: str) -> t.Any:
def get(self, kernel_id: str, uid: str) -> t.Any:
"""Get the request ``uid`` results or None.
Args:
uid (str): Request index
kernel_id : Kernel identifier
uid : Request index
Returns:
Any: None if the request is pending else its result
Expand All @@ -71,13 +73,16 @@ def get(self, uid: str) -> t.Any:
if uid not in self.__tasks:
raise ValueError(f"Request {uid} does not exists.")

if kernel_id in self.__pending_inputs:
return self.__pending_inputs.pop(kernel_id)

if self.__tasks[uid].done():
task = self.__tasks.pop(uid)
return task.result()
else:
return None

def put(self, task: t.Awaitable, *args) -> str:
def put(self, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map) -> str:
"""Add a asynchronous execution request.
Args:
Expand All @@ -87,36 +92,57 @@ def put(self, task: t.Awaitable, *args) -> str:
Returns:
Request identifier
"""
uid = uuid.uuid4()
uid = str(uuid.uuid4())

async def execute_task(uid, f, *args) -> t.Any:
try:
get_logger().debug(f"Will execute request {uid}.")
result = await f(*args)
except asyncio.CancelledError:
raise
except Exception as e:
exception_type, _, tb = sys.exc_info()
result = {
"type": exception_type.__qualname__,
"error": str(e),
"message": repr(e),
"traceback": traceback.format_tb(tb),
}
get_logger().error("Error for request %s.", result)
else:
get_logger().debug(f"Has executed request {uid}.")
self.__tasks[uid] = asyncio.create_task(
execute_task(uid, km, snippet, ycell, partial(self._stdin_hook, km.kernel_id))
)
return uid

return result
def _stdin_hook(self, kernel_id, msg) -> None:
get_logger().info(f"Execution request {kernel_id} received a input request {msg!s}")
if kernel_id in self.__pending_inputs:
get_logger().error(f"Execution request {kernel_id} received a input request while waiting for an input.\n{msg}")

header = msg["header"].copy()
header["date"] = header["date"].isoformat()
self.__pending_inputs[kernel_id] = {"parent_header": header, "input_request": msg["content"]}

self.__tasks[uid] = asyncio.create_task(execute_task(uid, task, *args))
return uid

async def execute_task(
uid, km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map, stdin_hook
) -> t.Any:
try:
get_logger().debug(f"Will execute request {uid}.")
result = await _execute_snippet(uid, km, snippet, ycell, stdin_hook)
except asyncio.CancelledError:
raise
except Exception as e:
exception_type, _, tb = sys.exc_info()
result = {
"type": exception_type.__qualname__,
"error": str(e),
"message": repr(e),
"traceback": traceback.format_tb(tb),
}
get_logger().error("Error for request %s.", result)
else:
get_logger().debug(f"Has executed request {uid}.")

return result


async def execute_snippet(
km: jupyter_client.manager.KernelManager, snippet: str, ycell: y.Map
async def _execute_snippet(
uid: str,
km: jupyter_client.client.KernelClient,
snippet: str,
ycell: y.Map,
stdin_hook,
) -> dict[str, t.Any]:
client = km.client()
client.session.session = uid
# FIXME
# client.session.username = username

if ycell is not None:
# Reset cell
Expand All @@ -125,15 +151,14 @@ async def execute_snippet(

outputs = []

# FIXME set the username of client.session to server user
# FIXME we don't check if the session is consistent (aka the kernel is linked to the document)
# - should we?
try:
reply = await ensure_async(
client.execute_interactive(
snippet,
output_hook=partial(_output_hook, ycell, outputs),
stdin_hook=_stdin_hook if client.allow_stdin else None,
stdin_hook=stdin_hook if client.allow_stdin else None,
)
)

Expand Down Expand Up @@ -191,9 +216,6 @@ def _output_hook(ycell, outputs, msg) -> None:
# FIXME
...

def _stdin_hook(msg) -> None:
get_logger().info("Code snippet execution is waiting for an input.")


class ExecuteHandler(ExtensionHandlerMixin, APIHandler):
"""Handle request for snippet execution."""
Expand Down Expand Up @@ -288,13 +310,38 @@ async def post(self, kernel_id: str) -> None:
get_logger().error(msg, exc_info=e)
raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg) from e

uid = self._execution_stack.put(execute_snippet, km, snippet, ycell)
uid = self._execution_stack.put(km, snippet, ycell)

self.set_status(HTTPStatus.ACCEPTED)
self.set_header("Location", f"/api/kernels/{kernel_id}/requests/{uid}")
self.finish("{}")


class InputHandler(ExtensionHandlerMixin, APIHandler):
"""Handle request for input reply."""

@tornado.web.authenticated
async def post(self, kernel_id: str) -> None:
body = self.get_json_body()

try:
km = self.kernel_manager.get_kernel(kernel_id)
except KeyError as e:
msg = f"Unknown kernel with id: {kernel_id}"
get_logger().error(msg, exc_info=e)
raise tornado.web.HTTPError(status_code=HTTPStatus.NOT_FOUND, reason=msg) from e

client = km.client()

try:
# only send stdin reply if there *was not* another request
# or execution finished while we were reading.
if not (await client.stdin_channel.msg_ready() or await client.shell_channel.msg_ready()):
client.input(body["input"])
finally:
del client


class RequestHandler(ExtensionHandlerMixin, APIHandler):
"""Handler for /api/kernels/<kernel_id>/requests/<request_id>"""

Expand All @@ -305,14 +352,15 @@ def initialize(
self._stack = execution_stack

@tornado.web.authenticated
def get(self, kernel_id: str, uid: str) -> None:
def get(self, kernel_id: str, request_id: str) -> None:
"""`GET /api/kernels/<kernel_id>/requests/<id>` Returns the request ``uid`` status.
Status are:
* 200: Task result is returned
* 202: Task is pending
* 500: Task ends with errors
* 200: Request result is returned
* 202: Request is pending
* 300: Request has a pending input
* 500: Request ends with errors
Args:
index: Request identifier
Expand All @@ -321,7 +369,7 @@ def get(self, kernel_id: str, uid: str) -> None:
404 if request ``uid`` does not exist
"""
try:
r = self._stack.get(uid)
r = self._stack.get(kernel_id, request_id)
except ValueError as err:
raise tornado.web.HTTPError(404, reason=str(err)) from err
else:
Expand All @@ -332,12 +380,15 @@ def get(self, kernel_id: str, uid: str) -> None:
if "error" in r:
self.set_status(500)
self.log.debug(f"{r}")
elif "input_request" in r:
self.set_status(300)
self.set_header("Location", f"/api/kernels/{kernel_id}/input")
else:
self.set_status(200)
self.finish(json.dumps(r))

@tornado.web.authenticated
def delete(self, kernel_id: str, uid: str) -> None:
def delete(self, kernel_id: str, request_id: str) -> None:
"""`DELETE /api/kernels/<kernel_id>/requests/<id>` cancels the request ``uid``.
Status are:
Expand All @@ -350,7 +401,7 @@ def delete(self, kernel_id: str, uid: str) -> None:
404 if request ``uid`` does not exist
"""
try:
self._stack.cancel(int(uid))
self._stack.cancel(request_id)
except ValueError as err:
raise tornado.web.HTTPError(404, reason=str(err)) from err
else:
Expand Down

0 comments on commit ffcfb02

Please sign in to comment.