Skip to content

Commit

Permalink
Merge pull request #245 from davidbrochart/kernel_path
Browse files Browse the repository at this point in the history
Add kernel cwd
  • Loading branch information
davidbrochart authored Nov 10, 2022
2 parents 4a24ca1 + 286528d commit 9e5009f
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 317 deletions.
19 changes: 13 additions & 6 deletions plugins/kernels/fps_kernels/kernel_driver/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import tempfile
import uuid
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import zmq
import zmq.asyncio
Expand Down Expand Up @@ -62,41 +62,48 @@ def write_connection_file(
return fname, cfg


def read_connection_file(fname: str = "") -> cfg_t:
def read_connection_file(fname: str) -> cfg_t:
with open(fname, "rt") as f:
cfg: cfg_t = json.load(f)

return cfg


async def launch_kernel(
kernelspec_path: str, connection_file_path: str, capture_output: bool
kernelspec_path: str, connection_file_path: str, kernel_cwd: str, capture_output: bool
) -> asyncio.subprocess.Process:
with open(kernelspec_path) as f:
kernelspec = json.load(f)
cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]]
if kernel_cwd:
prev_dir = os.getcwd()
os.chdir(kernel_cwd)
if capture_output:
p = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT
)
else:
p = await asyncio.create_subprocess_exec(*cmd)
if kernel_cwd:
os.chdir(prev_dir)
return p


