Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support YRoom forking #383

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions jupyverse_api/jupyverse_api/yjs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -35,6 +37,20 @@ async def create_roomid(
):
return await self.create_roomid(path, request, response, user)

@router.put("/api/collaboration/fork_room/{roomid}", status_code=201)
async def fork_room(
roomid,
user: User = Depends(auth.current_user(permissions={"contents": ["read"]})),
):
return await self.fork_room(roomid, user)

@router.put("/api/collaboration/merge_room", status_code=200)
async def merge_room(
request: Request,
user: User = Depends(auth.current_user(permissions={"contents": ["read"]})),
):
return await self.merge_room(request, user)

self.include_router(router)

@abstractmethod
Expand All @@ -55,6 +71,22 @@ async def create_roomid(
):
...

@abstractmethod
async def fork_room(
self,
roomid,
user: User,
):
...

@abstractmethod
async def merge_room(
self,
request: Request,
user: User,
):
...

@abstractmethod
def get_document(
self,
Expand Down
5 changes: 5 additions & 0 deletions jupyverse_api/jupyverse_api/yjs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
class CreateDocumentSession(BaseModel):
format: str
type: str


class MergeRoom(BaseModel):
fork_roomid: str
root_roomid: str
4 changes: 1 addition & 3 deletions plugins/contents/fps_contents/fileid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from anyio import Path
from watchfiles import Change, awatch

from jupyverse_api import Singleton

logger = logging.getLogger("contents")


Expand All @@ -30,7 +28,7 @@ def notify(self, change):
self._event.set()


class FileIdManager(metaclass=Singleton):
class FileIdManager:
db_path: str
initialized: asyncio.Event
watchers: Dict[str, List[Watcher]]
Expand Down
8 changes: 7 additions & 1 deletion plugins/contents/fps_contents/routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import json
import os
Expand Down Expand Up @@ -25,6 +27,8 @@


class _Contents(Contents):
_file_id_manager: FileIdManager | None = None

async def create_checkpoint(
self,
path,
Expand Down Expand Up @@ -245,7 +249,9 @@ async def write_content(self, content: Union[SaveContent, Dict]) -> None:

@property
def file_id_manager(self):
return FileIdManager()
if self._file_id_manager is None:
self._file_id_manager = FileIdManager()
return self._file_id_manager


def get_available_path(path: Path, sep: str = "") -> Path:
Expand Down
2 changes: 1 addition & 1 deletion plugins/webdav/fps_webdav/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, app: App, webdav_config: WebDAVConfig):

for account in webdav_config.account_mapping:
logger.info(f"WebDAV user {account.username} has password {account.password}")
webdav_conf = webdav_config.dict()
webdav_conf = webdav_config.model_dump()
init_config_from_obj(webdav_conf)
webdav_aep = AppEntryParameters()
webdav_app = get_asgi_app(aep=webdav_aep, config_obj=webdav_conf)
Expand Down
39 changes: 36 additions & 3 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jupyverse_api.auth import Auth, User
from jupyverse_api.contents import Contents
from jupyverse_api.yjs import Yjs
from jupyverse_api.yjs.models import CreateDocumentSession
from jupyverse_api.yjs.models import CreateDocumentSession, MergeRoom

from .ydocs import ydocs as YDOCS
from .ydocs.ybasedoc import YBaseDoc
Expand Down Expand Up @@ -95,6 +95,39 @@ async def create_roomid(
res["fileId"] = idx
return res

async def fork_room(
self,
roomid: str,
user: User,
):
idx = uuid4().hex

root_room = await self.room_manager.websocket_server.get_room(roomid)
update = root_room.ydoc.get_update()
fork_ydoc = Doc()
fork_ydoc.apply_update(update)
await self.room_manager.websocket_server.get_room(idx, ydoc=fork_ydoc)
root_room.fork_ydocs.add(fork_ydoc)

res = {
"sessionId": SERVER_SESSION,
"roomId": idx,
}
return res

async def merge_room(
self,
request: Request,
user: User,
):
# we need to process the request manually
# see https://github.com/tiangolo/fastapi/issues/3373#issuecomment-1306003451
merge_room = MergeRoom(**(await request.json()))
fork_room = await self.room_manager.websocket_server.get_room(merge_room.fork_roomid)
root_room = await self.room_manager.websocket_server.get_room(merge_room.root_roomid)
update = fork_room.ydoc.get_update()
root_room.ydoc.apply_update(update)

def get_document(self, document_id: str) -> YBaseDoc:
return self.room_manager.documents[document_id]

Expand Down Expand Up @@ -124,14 +157,14 @@ def __aiter__(self):
async def __anext__(self):
try:
message = await self._websocket.receive_bytes()
except WebSocketDisconnect:
except (ConnectionClosedOK, WebSocketDisconnect):
raise StopAsyncIteration()
return message

async def send(self, message):
try:
await self._websocket.send_bytes(message)
except ConnectionClosedOK:
except (ConnectionClosedOK, WebSocketDisconnect):
return

async def recv(self):
Expand Down
23 changes: 15 additions & 8 deletions plugins/yjs/fps_yjs/ywebsocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@


class YRoom:
clients: list
clients: set[Websocket]
fork_ydocs: set[Doc]
ydoc: Doc
ystore: BaseYStore | None
_on_message: Callable[[bytes], Awaitable[bool] | bool] | None
Expand All @@ -42,10 +43,10 @@ class YRoom:

def __init__(
self,
ydoc: Doc | None = None,
ready: bool = True,
ystore: BaseYStore | None = None,
log: Logger | None = None,
ydoc: Doc | None = None,
):
"""Initialize the object.

Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(
self.ready = ready
self.ystore = ystore
self.log = log or getLogger(__name__)
self.clients = []
self.clients = set()
self.fork_ydocs = set()
self._on_message = None
self._started = None
self._starting = False
Expand Down Expand Up @@ -133,10 +135,15 @@ async def _broadcast_updates(self):
return
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
for ydoc in self.fork_ydocs:
ydoc.apply_update(update)
if self.clients:
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
for client in self.clients:
self.log.debug(
"Sending Y update to remote client with endpoint: %s", client.path
)
self._task_group.start_soon(client.send, message)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)
Expand Down Expand Up @@ -197,7 +204,7 @@ async def serve(self, websocket: Websocket):
websocket: The WebSocket through which to serve the client.
"""
async with create_task_group() as tg:
self.clients.append(websocket)
self.clients.add(websocket)
await sync(self.ydoc, websocket, self.log)
try:
async for message in websocket:
Expand Down Expand Up @@ -236,4 +243,4 @@ async def serve(self, websocket: Websocket):
self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e)

# remove this client
self.clients = [c for c in self.clients if c != websocket]
self.clients.remove(websocket)
5 changes: 4 additions & 1 deletion tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
from asphalt.core import Context
from fastapi import APIRouter
Expand All @@ -17,7 +19,8 @@
"/foo",
),
)
async def test_mount_path(mount_path, unused_tcp_port):
async def test_mount_path(mount_path, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
components = configure({"app": {"type": "app"}}, {"app": {"mount_path": mount_path}})

async with Context() as ctx, AsyncClient() as http:
Expand Down
20 changes: 14 additions & 6 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
from asphalt.core import Context
from httpx import AsyncClient
Expand All @@ -20,7 +22,8 @@


@pytest.mark.asyncio
async def test_kernel_channels_unauthenticated(unused_tcp_port):
async def test_kernel_channels_unauthenticated(unused_tcp_port, tmp_path):
os.chdir(tmp_path)
async with Context() as ctx:
await JupyverseComponent(
components=COMPONENTS,
Expand All @@ -35,7 +38,8 @@ async def test_kernel_channels_unauthenticated(unused_tcp_port):


@pytest.mark.asyncio
async def test_kernel_channels_authenticated(unused_tcp_port):
async def test_kernel_channels_authenticated(unused_tcp_port, tmp_path):
os.chdir(tmp_path)
async with Context() as ctx, AsyncClient() as http:
await JupyverseComponent(
components=COMPONENTS,
Expand All @@ -52,7 +56,8 @@ async def test_kernel_channels_authenticated(unused_tcp_port):

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth", "token", "user"))
async def test_root_auth(auth_mode, unused_tcp_port):
async def test_root_auth(auth_mode, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
components = configure(COMPONENTS, {"auth": {"mode": auth_mode}})
async with Context() as ctx, AsyncClient() as http:
await JupyverseComponent(
Expand All @@ -72,7 +77,8 @@ async def test_root_auth(auth_mode, unused_tcp_port):

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth",))
async def test_no_auth(auth_mode, unused_tcp_port):
async def test_no_auth(auth_mode, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
components = configure(COMPONENTS, {"auth": {"mode": auth_mode}})
async with Context() as ctx, AsyncClient() as http:
await JupyverseComponent(
Expand All @@ -86,7 +92,8 @@ async def test_no_auth(auth_mode, unused_tcp_port):

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("token",))
async def test_token_auth(auth_mode, unused_tcp_port):
async def test_token_auth(auth_mode, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
components = configure(COMPONENTS, {"auth": {"mode": auth_mode}})
async with Context() as ctx, AsyncClient() as http:
await JupyverseComponent(
Expand All @@ -113,7 +120,8 @@ async def test_token_auth(auth_mode, unused_tcp_port):
{"admin": ["read"], "foo": ["bar", "baz"]},
),
)
async def test_permissions(auth_mode, permissions, unused_tcp_port):
async def test_permissions(auth_mode, permissions, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
components = configure(COMPONENTS, {"auth": {"mode": auth_mode}})
async with Context() as ctx, AsyncClient() as http:
await JupyverseComponent(
Expand Down
2 changes: 0 additions & 2 deletions tests/test_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth",))
async def test_tree(auth_mode, tmp_path, unused_tcp_port):
prev_dir = os.getcwd()
os.chdir(tmp_path)
dname = Path(".")
expected = []
Expand Down Expand Up @@ -81,4 +80,3 @@ async def test_tree(auth_mode, tmp_path, unused_tcp_port):
sort_content_by_name(actual)
sort_content_by_name(expected)
assert actual == expected
os.chdir(prev_dir)
14 changes: 11 additions & 3 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from functools import partial
from pathlib import Path
from shutil import copytree, rmtree

import pytest
from asphalt.core import Context
Expand All @@ -27,6 +28,8 @@
"yjs": {"type": "yjs"},
}

HERE = Path(__file__).parent


class Websocket:
def __init__(self, websocket, roomid: str):
Expand Down Expand Up @@ -57,7 +60,11 @@ async def recv(self) -> bytes:

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth",))
async def test_execute(auth_mode, unused_tcp_port):
async def test_execute(auth_mode, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
if Path("data").exists():
rmtree("data")
copytree(HERE / "data", "data")
url = f"http://127.0.0.1:{unused_tcp_port}"
components = configure(COMPONENTS, {
"auth": {"mode": auth_mode},
Expand All @@ -71,7 +78,7 @@ async def test_execute(auth_mode, unused_tcp_port):

ws_url = url.replace("http", "ws", 1)
name = "notebook1.ipynb"
path = (Path("tests") / "data" / name).as_posix()
path = (Path("data") / name).as_posix()
# create a session to launch a kernel
response = await http.post(
f"{url}/api/sessions",
Expand All @@ -92,7 +99,8 @@ async def test_execute(auth_mode, unused_tcp_port):
"type": "notebook",
}
)
file_id = response.json()["fileId"]
r = response.json()
file_id = r["fileId"]
document_id = f"json:notebook:{file_id}"
ynb = ydocs["notebook"]()
def callback(aevent, events, event):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth",))
async def test_kernel_messages(auth_mode, capfd, unused_tcp_port):
async def test_kernel_messages(auth_mode, capfd, unused_tcp_port, tmp_path):
os.chdir(tmp_path)
kernel_id = "kernel_id_0"
kernel_name = "python3"
kernelspec_path = (
Expand Down
Loading
Loading