diff --git a/jupyverse_api/jupyverse_api/yjs/__init__.py b/jupyverse_api/jupyverse_api/yjs/__init__.py index 7a8dfa5b..6f834afe 100644 --- a/jupyverse_api/jupyverse_api/yjs/__init__.py +++ b/jupyverse_api/jupyverse_api/yjs/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Any @@ -35,6 +37,20 @@ async def create_roomid( ): return await self.create_roomid(path, request, response, user) + @router.put("/api/collaboration/fork_room/{roomid}", status_code=201) + async def fork_room( + roomid, + user: User = Depends(auth.current_user(permissions={"contents": ["read"]})), + ): + return await self.fork_room(roomid, user) + + @router.put("/api/collaboration/merge_room", status_code=200) + async def merge_room( + request: Request, + user: User = Depends(auth.current_user(permissions={"contents": ["read"]})), + ): + return await self.merge_room(request, user) + self.include_router(router) @abstractmethod @@ -55,6 +71,22 @@ async def create_roomid( ): ... + @abstractmethod + async def fork_room( + self, + roomid, + user: User, + ): + ... + + @abstractmethod + async def merge_room( + self, + request: Request, + user: User, + ): + ... + @abstractmethod def get_document( self, diff --git a/jupyverse_api/jupyverse_api/yjs/models.py b/jupyverse_api/jupyverse_api/yjs/models.py index 0fe8aab6..9e112661 100644 --- a/jupyverse_api/jupyverse_api/yjs/models.py +++ b/jupyverse_api/jupyverse_api/yjs/models.py @@ -4,3 +4,8 @@ class CreateDocumentSession(BaseModel): format: str type: str + + +class MergeRoom(BaseModel): + fork_roomid: str + root_roomid: str diff --git a/plugins/contents/fps_contents/fileid.py b/plugins/contents/fps_contents/fileid.py index f489c59d..6cac7dd5 100644 --- a/plugins/contents/fps_contents/fileid.py +++ b/plugins/contents/fps_contents/fileid.py @@ -7,8 +7,6 @@ from anyio import Path from watchfiles import Change, awatch -from jupyverse_api import Singleton - logger = logging.getLogger("contents") @@ -30,7 +28,7 @@ def notify(self, change): self._event.set() -class FileIdManager(metaclass=Singleton): +class FileIdManager: db_path: str initialized: asyncio.Event watchers: Dict[str, List[Watcher]] diff --git a/plugins/contents/fps_contents/routes.py b/plugins/contents/fps_contents/routes.py index 7f759b9e..b6cbb44b 100644 --- a/plugins/contents/fps_contents/routes.py +++ b/plugins/contents/fps_contents/routes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import os @@ -25,6 +27,8 @@ class _Contents(Contents): + _file_id_manager: FileIdManager | None = None + async def create_checkpoint( self, path, @@ -245,7 +249,9 @@ async def write_content(self, content: Union[SaveContent, Dict]) -> None: @property def file_id_manager(self): - return FileIdManager() + if self._file_id_manager is None: + self._file_id_manager = FileIdManager() + return self._file_id_manager def get_available_path(path: Path, sep: str = "") -> Path: diff --git a/plugins/webdav/fps_webdav/routes.py b/plugins/webdav/fps_webdav/routes.py index b31b1fc9..bf67903a 100644 --- a/plugins/webdav/fps_webdav/routes.py +++ b/plugins/webdav/fps_webdav/routes.py @@ -46,7 +46,7 @@ def __init__(self, app: App, webdav_config: WebDAVConfig): for account in webdav_config.account_mapping: logger.info(f"WebDAV user {account.username} has password {account.password}") - webdav_conf = webdav_config.dict() + webdav_conf = webdav_config.model_dump() init_config_from_obj(webdav_conf) webdav_aep = AppEntryParameters() webdav_app = get_asgi_app(aep=webdav_aep, config_obj=webdav_conf) diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 1a023d95..73216488 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -21,7 +21,7 @@ from jupyverse_api.auth import Auth, User from jupyverse_api.contents import Contents from jupyverse_api.yjs import Yjs -from jupyverse_api.yjs.models import CreateDocumentSession +from jupyverse_api.yjs.models import CreateDocumentSession, MergeRoom from .ydocs import ydocs as YDOCS from .ydocs.ybasedoc import YBaseDoc @@ -95,6 +95,39 @@ async def create_roomid( res["fileId"] = idx return res + async def fork_room( + self, + roomid: str, + user: User, + ): + idx = uuid4().hex + + root_room = await self.room_manager.websocket_server.get_room(roomid) + update = root_room.ydoc.get_update() + fork_ydoc = Doc() + fork_ydoc.apply_update(update) + await self.room_manager.websocket_server.get_room(idx, ydoc=fork_ydoc) + root_room.fork_ydocs.add(fork_ydoc) + + res = { + "sessionId": SERVER_SESSION, + "roomId": idx, + } + return res + + async def merge_room( + self, + request: Request, + user: User, + ): + # we need to process the request manually + # see https://github.com/tiangolo/fastapi/issues/3373#issuecomment-1306003451 + merge_room = MergeRoom(**(await request.json())) + fork_room = await self.room_manager.websocket_server.get_room(merge_room.fork_roomid) + root_room = await self.room_manager.websocket_server.get_room(merge_room.root_roomid) + update = fork_room.ydoc.get_update() + root_room.ydoc.apply_update(update) + def get_document(self, document_id: str) -> YBaseDoc: return self.room_manager.documents[document_id] @@ -124,14 +157,14 @@ def __aiter__(self): async def __anext__(self): try: message = await self._websocket.receive_bytes() - except WebSocketDisconnect: + except (ConnectionClosedOK, WebSocketDisconnect): raise StopAsyncIteration() return message async def send(self, message): try: await self._websocket.send_bytes(message) - except ConnectionClosedOK: + except (ConnectionClosedOK, WebSocketDisconnect): return async def recv(self): diff --git a/plugins/yjs/fps_yjs/ywebsocket/yroom.py b/plugins/yjs/fps_yjs/ywebsocket/yroom.py index 15fd41de..d4905ae9 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/yroom.py +++ b/plugins/yjs/fps_yjs/ywebsocket/yroom.py @@ -29,7 +29,8 @@ class YRoom: - clients: list + clients: set[Websocket] + fork_ydocs: set[Doc] ydoc: Doc ystore: BaseYStore | None _on_message: Callable[[bytes], Awaitable[bool] | bool] | None @@ -42,10 +43,10 @@ class YRoom: def __init__( self, - ydoc: Doc | None = None, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None, + ydoc: Doc | None = None, ): """Initialize the object. @@ -76,7 +77,8 @@ def __init__( self.ready = ready self.ystore = ystore self.log = log or getLogger(__name__) - self.clients = [] + self.clients = set() + self.fork_ydocs = set() self._on_message = None self._started = None self._starting = False @@ -133,10 +135,15 @@ async def _broadcast_updates(self): return # broadcast internal ydoc's update to all clients, that includes changes from the # clients and changes from the backend (out-of-band changes) - for client in self.clients: - self.log.debug("Sending Y update to client with endpoint: %s", client.path) + for ydoc in self.fork_ydocs: + ydoc.apply_update(update) + if self.clients: message = create_update_message(update) - self._task_group.start_soon(client.send, message) + for client in self.clients: + self.log.debug( + "Sending Y update to remote client with endpoint: %s", client.path + ) + self._task_group.start_soon(client.send, message) if self.ystore: self.log.debug("Writing Y update to YStore") self._task_group.start_soon(self.ystore.write, update) @@ -197,7 +204,7 @@ async def serve(self, websocket: Websocket): websocket: The WebSocket through which to serve the client. """ async with create_task_group() as tg: - self.clients.append(websocket) + self.clients.add(websocket) await sync(self.ydoc, websocket, self.log) try: async for message in websocket: @@ -236,4 +243,4 @@ async def serve(self, websocket: Websocket): self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e) # remove this client - self.clients = [c for c in self.clients if c != websocket] + self.clients.remove(websocket) diff --git a/tests/test_app.py b/tests/test_app.py index dfd97365..ca9b5600 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,3 +1,5 @@ +import os + import pytest from asphalt.core import Context from fastapi import APIRouter @@ -17,7 +19,8 @@ "/foo", ), ) -async def test_mount_path(mount_path, unused_tcp_port): +async def test_mount_path(mount_path, unused_tcp_port, tmp_path): + os.chdir(tmp_path) components = configure({"app": {"type": "app"}}, {"app": {"mount_path": mount_path}}) async with Context() as ctx, AsyncClient() as http: diff --git a/tests/test_auth.py b/tests/test_auth.py index e8a3b5ed..7b9ccbd4 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,3 +1,5 @@ +import os + import pytest from asphalt.core import Context from httpx import AsyncClient @@ -20,7 +22,8 @@ @pytest.mark.asyncio -async def test_kernel_channels_unauthenticated(unused_tcp_port): +async def test_kernel_channels_unauthenticated(unused_tcp_port, tmp_path): + os.chdir(tmp_path) async with Context() as ctx: await JupyverseComponent( components=COMPONENTS, @@ -35,7 +38,8 @@ async def test_kernel_channels_unauthenticated(unused_tcp_port): @pytest.mark.asyncio -async def test_kernel_channels_authenticated(unused_tcp_port): +async def test_kernel_channels_authenticated(unused_tcp_port, tmp_path): + os.chdir(tmp_path) async with Context() as ctx, AsyncClient() as http: await JupyverseComponent( components=COMPONENTS, @@ -52,7 +56,8 @@ async def test_kernel_channels_authenticated(unused_tcp_port): @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth", "token", "user")) -async def test_root_auth(auth_mode, unused_tcp_port): +async def test_root_auth(auth_mode, unused_tcp_port, tmp_path): + os.chdir(tmp_path) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) async with Context() as ctx, AsyncClient() as http: await JupyverseComponent( @@ -72,7 +77,8 @@ async def test_root_auth(auth_mode, unused_tcp_port): @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_no_auth(auth_mode, unused_tcp_port): +async def test_no_auth(auth_mode, unused_tcp_port, tmp_path): + os.chdir(tmp_path) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) async with Context() as ctx, AsyncClient() as http: await JupyverseComponent( @@ -86,7 +92,8 @@ async def test_no_auth(auth_mode, unused_tcp_port): @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("token",)) -async def test_token_auth(auth_mode, unused_tcp_port): +async def test_token_auth(auth_mode, unused_tcp_port, tmp_path): + os.chdir(tmp_path) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) async with Context() as ctx, AsyncClient() as http: await JupyverseComponent( @@ -113,7 +120,8 @@ async def test_token_auth(auth_mode, unused_tcp_port): {"admin": ["read"], "foo": ["bar", "baz"]}, ), ) -async def test_permissions(auth_mode, permissions, unused_tcp_port): +async def test_permissions(auth_mode, permissions, unused_tcp_port, tmp_path): + os.chdir(tmp_path) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) async with Context() as ctx, AsyncClient() as http: await JupyverseComponent( diff --git a/tests/test_contents.py b/tests/test_contents.py index b44a4aac..78eb1e72 100644 --- a/tests/test_contents.py +++ b/tests/test_contents.py @@ -19,7 +19,6 @@ @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_tree(auth_mode, tmp_path, unused_tcp_port): - prev_dir = os.getcwd() os.chdir(tmp_path) dname = Path(".") expected = [] @@ -81,4 +80,3 @@ async def test_tree(auth_mode, tmp_path, unused_tcp_port): sort_content_by_name(actual) sort_content_by_name(expected) assert actual == expected - os.chdir(prev_dir) diff --git a/tests/test_execute.py b/tests/test_execute.py index d423f1a1..06e398ac 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -2,6 +2,7 @@ import os from functools import partial from pathlib import Path +from shutil import copytree, rmtree import pytest from asphalt.core import Context @@ -27,6 +28,8 @@ "yjs": {"type": "yjs"}, } +HERE = Path(__file__).parent + class Websocket: def __init__(self, websocket, roomid: str): @@ -57,7 +60,11 @@ async def recv(self) -> bytes: @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_execute(auth_mode, unused_tcp_port): +async def test_execute(auth_mode, unused_tcp_port, tmp_path): + os.chdir(tmp_path) + if Path("data").exists(): + rmtree("data") + copytree(HERE / "data", "data") url = f"http://127.0.0.1:{unused_tcp_port}" components = configure(COMPONENTS, { "auth": {"mode": auth_mode}, @@ -71,7 +78,7 @@ async def test_execute(auth_mode, unused_tcp_port): ws_url = url.replace("http", "ws", 1) name = "notebook1.ipynb" - path = (Path("tests") / "data" / name).as_posix() + path = (Path("data") / name).as_posix() # create a session to launch a kernel response = await http.post( f"{url}/api/sessions", @@ -92,7 +99,8 @@ async def test_execute(auth_mode, unused_tcp_port): "type": "notebook", } ) - file_id = response.json()["fileId"] + r = response.json() + file_id = r["fileId"] document_id = f"json:notebook:{file_id}" ynb = ydocs["notebook"]() def callback(aevent, events, event): diff --git a/tests/test_kernels.py b/tests/test_kernels.py index dba726b9..baac1632 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -28,7 +28,8 @@ @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_kernel_messages(auth_mode, capfd, unused_tcp_port): +async def test_kernel_messages(auth_mode, capfd, unused_tcp_port, tmp_path): + os.chdir(tmp_path) kernel_id = "kernel_id_0" kernel_name = "python3" kernelspec_path = ( diff --git a/tests/test_server.py b/tests/test_server.py index bc2325d4..11e06965 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,6 @@ import asyncio import json +import os from functools import partial from pathlib import Path @@ -16,7 +17,8 @@ @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) -def test_settings_persistence_put(start_jupyverse): +def test_settings_persistence_put(start_jupyverse, tmp_path): + os.chdir(tmp_path) url = start_jupyverse # get previous theme response = requests.get(url + "/lab/api/settings/@jupyterlab/apputils-extension:themes") @@ -31,7 +33,8 @@ def test_settings_persistence_put(start_jupyverse): @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) -def test_settings_persistence_get(start_jupyverse): +def test_settings_persistence_get(start_jupyverse, tmp_path): + os.chdir(tmp_path) url = start_jupyverse # get new theme response = requests.get( @@ -50,7 +53,8 @@ def test_settings_persistence_get(start_jupyverse): @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) -async def test_rest_api(start_jupyverse): +async def test_rest_api(start_jupyverse, tmp_path): + os.chdir(tmp_path) url = start_jupyverse ws_url = url.replace("http", "ws", 1) name = "notebook0.ipynb" @@ -128,7 +132,8 @@ async def test_rest_api(start_jupyverse): @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) -async def test_ywidgets(start_jupyverse): +async def test_ywidgets(start_jupyverse, tmp_path): + os.chdir(tmp_path) url = start_jupyverse ws_url = url.replace("http", "ws", 1) name = "notebook1.ipynb" diff --git a/tests/test_settings.py b/tests/test_settings.py index 03cb6a60..29752018 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,4 +1,5 @@ import json +import os import pytest from asphalt.core import Context @@ -23,7 +24,8 @@ @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_settings(auth_mode, unused_tcp_port): +async def test_settings(auth_mode, unused_tcp_port, tmp_path): + os.chdir(tmp_path) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) async with Context() as ctx, AsyncClient() as http: await JupyverseComponent( diff --git a/tests/test_yjs.py b/tests/test_yjs.py new file mode 100644 index 00000000..e8b617db --- /dev/null +++ b/tests/test_yjs.py @@ -0,0 +1,117 @@ +import os +from asyncio import sleep +from pathlib import Path + +import pytest +from asphalt.core import Context +from fps_yjs.ywebsocket import WebsocketProvider +from httpx import AsyncClient +from httpx_ws import aconnect_ws +from pycrdt import Doc, Text + +from jupyverse_api.main import JupyverseComponent +from jupyverse_api.yjs.models import CreateDocumentSession, MergeRoom + + +@pytest.mark.asyncio +async def test_fork_room(tmp_path, unused_tcp_port): + os.chdir(tmp_path) + path = Path("foo.txt") + path.write_text("Hello") + + components = { + "app": {"type": "app"}, + "auth": {"type": "auth", "test": True, "mode": "noauth"}, + "contents": {"type": "contents"}, + "frontend": {"type": "frontend"}, + "yjs": {"type": "yjs"}, + } + async with Context() as ctx, AsyncClient() as http: + await JupyverseComponent( + components=components, + port=unused_tcp_port, + ).start(ctx) + await sleep(1) + + create_document_session = CreateDocumentSession(format="text", type="file") + response = await http.put( + f"http://127.0.0.1:{unused_tcp_port}/api/collaboration/session/{path}", + json=create_document_session.model_dump(), + ) + r = response.json() + file_id = r["fileId"] + + # connect to root room + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/collaboration/room/text:file:{file_id}" + ) as root_ws: + # create a root room client + root_ydoc = Doc() + root_ydoc["source"] = root_ytext = Text() + async with WebsocketProvider(root_ydoc, Websocket(root_ws, file_id)): + await sleep(0.1) + assert str(root_ytext) == "Hello" + # fork room + root_roomid = f"text:file:{file_id}" + response = await http.put( + f"http://127.0.0.1:{unused_tcp_port}/api/collaboration/fork_room/{root_roomid}" + ) + r = response.json() + fork_roomid = r["roomId"] + # connect to forked room + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/collaboration/room/{fork_roomid}" + ) as fork_ws: + # create a forked room client + fork_ydoc = Doc() + fork_ydoc["source"] = fork_ytext = Text() + async with WebsocketProvider(fork_ydoc, Websocket(fork_ws, fork_roomid)): + await sleep(0.1) + assert str(fork_ytext) == "Hello" + # check that the forked room is synced with the root room + root_ytext += ", World!" + await sleep(0.1) + assert str(root_ytext) == "Hello, World!" + assert str(fork_ytext) == "Hello, World!" + # check that the root room is not synced with the forked room + fork_ytext += " Bye!" + await sleep(0.1) + assert str(root_ytext) == "Hello, World!" + assert str(fork_ytext) == "Hello, World! Bye!" + # merge forked room into root room + merge_room = MergeRoom(fork_roomid=fork_roomid, root_roomid=root_roomid) + response = await http.put( + f"http://127.0.0.1:{unused_tcp_port}/api/collaboration/merge_room", + json=merge_room.model_dump(), + ) + # check that the root room is synced with the forked room + await sleep(0.1) + assert str(root_ytext) == "Hello, World! Bye!" + assert str(fork_ytext) == "Hello, World! Bye!" + + +class Websocket: + def __init__(self, websocket, roomid: str): + self.websocket = websocket + self.roomid = roomid + + @property + def path(self) -> str: + return self.roomid + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + try: + message = await self.recv() + except BaseException: + raise StopAsyncIteration() + return message + + async def send(self, message: bytes): + await self.websocket.send_bytes(message) + + async def recv(self) -> bytes: + b = await self.websocket.receive_bytes() + return bytes(b)