diff --git a/plugins/kernels/fps_kernels/kernel_driver/connect.py b/plugins/kernels/fps_kernels/kernel_driver/connect.py index b137bbff..cc69b183 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/connect.py +++ b/plugins/kernels/fps_kernels/kernel_driver/connect.py @@ -4,7 +4,7 @@ import socket import tempfile import uuid -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union import zmq import zmq.asyncio @@ -62,7 +62,7 @@ def write_connection_file( return fname, cfg -def read_connection_file(fname: str = "") -> cfg_t: +def read_connection_file(fname: str) -> cfg_t: with open(fname, "rt") as f: cfg: cfg_t = json.load(f) @@ -70,33 +70,40 @@ def read_connection_file(fname: str = "") -> cfg_t: async def launch_kernel( - kernelspec_path: str, connection_file_path: str, capture_output: bool + kernelspec_path: str, connection_file_path: str, kernel_cwd: str, capture_output: bool ) -> asyncio.subprocess.Process: with open(kernelspec_path) as f: kernelspec = json.load(f) cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]] + if kernel_cwd: + prev_dir = os.getcwd() + os.chdir(kernel_cwd) if capture_output: p = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT ) else: p = await asyncio.create_subprocess_exec(*cmd) + if kernel_cwd: + os.chdir(prev_dir) return p -def create_socket(channel: str, cfg: cfg_t) -> Socket: +def create_socket(channel: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket: ip = cfg["ip"] port = cfg[f"{channel}_port"] url = f"tcp://{ip}:{port}" socket_type = channel_socket_types[channel] sock = context.socket(socket_type) sock.linger = 1000 # set linger to 1s to prevent hangs at exit + if identity: + sock.identity = identity sock.connect(url) return sock -def connect_channel(channel_name: str, cfg: cfg_t) -> Socket: - sock = create_socket(channel_name, cfg) +def connect_channel(channel_name: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket: + sock = create_socket(channel_name, cfg, identity) if channel_name == "iopub": sock.setsockopt(zmq.SUBSCRIBE, b"") return sock diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index ce304a1f..85a03ba1 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -2,53 +2,31 @@ import os import time import uuid -from typing import Any, Dict, List, Optional, Tuple, cast - -from zmq.asyncio import Socket +from typing import Any, Dict, List, Optional, cast from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file from .connect import write_connection_file as _write_connection_file from .kernelspec import find_kernelspec -from .message import create_message, deserialize, serialize - -DELIM = b"" +from .message import create_message, receive_message, send_message def deadline_to_timeout(deadline: float) -> float: return max(0, deadline - time.time()) -def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]: - idx = msg_list.index(DELIM) - return msg_list[:idx], msg_list[idx + 1 :] # noqa - - -def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None: - to_send = serialize(msg, key) - sock.send_multipart(to_send, copy=True) - - -async def receive_message(sock: Socket, timeout: float = float("inf")) -> Optional[Dict[str, Any]]: - timeout *= 1000 # in ms - ready = await sock.poll(timeout) - if ready: - msg_list = await sock.recv_multipart() - idents, msg_list = feed_identities(msg_list) - return deserialize(msg_list) - return None - - class KernelDriver: def __init__( self, kernel_name: str = "", kernelspec_path: str = "", + kernel_cwd: str = "", connection_file: str = "", write_connection_file: bool = True, capture_kernel_output: bool = True, ) -> None: self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name) + self.kernel_cwd = kernel_cwd if not self.kernelspec_path: raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") if write_connection_file: @@ -66,9 +44,11 @@ async def restart(self, startup_timeout: float = float("inf")) -> None: for task in self.channel_tasks: task.cancel() msg = create_message("shutdown_request", content={"restart": True}) - send_message(msg, self.control_channel, self.key) + await send_message(msg, self.control_channel, self.key, change_date_to_str=True) while True: - msg = cast(Dict[str, Any], await receive_message(self.control_channel)) + msg = cast( + Dict[str, Any], await receive_message(self.control_channel, change_str_to_date=True) + ) if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: break await self._wait_for_ready(startup_timeout) @@ -77,7 +57,10 @@ async def restart(self, startup_timeout: float = float("inf")) -> None: async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None: self.kernel_process = await launch_kernel( - self.kernelspec_path, self.connection_file_path, self.capture_kernel_output + self.kernelspec_path, + self.connection_file_path, + self.kernel_cwd, + self.capture_kernel_output, ) if connect: await self.connect(startup_timeout) @@ -106,14 +89,14 @@ async def stop(self) -> None: async def listen_iopub(self): while True: - msg = await receive_message(self.iopub_channel) # type: ignore + msg = await receive_message(self.iopub_channel, change_str_to_date=True) # type: ignore msg_id = msg["parent_header"].get("msg_id") if msg_id in self.execute_requests.keys(): self.execute_requests[msg_id]["iopub_msg"].set_result(msg) async def listen_shell(self): while True: - msg = await receive_message(self.shell_channel) # type: ignore + msg = await receive_message(self.shell_channel, change_str_to_date=True) # type: ignore msg_id = msg["parent_header"].get("msg_id") if msg_id in self.execute_requests.keys(): self.execute_requests[msg_id]["shell_msg"].set_result(msg) @@ -129,14 +112,14 @@ async def execute( return content = {"code": cell["source"], "silent": False} msg = create_message( - "execute_request", content, session_id=self.session_id, msg_cnt=self.msg_cnt + "execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt) ) if msg_id: msg["header"]["msg_id"] = msg_id else: msg_id = msg["header"]["msg_id"] self.msg_cnt += 1 - send_message(msg, self.shell_channel, self.key) + await send_message(msg, self.shell_channel, self.key, change_date_to_str=True) if wait_for_executed: deadline = time.time() + timeout self.execute_requests[msg_id] = { @@ -177,16 +160,20 @@ async def _wait_for_ready(self, timeout): new_timeout = timeout while True: msg = create_message( - "kernel_info_request", session_id=self.session_id, msg_cnt=self.msg_cnt + "kernel_info_request", session_id=self.session_id, msg_id=str(self.msg_cnt) ) self.msg_cnt += 1 - send_message(msg, self.shell_channel, self.key) - msg = await receive_message(self.shell_channel, new_timeout) + await send_message(msg, self.shell_channel, self.key, change_date_to_str=True) + msg = await receive_message( + self.shell_channel, timeout=new_timeout, change_str_to_date=True + ) if msg is None: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) if msg["msg_type"] == "kernel_info_reply": - msg = await receive_message(self.iopub_channel, 0.2) + msg = await receive_message( + self.iopub_channel, timeout=0.2, change_str_to_date=True + ) if msg is not None: break new_timeout = deadline_to_timeout(deadline) diff --git a/plugins/kernels/fps_kernels/kernel_driver/message.py b/plugins/kernels/fps_kernels/kernel_driver/message.py index e6231dfe..69f05595 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/message.py +++ b/plugins/kernels/fps_kernels/kernel_driver/message.py @@ -1,10 +1,11 @@ import hashlib import hmac -import uuid from datetime import datetime, timezone -from typing import Any, Dict, List, cast +from typing import Any, Dict, List, Optional, Tuple, cast +from uuid import uuid4 from dateutil.parser import parse as dateutil_parse # type: ignore +from zmq.asyncio import Socket from zmq.utils import jsonapi protocol_version_info = (5, 3) @@ -13,6 +14,11 @@ DELIM = b"" +def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]: + idx = msg_list.index(DELIM) + return msg_list[:idx], msg_list[idx + 1 :] # noqa + + def str_to_date(obj: Dict[str, Any]) -> Dict[str, Any]: if "date" in obj: obj["date"] = dateutil_parse(obj["date"]) @@ -29,17 +35,17 @@ def utcnow() -> datetime: return datetime.utcnow().replace(tzinfo=timezone.utc) -def create_message_header(msg_type: str, session_id: str, msg_cnt: int) -> Dict[str, Any]: +def create_message_header(msg_type: str, session_id: str, msg_id: str) -> Dict[str, Any]: if not session_id: - session_id = msg_id = uuid.uuid4().hex + session_id = msg_id = uuid4().hex else: - msg_id = f"{session_id}_{msg_cnt}" + msg_id = f"{session_id}_{msg_id}" header = { - "date": utcnow(), + "date": utcnow().isoformat().replace("+00:00", "Z"), "msg_id": msg_id, "msg_type": msg_type, "session": session_id, - "username": "david", + "username": "", "version": protocol_version, } return header @@ -49,9 +55,9 @@ def create_message( msg_type: str, content: Dict = {}, session_id: str = "", - msg_cnt: int = 0, + msg_id: str = "", ) -> Dict[str, Any]: - header = create_message_header(msg_type, session_id, msg_cnt) + header = create_message_header(msg_type, session_id, msg_id) msg = { "header": header, "msg_id": header["msg_id"], @@ -59,6 +65,7 @@ def create_message( "parent_header": {}, "content": content, "metadata": {}, + "buffers": [], } return msg @@ -79,25 +86,52 @@ def sign(msg_list: List[bytes], key: str) -> bytes: return h.hexdigest().encode() -def serialize(msg: Dict[str, Any], key: str) -> List[bytes]: +def serialize(msg: Dict[str, Any], key: str, change_date_to_str: bool = False) -> List[bytes]: + _date_to_str = date_to_str if change_date_to_str else lambda x: x message = [ - pack(date_to_str(msg["header"])), - pack(date_to_str(msg["parent_header"])), - pack(date_to_str(msg["metadata"])), - pack(date_to_str(msg.get("content", {}))), + pack(_date_to_str(msg["header"])), + pack(_date_to_str(msg["parent_header"])), + pack(_date_to_str(msg["metadata"])), + pack(_date_to_str(msg.get("content", {}))), ] - to_send = [DELIM, sign(message, key)] + message + to_send = [DELIM, sign(message, key)] + message + msg.get("buffers", []) return to_send -def deserialize(msg_list: List[bytes]) -> Dict[str, Any]: +def deserialize( + msg_list: List[bytes], + parent_header: Optional[Dict[str, Any]] = None, + change_str_to_date: bool = False, +) -> Dict[str, Any]: + _str_to_date = str_to_date if change_str_to_date else lambda x: x message: Dict[str, Any] = {} header = unpack(msg_list[1]) - message["header"] = str_to_date(header) + message["header"] = _str_to_date(header) message["msg_id"] = header["msg_id"] message["msg_type"] = header["msg_type"] - message["parent_header"] = str_to_date(unpack(msg_list[2])) + if parent_header: + message["parent_header"] = parent_header + else: + message["parent_header"] = _str_to_date(unpack(msg_list[2])) message["metadata"] = unpack(msg_list[3]) message["content"] = unpack(msg_list[4]) message["buffers"] = [memoryview(b) for b in msg_list[5:]] return message + + +async def send_message( + msg: Dict[str, Any], sock: Socket, key: str, change_date_to_str: bool = False +) -> None: + await sock.send_multipart(serialize(msg, key, change_date_to_str=change_date_to_str), copy=True) + + +async def receive_message( + sock: Socket, timeout: float = float("inf"), change_str_to_date: bool = False +) -> Optional[Dict[str, Any]]: + timeout *= 1000 # in ms + ready = await sock.poll(timeout) + if ready: + msg_list = await sock.recv_multipart() + idents, msg_list = feed_identities(msg_list) + return deserialize(msg_list, change_str_to_date=change_str_to_date) + return None diff --git a/plugins/kernels/fps_kernels/kernel_server/connect.py b/plugins/kernels/fps_kernels/kernel_server/connect.py deleted file mode 100644 index dcb4edbf..00000000 --- a/plugins/kernels/fps_kernels/kernel_server/connect.py +++ /dev/null @@ -1,122 +0,0 @@ -import asyncio -import json -import os -import socket -import tempfile -import uuid -from typing import Dict, Optional, Tuple, Union - -import zmq -import zmq.asyncio -from fastapi import WebSocket -from zmq.asyncio import Socket - -channel_socket_types = { - "hb": zmq.REQ, - "shell": zmq.DEALER, - "iopub": zmq.SUB, - "stdin": zmq.DEALER, - "control": zmq.DEALER, -} - -context = zmq.asyncio.Context() - -cfg_t = Dict[str, Union[str, int]] - - -def get_port(ip: str) -> int: - sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8) - sock.bind((ip, 0)) - port = sock.getsockname()[1] - sock.close() - return port - - -def write_connection_file( - fname: str = "", - ip: str = "", - transport: str = "tcp", - signature_scheme: str = "hmac-sha256", - kernel_name: str = "", -) -> Tuple[str, cfg_t]: - ip = ip or "127.0.0.1" - - if not fname: - fd, fname = tempfile.mkstemp(suffix=".json") - os.close(fd) - f = open(fname, "wt") - - channels = ["shell", "iopub", "stdin", "control", "hb"] - - cfg: cfg_t = {f"{c}_port": get_port(ip) for c in channels} - - cfg["ip"] = ip - cfg["key"] = uuid.uuid4().hex - cfg["transport"] = transport - cfg["signature_scheme"] = signature_scheme - cfg["kernel_name"] = kernel_name - - f.write(json.dumps(cfg, indent=2)) - f.close() - - return fname, cfg - - -def read_connection_file(fname: str) -> cfg_t: - with open(fname, "rt") as f: - cfg: cfg_t = json.load(f) - - return cfg - - -async def launch_kernel( - kernelspec_path: str, connection_file_path: str, capture_output: bool -) -> asyncio.subprocess.Process: - with open(kernelspec_path) as f: - kernelspec = json.load(f) - cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]] - if capture_output: - p = await asyncio.create_subprocess_exec( - *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - else: - p = await asyncio.create_subprocess_exec(*cmd) - return p - - -def create_socket(channel: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket: - ip = cfg["ip"] - port = cfg[f"{channel}_port"] - url = f"tcp://{ip}:{port}" - socket_type = channel_socket_types[channel] - sock = context.socket(socket_type) - sock.linger = 1000 # set linger to 1s to prevent hangs at exit - if identity: - sock.identity = identity - sock.connect(url) - return sock - - -def connect_channel(channel_name: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket: - sock = create_socket(channel_name, cfg, identity) - if channel_name == "iopub": - sock.setsockopt(zmq.SUBSCRIBE, b"") - return sock - - -class AcceptedWebSocket: - _websocket: WebSocket - _accepted_subprotocol: Optional[str] - - def __init__(self, websocket, accepted_subprotocol): - self._websocket = websocket - self._accepted_subprotocol = accepted_subprotocol - - @property - def websocket(self): - return self._websocket - - @property - def accepted_subprotocol(self): - return self._accepted_subprotocol diff --git a/plugins/kernels/fps_kernels/kernel_server/message.py b/plugins/kernels/fps_kernels/kernel_server/message.py index 4138db77..568efb80 100644 --- a/plugins/kernels/fps_kernels/kernel_server/message.py +++ b/plugins/kernels/fps_kernels/kernel_server/message.py @@ -1,18 +1,10 @@ -import hashlib -import hmac import json import struct -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Tuple, cast -from uuid import uuid4 +from typing import Any, Dict, List, Optional, Tuple from zmq.asyncio import Socket -from zmq.utils import jsonapi -protocol_version_info = (5, 3) -protocol_version = "%i.%i" % protocol_version_info - -DELIM = b"" +from ..kernel_driver.message import DELIM, deserialize, feed_identities, sign, unpack def to_binary(msg: Dict[str, Any]) -> Optional[bytes]: @@ -42,60 +34,6 @@ def from_binary(bmsg: bytes) -> Dict[str, Any]: return msg -def pack(obj: Dict[str, Any]) -> bytes: - return jsonapi.dumps(obj) - - -def unpack(s: bytes) -> Dict[str, Any]: - return cast(Dict[str, Any], jsonapi.loads(s)) - - -def sign(msg_list: List[bytes], key: str) -> bytes: - auth = hmac.new(key.encode("ascii"), digestmod=hashlib.sha256) - h = auth.copy() - for m in msg_list: - h.update(m) - return h.hexdigest().encode() - - -def serialize(msg: Dict[str, Any], key: str) -> List[bytes]: - message = [ - pack(msg["header"]), - pack(msg["parent_header"]), - pack(msg["metadata"]), - pack(msg.get("content", {})), - ] - to_send = [DELIM, sign(message, key)] + message + msg.get("buffers", []) - return to_send - - -def deserialize( - msg_list: List[bytes], parent_header: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - message: Dict[str, Any] = {} - header = unpack(msg_list[1]) - message["header"] = header - message["msg_id"] = header["msg_id"] - message["msg_type"] = header["msg_type"] - if parent_header: - message["parent_header"] = parent_header - else: - message["parent_header"] = unpack(msg_list[2]) - message["metadata"] = unpack(msg_list[3]) - message["content"] = unpack(msg_list[4]) - message["buffers"] = [memoryview(b) for b in msg_list[5:]] - return message - - -def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]: - idx = msg_list.index(DELIM) - return msg_list[:idx], msg_list[idx + 1 :] # noqa - - -async def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None: - await sock.send_multipart(serialize(msg, key), copy=True) - - async def send_raw_message(parts: List[bytes], sock: Socket, key: str) -> None: msg = parts[:4] buffers = parts[4:] @@ -114,16 +52,6 @@ def deserialize_msg_from_ws_v1(ws_msg: bytes) -> Tuple[str, List[bytes]]: return channel, msg_list -async def receive_message(sock: Socket, timeout: float = float("inf")) -> Optional[Dict[str, Any]]: - timeout *= 1000 # in ms - ready = await sock.poll(timeout) - if ready: - msg_list = await sock.recv_multipart() - idents, msg_list = feed_identities(msg_list) - return deserialize(msg_list) - return None - - async def get_zmq_parts(socket: Socket) -> List[bytes]: parts = await socket.recv_multipart() idents, parts = feed_identities(parts) @@ -152,42 +80,3 @@ def serialize_msg_to_ws_v1(msg_list: List[bytes], channel: str) -> List[bytes]: def get_parent_header(parts: List[bytes]) -> Dict[str, Any]: return unpack(parts[2]) - - -def utcnow() -> datetime: - return datetime.utcnow().replace(tzinfo=timezone.utc) - - -def create_message_header(msg_type: str, session_id: str, msg_id: str) -> Dict[str, Any]: - if not session_id: - session_id = uuid4().hex - if not msg_id: - msg_id = uuid4().hex - header = { - "date": utcnow().isoformat().replace("+00:00", "Z"), - "msg_id": msg_id, - "msg_type": msg_type, - "session": session_id, - "username": "", - "version": protocol_version, - } - return header - - -def create_message( - msg_type: str, - content: Dict = {}, - session_id: str = "", - msg_id: str = "", -) -> Dict[str, Any]: - header = create_message_header(msg_type, session_id, msg_id) - msg = { - "header": header, - "msg_id": header["msg_id"], - "msg_type": header["msg_type"], - "parent_header": {}, - "content": content, - "metadata": {}, - "buffers": [], - } - return msg diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index b388e694..27f7c75e 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -6,26 +6,25 @@ from datetime import datetime from typing import Dict, Iterable, List, Optional, cast -from fastapi import WebSocketDisconnect # type: ignore +from fastapi import WebSocket, WebSocketDisconnect # type: ignore from starlette.websockets import WebSocketState -from .connect import ( - AcceptedWebSocket, +from ..kernel_driver.connect import ( cfg_t, connect_channel, launch_kernel, read_connection_file, ) -from .connect import write_connection_file as _write_connection_file # type: ignore +from ..kernel_driver.connect import ( + write_connection_file as _write_connection_file, # type: ignore +) +from ..kernel_driver.message import create_message, receive_message, send_message from .message import ( # type: ignore - create_message, deserialize_msg_from_ws_v1, from_binary, get_msg_from_parts, get_parent_header, get_zmq_parts, - receive_message, - send_message, send_raw_message, serialize_msg_to_ws_v1, to_binary, @@ -34,10 +33,28 @@ kernels: dict = {} +class AcceptedWebSocket: + _websocket: WebSocket + _accepted_subprotocol: Optional[str] + + def __init__(self, websocket, accepted_subprotocol): + self._websocket = websocket + self._accepted_subprotocol = accepted_subprotocol + + @property + def websocket(self): + return self._websocket + + @property + def accepted_subprotocol(self): + return self._accepted_subprotocol + + class KernelServer: def __init__( self, kernelspec_path: str = "", + kernel_cwd: str = "", connection_cfg: Optional[cfg_t] = None, connection_file: str = "", write_connection_file: bool = True, @@ -45,6 +62,7 @@ def __init__( ) -> None: self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path + self.kernel_cwd = kernel_cwd self.connection_cfg = connection_cfg self.connection_file = connection_file self.write_connection_file = write_connection_file @@ -98,7 +116,10 @@ async def start(self) -> None: "execution_state": "starting", } self.kernel_process = await launch_kernel( - self.kernelspec_path, self.connection_file_path, self.capture_kernel_output + self.kernelspec_path, + self.connection_file_path, + self.kernel_cwd, + self.capture_kernel_output, ) assert self.connection_cfg is not None identity = uuid.uuid4().hex.encode("ascii") @@ -144,7 +165,9 @@ async def serve( self.sessions[session_id] = websocket self.can_execute = permissions is None or "execute" in permissions.get("kernels", []) await self.listen_web(websocket) - del self.sessions[session_id] + # the session could have been removed through the REST API, so check if it still exists + if session_id in self.sessions: + del self.sessions[session_id] async def listen_web(self, websocket: AcceptedWebSocket): try: @@ -179,9 +202,9 @@ async def _wait_for_ready(self): while True: msg = create_message("kernel_info_request") await send_message(msg, self.shell_channel, self.key) - msg = await receive_message(self.shell_channel, 0.2) + msg = await receive_message(self.shell_channel, timeout=0.2) if msg is not None and msg["msg_type"] == "kernel_info_reply": - msg = await receive_message(self.iopub_channel, 0.2) + msg = await receive_message(self.iopub_channel, timeout=0.2) if msg is None: # IOPub not connected, start over pass diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index b65237c1..cf0a697c 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -1,8 +1,8 @@ import json -import pathlib import sys import uuid from http import HTTPStatus +from pathlib import Path from fastapi import APIRouter, Depends, Response from fastapi.responses import FileResponse @@ -18,13 +18,13 @@ KernelServer, kernels, ) -from .models import Execution, Session +from .models import CreateSession, Execution, Session router = APIRouter() kernelspecs: dict = {} sessions: dict = {} -prefix_dir: pathlib.Path = pathlib.Path(sys.prefix) +prefix_dir: Path = Path(sys.prefix) @router.on_event("shutdown") @@ -124,12 +124,13 @@ async def create_session( request: Request, user: User = Depends(current_user(permissions={"sessions": ["write"]})), ): - create_session = await request.json() - kernel_name = create_session["kernel"]["name"] + create_session = CreateSession(**(await request.json())) + kernel_name = create_session.kernel.name kernel_server = KernelServer( kernelspec_path=( prefix_dir / "share" / "jupyter" / "kernels" / kernel_name / "kernel.json" ).as_posix(), + kernel_cwd=str(Path(create_session.path).parent), ) kernel_id = str(uuid.uuid4()) kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None} @@ -137,17 +138,17 @@ async def create_session( session_id = str(uuid.uuid4()) session = { "id": session_id, - "path": create_session["path"], - "name": create_session["name"], - "type": create_session["type"], + "path": create_session.path, + "name": create_session.name, + "type": create_session.type, "kernel": { "id": kernel_id, - "name": create_session["kernel"]["name"], + "name": create_session.kernel.name, "connections": kernel_server.connections, "last_activity": kernel_server.last_activity["date"], "execution_state": kernel_server.last_activity["execution_state"], }, - "notebook": {"path": create_session["path"], "name": create_session["name"]}, + "notebook": {"path": create_session.path, "name": create_session.name}, } sessions[session_id] = session return Session(**session)