def create_socket(channel: str, cfg: cfg_t) -> Socket:
def create_socket(channel: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket:
ip = cfg["ip"]
port = cfg[f"{channel}_port"]
url = f"tcp://{ip}:{port}"
socket_type = channel_socket_types[channel]
sock = context.socket(socket_type)
sock.linger = 1000 # set linger to 1s to prevent hangs at exit
if identity:
sock.identity = identity
sock.connect(url)
return sock


def connect_channel(channel_name: str, cfg: cfg_t) -> Socket:
sock = create_socket(channel_name, cfg)
def connect_channel(channel_name: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket:
sock = create_socket(channel_name, cfg, identity)
if channel_name == "iopub":
sock.setsockopt(zmq.SUBSCRIBE, b"")
return sock
61 changes: 24 additions & 37 deletions plugins/kernels/fps_kernels/kernel_driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,31 @@
import os
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast

from zmq.asyncio import Socket
from typing import Any, Dict, List, Optional, cast

from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
from .connect import write_connection_file as _write_connection_file
from .kernelspec import find_kernelspec
from .message import create_message, deserialize, serialize

DELIM = b"<IDS|MSG>"
from .message import create_message, receive_message, send_message


def deadline_to_timeout(deadline: float) -> float:
return max(0, deadline - time.time())


def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]:
idx = msg_list.index(DELIM)
return msg_list[:idx], msg_list[idx + 1 :] # noqa


def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None:
to_send = serialize(msg, key)
sock.send_multipart(to_send, copy=True)


async def receive_message(sock: Socket, timeout: float = float("inf")) -> Optional[Dict[str, Any]]:
timeout *= 1000 # in ms
ready = await sock.poll(timeout)
if ready:
msg_list = await sock.recv_multipart()
idents, msg_list = feed_identities(msg_list)
return deserialize(msg_list)
return None


class KernelDriver:
def __init__(
self,
kernel_name: str = "",
kernelspec_path: str = "",
kernel_cwd: str = "",
connection_file: str = "",
write_connection_file: bool = True,
capture_kernel_output: bool = True,
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
self.kernel_cwd = kernel_cwd
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
if write_connection_file:
Expand All @@ -66,9 +44,11 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
for task in self.channel_tasks:
task.cancel()
msg = create_message("shutdown_request", content={"restart": True})
send_message(msg, self.control_channel, self.key)
await send_message(msg, self.control_channel, self.key, change_date_to_str=True)
while True:
msg = cast(Dict[str, Any], await receive_message(self.control_channel))
msg = cast(
Dict[str, Any], await receive_message(self.control_channel, change_str_to_date=True)
)
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
break
await self._wait_for_ready(startup_timeout)
Expand All @@ -77,7 +57,10 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:

async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
self.kernel_process = await launch_kernel(
self.kernelspec_path, self.connection_file_path, self.capture_kernel_output
self.kernelspec_path,
self.connection_file_path,
self.kernel_cwd,
self.capture_kernel_output,
)
if connect:
await self.connect(startup_timeout)
Expand Down Expand Up @@ -106,14 +89,14 @@ async def stop(self) -> None:

async def listen_iopub(self):
while True:
msg = await receive_message(self.iopub_channel) # type: ignore
msg = await receive_message(self.iopub_channel, change_str_to_date=True) # type: ignore
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)

async def listen_shell(self):
while True:
msg = await receive_message(self.shell_channel) # type: ignore
msg = await receive_message(self.shell_channel, change_str_to_date=True) # type: ignore
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
Expand All @@ -129,14 +112,14 @@ async def execute(
return
content = {"code": cell["source"], "silent": False}
msg = create_message(
"execute_request", content, session_id=self.session_id, msg_cnt=self.msg_cnt
"execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
)
if msg_id:
msg["header"]["msg_id"] = msg_id
else:
msg_id = msg["header"]["msg_id"]
self.msg_cnt += 1
send_message(msg, self.shell_channel, self.key)
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
if wait_for_executed:
deadline = time.time() + timeout
self.execute_requests[msg_id] = {
Expand Down Expand Up @@ -177,16 +160,20 @@ async def _wait_for_ready(self, timeout):
new_timeout = timeout
while True:
msg = create_message(
"kernel_info_request", session_id=self.session_id, msg_cnt=self.msg_cnt
"kernel_info_request", session_id=self.session_id, msg_id=str(self.msg_cnt)
)
self.msg_cnt += 1
send_message(msg, self.shell_channel, self.key)
msg = await receive_message(self.shell_channel, new_timeout)
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
msg = await receive_message(
self.shell_channel, timeout=new_timeout, change_str_to_date=True
)
if msg is None:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
if msg["msg_type"] == "kernel_info_reply":
msg = await receive_message(self.iopub_channel, 0.2)
msg = await receive_message(
self.iopub_channel, timeout=0.2, change_str_to_date=True
)
if msg is not None:
break
new_timeout = deadline_to_timeout(deadline)
Expand Down
70 changes: 52 additions & 18 deletions plugins/kernels/fps_kernels/kernel_driver/message.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import hashlib
import hmac
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, cast
from typing import Any, Dict, List, Optional, Tuple, cast
from uuid import uuid4

from dateutil.parser import parse as dateutil_parse # type: ignore
from zmq.asyncio import Socket
from zmq.utils import jsonapi

protocol_version_info = (5, 3)
Expand All @@ -13,6 +14,11 @@
DELIM = b"<IDS|MSG>"


def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]:
idx = msg_list.index(DELIM)
return msg_list[:idx], msg_list[idx + 1 :] # noqa


def str_to_date(obj: Dict[str, Any]) -> Dict[str, Any]:
if "date" in obj:
obj["date"] = dateutil_parse(obj["date"])
Expand All @@ -29,17 +35,17 @@ def utcnow() -> datetime:
return datetime.utcnow().replace(tzinfo=timezone.utc)


def create_message_header(msg_type: str, session_id: str, msg_cnt: int) -> Dict[str, Any]:
def create_message_header(msg_type: str, session_id: str, msg_id: str) -> Dict[str, Any]:
if not session_id:
session_id = msg_id = uuid.uuid4().hex
session_id = msg_id = uuid4().hex
else:
msg_id = f"{session_id}_{msg_cnt}"
msg_id = f"{session_id}_{msg_id}"
header = {
"date": utcnow(),
"date": utcnow().isoformat().replace("+00:00", "Z"),
"msg_id": msg_id,
"msg_type": msg_type,
"session": session_id,
"username": "david",
"username": "",
"version": protocol_version,
}
return header
Expand All @@ -49,16 +55,17 @@ def create_message(
msg_type: str,
content: Dict = {},
session_id: str = "",
msg_cnt: int = 0,
msg_id: str = "",
) -> Dict[str, Any]:
header = create_message_header(msg_type, session_id, msg_cnt)
header = create_message_header(msg_type, session_id, msg_id)
msg = {
"header": header,
"msg_id": header["msg_id"],
"msg_type": header["msg_type"],
"parent_header": {},
"content": content,
"metadata": {},
"buffers": [],
}
return msg

Expand All @@ -79,25 +86,52 @@ def sign(msg_list: List[bytes], key: str) -> bytes:
return h.hexdigest().encode()


def serialize(msg: Dict[str, Any], key: str) -> List[bytes]:
def serialize(msg: Dict[str, Any], key: str, change_date_to_str: bool = False) -> List[bytes]:
_date_to_str = date_to_str if change_date_to_str else lambda x: x
message = [
pack(date_to_str(msg["header"])),
pack(date_to_str(msg["parent_header"])),
pack(date_to_str(msg["metadata"])),
pack(date_to_str(msg.get("content", {}))),
pack(_date_to_str(msg["header"])),
pack(_date_to_str(msg["parent_header"])),
pack(_date_to_str(msg["metadata"])),
pack(_date_to_str(msg.get("content", {}))),
]
to_send = [DELIM, sign(message, key)] + message
to_send = [DELIM, sign(message, key)] + message + msg.get("buffers", [])
return to_send


def deserialize(msg_list: List[bytes]) -> Dict[str, Any]:
def deserialize(
msg_list: List[bytes],
parent_header: Optional[Dict[str, Any]] = None,
change_str_to_date: bool = False,
) -> Dict[str, Any]:
_str_to_date = str_to_date if change_str_to_date else lambda x: x
message: Dict[str, Any] = {}
header = unpack(msg_list[1])
message["header"] = str_to_date(header)
message["header"] = _str_to_date(header)
message["msg_id"] = header["msg_id"]
message["msg_type"] = header["msg_type"]
message["parent_header"] = str_to_date(unpack(msg_list[2]))
if parent_header:
message["parent_header"] = parent_header
else:
message["parent_header"] = _str_to_date(unpack(msg_list[2]))
message["metadata"] = unpack(msg_list[3])
message["content"] = unpack(msg_list[4])
message["buffers"] = [memoryview(b) for b in msg_list[5:]]
return message


async def send_message(
msg: Dict[str, Any], sock: Socket, key: str, change_date_to_str: bool = False
) -> None:
await sock.send_multipart(serialize(msg, key, change_date_to_str=change_date_to_str), copy=True)


async def receive_message(
sock: Socket, timeout: float = float("inf"), change_str_to_date: bool = False
) -> Optional[Dict[str, Any]]:
timeout *= 1000 # in ms
ready = await sock.poll(timeout)
if ready:
msg_list = await sock.recv_multipart()
idents, msg_list = feed_identities(msg_list)
return deserialize(msg_list, change_str_to_date=change_str_to_date)
return None
Loading

0 comments on commit 9e5009f

Please sign in to comment.