Skip to content

Commit

Permalink
Support renaming kernelspecs at runtime.
Browse files Browse the repository at this point in the history
This change adds support for kernel spec managers that rename kernel specs
based on configured traits.

This is a necessary step in the work to support multiplexing between
multiple kernel spec managers (jupyter-server#1187), as we need to be able to rename
kernel specs in order to prevent collisions between the kernel specs
provided by multiple kernel spec managers.
  • Loading branch information
ojarjur committed Apr 29, 2023
1 parent c53e658 commit e5ad55f
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 5 deletions.
32 changes: 27 additions & 5 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from jupyter_core.utils import ensure_async
from tornado import web
from tornado.escape import json_decode, json_encode, url_escape, utf8
from traitlets import DottedObjectName, Instance, Type, default
from traitlets import DottedObjectName, Instance, Type, Unicode, default

from .._tz import UTC, utcnow
from ..services.kernels.kernelmanager import (
AsyncMappingKernelManager,
ServerKernelManager,
emit_kernel_action_event,
)
from ..services.kernelspecs.renaming import RenamingKernelSpecManagerMixin, normalize_kernel_name
from ..services.sessions.sessionmanager import SessionManager
from ..utils import url_path_join
from .gateway_client import GatewayClient, gateway_request
Expand All @@ -52,6 +53,13 @@ def __init__(self, **kwargs):
self.kernels_url = url_path_join(
GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint
)
if hasattr(self.kernel_spec_manager, "default_kernel_name"):
self.kernel_spec_manager.observe(
self.on_default_kernel_name_change, "default_kernel_name"
)

def on_default_kernel_name_change(self, change):
self.default_kernel_name = change.new

def remove_kernel(self, kernel_id):
"""Complete override since we want to be more tolerant of missing keys"""
Expand All @@ -60,6 +68,7 @@ def remove_kernel(self, kernel_id):
except KeyError:
pass

@normalize_kernel_name
async def start_kernel(self, *, kernel_id=None, path=None, **kwargs):
"""Start a kernel for a session and return its kernel_id.
Expand Down Expand Up @@ -210,6 +219,8 @@ async def cull_kernels(self):
class GatewayKernelSpecManager(KernelSpecManager):
"""A gateway kernel spec manager."""

default_kernel_name = Unicode(allow_none=True)

def __init__(self, **kwargs):
"""Initialize a gateway kernel spec manager."""
super().__init__(**kwargs)
Expand Down Expand Up @@ -273,14 +284,13 @@ async def get_all_specs(self):
# If different log a warning and reset the default. However, the
# caller of this method will still return this server's value until
# the next fetch of kernelspecs - at which time they'll match.
km = self.parent.kernel_manager
remote_default_kernel_name = fetched_kspecs.get("default")
if remote_default_kernel_name != km.default_kernel_name:
if remote_default_kernel_name != self.default_kernel_name:
self.log.info(
f"Default kernel name on Gateway server ({remote_default_kernel_name}) differs from "
f"Notebook server ({km.default_kernel_name}). Updating to Gateway server's value."
f"Notebook server ({self.default_kernel_name}). Updating to Gateway server's value."
)
km.default_kernel_name = remote_default_kernel_name
self.default_kernel_name = remote_default_kernel_name

remote_kspecs = fetched_kspecs.get("kernelspecs")
return remote_kspecs
Expand Down Expand Up @@ -345,6 +355,18 @@ async def get_kernel_spec_resource(self, kernel_name, path):
return kernel_spec_resource


class GatewayRenamingKernelSpecManager(RenamingKernelSpecManagerMixin, GatewayKernelSpecManager):
spec_name_prefix = Unicode(
"remote-", config=True, help="Prefix to be added onto the front of kernel spec names."
)

display_name_suffix = Unicode(
" (Remote)",
config=True,
help="Suffix to be added onto the end of kernel spec display names.",
)


class GatewaySessionManager(SessionManager):
"""A gateway session manager."""

Expand Down
15 changes: 15 additions & 0 deletions jupyter_server/services/kernels/kernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TraitError,
Unicode,
default,
observe,
validate,
)

Expand All @@ -46,6 +47,8 @@
from jupyter_server.prometheus.metrics import KERNEL_CURRENTLY_RUNNING_TOTAL
from jupyter_server.utils import ApiPath, import_item, to_os_path

from ..kernelspecs.renaming import normalize_kernel_name


class MappingKernelManager(MultiKernelManager):
"""A KernelManager that handles
Expand Down Expand Up @@ -206,6 +209,7 @@ async def _remove_kernel_when_ready(self, kernel_id, kernel_awaitable):

# TODO DEC 2022: Revise the type-ignore once the signatures have been changed upstream
# https://github.com/jupyter/jupyter_client/pull/905
@normalize_kernel_name
async def _async_start_kernel( # type:ignore[override]
self, *, kernel_id: Optional[str] = None, path: Optional[ApiPath] = None, **kwargs: str
) -> str:
Expand Down Expand Up @@ -700,12 +704,23 @@ def _validate_kernel_manager_class(self, proposal):
)
return km_class_value

@observe("default_kernel_name")
def _observe_default_kernel_name(self, change):
if hasattr(self.kernel_spec_manager, "maybe_rename_kernel"):
renamed_kernel = self.kernel_spec_manager.maybe_rename_kernel(change.new)
if renamed_kernel is not change.new:
self.default_kernel_name = renamed_kernel

