Skip to content

Commit

Permalink
Use a lock to prevent document access or change between threads (#16)
Browse files Browse the repository at this point in the history
This can happen if a document change is happening while another one is still in process because Yrs transaction are not thread safe.

Co-authored-by: Frédéric Collonval <[email protected]>
  • Loading branch information
fcollonval and fcollonval authored Dec 17, 2024
1 parent 99ebff9 commit 20ff2d1
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 38 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,4 @@ dmypy.json
# OSX files
.DS_Store
.jupyter_ystore.db
.dot-env
6 changes: 4 additions & 2 deletions jupyter_nbmodel_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def start(self) -> None:
emsg = f"Unable to open a websocket connection to {self._server_url} within {self._timeout} s."
raise TimeoutError(emsg)

sync_message = create_sync_message(self._doc.ydoc)
with self._lock:
sync_message = create_sync_message(self._doc.ydoc)
self._log.debug(
"Sending SYNC_STEP1 message for document %s",
self._path,
Expand Down Expand Up @@ -201,7 +202,8 @@ def _on_message(self, websocket: WebSocket, message: bytes) -> None:
YSyncMessageType(message[1]).name,
self._path,
)
reply = handle_sync_message(message[1:], self._doc.ydoc)
with self._lock:
reply = handle_sync_message(message[1:], self._doc.ydoc)
if message[1] == YSyncMessageType.SYNC_STEP2:
self.__synced.set()
if reply is not None:
Expand Down
98 changes: 63 additions & 35 deletions jupyter_nbmodel_client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import threading
import typing as t
import warnings
from collections.abc import MutableSequence
Expand Down Expand Up @@ -43,8 +44,6 @@ def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> s
return set()




class KernelClient(t.Protocol):
"""Interface to be implemented by the kernel client."""

Expand Down Expand Up @@ -89,7 +88,9 @@ def execute_interactive(
...


def save_in_notebook_hook(outputs: list[dict], ycell: pycrdt.Map, msg: dict) -> None:
def save_in_notebook_hook(
lock: threading.Lock, outputs: list[dict], ycell: pycrdt.Map, msg: dict
) -> None:
"""Callback on execution request when an output is emitted.
Args:
Expand All @@ -100,15 +101,18 @@ def save_in_notebook_hook(outputs: list[dict], ycell: pycrdt.Map, msg: dict) ->
indexes = output_hook(outputs, msg)
cell_outputs = t.cast(pycrdt.Array, ycell["outputs"])
if len(indexes) == len(cell_outputs):
with cell_outputs.doc.transaction():
cell_outputs.clear()
cell_outputs.extend(outputs)
with lock:
with cell_outputs.doc.transaction():
cell_outputs.clear()
cell_outputs.extend(outputs)
else:
for index in indexes:
if index >= len(cell_outputs):
cell_outputs.append(outputs[index])
else:
cell_outputs[index] = outputs[index]
with lock:
with cell_outputs.doc.transaction():
for index in indexes:
if index >= len(cell_outputs):
cell_outputs.append(outputs[index])
else:
cell_outputs[index] = outputs[index]


class NotebookModel(MutableSequence):
Expand All @@ -122,21 +126,32 @@ class NotebookModel(MutableSequence):

def __init__(self) -> None:
self._doc = YNotebook()
self._lock = threading.Lock()
"""Lock to prevent updating the document in multiple threads simultaneously.
That may induce a Panic error; see https://github.com/datalayer/jupyter-nbmodel-client/issues/12
"""

# Initialize _doc
self._reset_y_model()

def __delitem__(self, index: int) -> NotebookNode:
raw_ycell = self._doc.ycells.pop(index)
with self._lock:
raw_ycell = self._doc.ycells.pop(index)
cell: dict[str, t.Any] = raw_ycell.to_py()
nbcell = NotebookNode(**cell)
return nbcell

def __getitem__(self, index: int) -> NotebookNode:
raw_ycell = self._doc.ycells[index]
cell = raw_ycell.to_py()
with self._lock:
cell = raw_ycell.to_py()
nbcell = NotebookNode(**cell)
return nbcell

def __setitem__(self, index: int, value: dict[str, t.Any]) -> None:
self._doc.set_cell(index, value)
with self._lock:
self._doc.set_cell(index, value)

def __len__(self) -> int:
"""Number of cells"""
Expand All @@ -145,24 +160,28 @@ def __len__(self) -> int:
@property
def nbformat(self) -> int:
"""Notebook format major version."""
return int(self._doc._ymeta.get("nbformat"))
with self._lock:
return int(self._doc._ymeta.get("nbformat") or current_api.nbformat_minor)

@property
def nbformat_minor(self) -> int:
"""Notebook format minor version."""
return int(self._doc._ymeta.get("nbformat_minor"))
with self._lock:
return int(self._doc._ymeta.get("nbformat_minor") or current_api.nbformat_minor)

@property
def metadata(self) -> dict[str, t.Any]:
"""Notebook metadata."""
return t.cast(pycrdt.Map, self._doc._ymeta["metadata"]).to_py()
with self._lock:
return t.cast(pycrdt.Map, self._doc._ymeta["metadata"]).to_py() or {}

@metadata.setter
def metadata(self, value: dict[str, t.Any]) -> None:
metadata = t.cast(pycrdt.Map, self._doc._ymeta["metadata"])
with metadata.doc.transaction():
metadata.clear()
metadata.update(value)
with self._lock:
with metadata.doc.transaction():
metadata.clear()
metadata.update(value)

def add_code_cell(self, source: str, **kwargs) -> int:
"""Add a code cell
Expand All @@ -175,7 +194,8 @@ def add_code_cell(self, source: str, **kwargs) -> int:
"""
cell = current_api.new_code_cell(source, **kwargs)

self._doc.append_cell(cell)
with self._lock:
self._doc.append_cell(cell)

return len(self) - 1

Expand All @@ -190,7 +210,8 @@ def add_markdown_cell(self, source: str, **kwargs) -> int:
"""
cell = current_api.new_markdown_cell(source, **kwargs)

self._doc.append_cell(cell)
with self._lock:
self._doc.append_cell(cell)

return len(self) - 1

Expand All @@ -205,7 +226,8 @@ def add_raw_cell(self, source: str, **kwargs) -> int:
"""
cell = current_api.new_raw_cell(source, **kwargs)

self._doc.append_cell(cell)
with self._lock:
self._doc.append_cell(cell)

return len(self) - 1

Expand All @@ -215,7 +237,8 @@ def as_dict(self) -> dict[str, t.Any]:
Returns:
The dictionary
"""
return self._doc.source
with self._lock:
return self._doc.source

def execute_cell(
self,
Expand Down Expand Up @@ -258,20 +281,22 @@ def execute_cell(
)

ycell = t.cast(pycrdt.Map, self._doc.ycells[index])
source = ycell["source"].to_py()
with self._lock:
source = ycell["source"].to_py()

# Reset cell
with ycell.doc.transaction():
del ycell["outputs"][:]
ycell["execution_count"] = None
ycell["execution_state"] = "running"
with self._lock:
with ycell.doc.transaction():
del ycell["outputs"][:]
ycell["execution_count"] = None
ycell["execution_state"] = "running"

outputs = []
reply_content = {}
try:
reply = kernel_client.execute_interactive(
source,
output_hook=partial(save_in_notebook_hook, outputs, ycell),
output_hook=partial(save_in_notebook_hook, self._lock, outputs, ycell),
allow_stdin=False,
silent=silent,
store_history=False if silent else store_history,
Expand All @@ -281,9 +306,10 @@ def execute_cell(

reply_content = reply["content"]
finally:
with ycell.doc.transaction():
ycell["execution_count"] = reply_content.get("execution_count")
ycell["execution_state"] = "idle"
with self._lock:
with ycell.doc.transaction():
ycell["execution_count"] = reply_content.get("execution_count")
ycell["execution_state"] = "idle"

return {
"execution_count": reply_content.get("execution_count"),
Expand All @@ -299,7 +325,8 @@ def insert(self, index: int, value: dict[str, t.Any]) -> None:
value: A mapping describing the cell
"""
ycell = self._doc.create_ycell(value)
self._doc.ycells.insert(index, ycell)
with self._lock:
self._doc.ycells.insert(index, ycell)

def set_cell_source(self, index: int, source: str) -> None:
"""Set a cell source.
Expand All @@ -308,7 +335,8 @@ def set_cell_source(self, index: int, source: str) -> None:
index: Cell index
source: New cell source
"""
self._doc._ycells[index].set("source", source)
with self._lock:
t.cast(pycrdt.Map, self._doc._ycells[index])["source"] = source

def _reset_y_model(self) -> None:
"""Reset the Y model."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = ["jupyter_ydoc>=2.1.2,<4.0.0", "nbformat~=5.0", "pycrdt >=0.10.3,

[project.optional-dependencies]
test = ["ipykernel", "jupyter-kernel-client", "jupyter-server-ydoc~=1.0.0", "pytest>=7.0", "pytest-timeout"]
lint = ["mdformat>0.7", "mdformat-gfm>=0.3.5", "ruff"]
lint = ["pre_commit", "mdformat>0.7", "mdformat-gfm>=0.3.5", "ruff"]
typing = ["mypy>=0.990"]

[project.license]
Expand Down

0 comments on commit 20ff2d1

Please sign in to comment.