def __init__(self, **kwargs):
"""Initialize an async mapping kernel manager."""
self.pinned_superclass = MultiKernelManager
self._pending_kernel_tasks = {}
self.pinned_superclass.__init__(self, **kwargs)
self.last_kernel_activity = utcnow()
if hasattr(self.kernel_spec_manager, "rename_kernel"):
self.default_kernel_name = self.kernel_spec_manager.rename_kernel(
self.default_kernel_name
)


def emit_kernel_action_event(success_msg: str = ""): # type: ignore
Expand Down
141 changes: 141 additions & 0 deletions jupyter_server/services/kernelspecs/renaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Support for renaming kernel specs at runtime."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from functools import wraps
from typing import Any

from jupyter_client.kernelspec import KernelSpecManager
from jupyter_core.utils import ensure_async, run_sync
from traitlets import HasTraits, Unicode, default


def normalize_kernel_name(method):
@wraps(method)
async def wrapped_method(self, *args, **kwargs):
kernel_name = kwargs.get("kernel_name", None)
if (
kernel_name
and hasattr(self, "kernel_spec_manager")
and hasattr(self.kernel_spec_manager, "original_kernel_name")
):
kwargs["kernel_name"] = self.kernel_spec_manager.original_kernel_name(kernel_name)
return await method(self, *args, **kwargs)

return wrapped_method


class RenamingKernelSpecManagerMixin(HasTraits):
"""KernelSpecManager mixin that renames kernel specs.
The base KernelSpecManager class only has synchronous methods, but some child
classes (in particular, GatewayKernelManager) change those methods to be async.
In order to support both versions, we provide both synchronous and async versions
of all the relevant kernel spec manager methods. We first do the renaming in the
async version, but override the KernelSpecManager base methods using the
synchronous versions.
"""

spec_name_prefix = Unicode(
config=True, help="Prefix to be added onto the front of kernel spec names."
)

spec_name_format = Unicode(
config=True,
help="""Format for rewritten kernel spec names.
Defaults to prefixing the kernel spec name with the value of the
`spec_name_prefix` attribute if it has been set.
""",
)

@default("spec_name_format")
def _default_spec_name_format(self):
if self.spec_name_prefix:
return self.spec_name_prefix + "{}"
return "{}"

display_name_suffix = Unicode(
config=True, help="Suffix to be added onto the end of kernel spec display names."
)

display_name_format = Unicode(
config=True, help="Format for rewritten kernel spec display names."
)

@default("display_name_format")
def _default_display_name_format(self):
if self.display_name_suffix:
return "{}" + self.display_name_suffix
return "{}"

default_kernel_name = Unicode(allow_none=True)

def rename_kernel(self, kernel_name: str) -> str:
"""Rename the supplied kernel spec based on the configured format string."""
if not hasattr(self, "original_kernel_names"):
self.original_kernel_names = {}

renamed = self.spec_name_format.format(kernel_name)
self.original_kernel_names[renamed] = kernel_name
return renamed

def maybe_rename_kernel(self, kernel_name: str) -> str:
"""Rename the supplied kernel if it is not already the result of a rename."""
if not hasattr(self, "original_kernel_names"):
self.original_kernel_names = {}
if kernel_name in self.original_kernel_names:
# The kernel was already renamed
return kernel_name
return self.rename_kernel(kernel_name)

def original_kernel_name(self, kernel_name: str) -> str:
if not hasattr(self, "original_kernel_names"):
return kernel_name

return self.original_kernel_names.get(kernel_name, kernel_name)

async def async_get_all_specs(self):
ks = {}
original_ks = await ensure_async(super().get_all_specs())
for s, k in original_ks.items():
spec_name = s
kernel_spec = k
original_prefix = f"/kernelspecs/{spec_name}"
spec_name = self.rename_kernel(spec_name)
new_prefix = f"/kernelspecs/{spec_name}"

ks[spec_name] = kernel_spec
kernel_spec["name"] = spec_name
kernel_spec["spec"] = kernel_spec.get("spec", {})
kernel_spec["resources"] = kernel_spec.get("resources", {})

spec = kernel_spec["spec"]
spec["display_name"] = self.display_name_format.format(spec.get("display_name"))

resources = kernel_spec["resources"]
for name, value in resources.items():
resources[name] = value.replace(original_prefix, new_prefix)
if hasattr(super(), "default_kernel_name"):
self.default_kernel_name = self.rename_kernel(super().default_kernel_name)
return ks

def get_all_specs(self):
return run_sync(self.async_get_all_specs)()

async def async_get_kernel_spec(self, kernel_name: str, *args: Any, **kwargs: Any):
kernel_name = self.original_kernel_name(kernel_name)
return await ensure_async(super().get_kernel_spec(kernel_name, *args, **kwargs))

def get_kernel_spec(self, kernel_name: str, *args: Any, **kwargs: Any):
return run_sync(self.async_get_kernel_spec)(kernel_name, *args, **kwargs)

async def get_kernel_spec_resource(self, kernel_name: str, *args: Any, **kwargs: Any):
if not hasattr(super(), "get_kernel_spec_resource"):
return None
kernel_name = self.original_kernel_name(kernel_name)
return await ensure_async(super().get_kernel_spec_resource(kernel_name, *args, **kwargs))


class RenamingKernelSpecManager(RenamingKernelSpecManagerMixin, KernelSpecManager):
"""KernelSpecManager that renames kernels"""

0 comments on commit e5ad55f

Please sign in to comment.