From 0e18c93c9d4915ab33dda4ec2a5ae98e44fc4cff Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 8 Oct 2023 19:19:46 -0400 Subject: [PATCH] implement cancellation semantics suggestions from code review --- trio/_tests/test_threads.py | 5 +++-- trio/_threads.py | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 118d272ba1..fb1682984b 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -879,6 +879,7 @@ def sync_check(): try: from_thread_run_sync(bool) except _core.Cancelled: + # pragma: no cover, sync functions don't raise Cancelled queue.put(True) else: queue.put(False) @@ -893,7 +894,7 @@ def sync_check(): await to_thread_run_sync(sync_check, cancellable=True) assert cancel_scope.cancelled_caught - assert await to_thread_run_sync(partial(queue.get, timeout=1)) + assert not await to_thread_run_sync(partial(queue.get, timeout=1)) async def no_checkpoint(): return True @@ -917,7 +918,7 @@ def async_check(): await to_thread_run_sync(async_check, cancellable=True) assert cancel_scope.cancelled_caught - assert await to_thread_run_sync(partial(queue.get, timeout=1)) + assert not await to_thread_run_sync(partial(queue.get, timeout=1)) async def async_time_bomb(): cancel_scope.cancel() diff --git a/trio/_threads.py b/trio/_threads.py index dff4485342..7620bef55c 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -17,6 +17,7 @@ from trio._core._traps import RaiseCancelT from ._core import ( + CancelScope, RunVar, TrioToken, disable_ki_protection, @@ -86,6 +87,7 @@ class Run(Generic[RetT]): queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( init=False, factory=stdlib_queue.SimpleQueue ) + scope: CancelScope = attr.ib(init=False, factory=CancelScope) @disable_ki_protection async def unprotected_afn(self) -> RetT: @@ -106,7 +108,12 @@ async def run(self) -> None: await trio.lowlevel.cancel_shielded_checkpoint() async def run_system(self) -> None: - result = await outcome.acapture(self.unprotected_afn) + # NOTE: There is potential here to only conditionally enter a CancelScope + # when we need it, sparing some computation. But doing so adds substantial + # complexity, so we'll leave it until real need is demonstrated. + with self.scope: + result = await outcome.acapture(self.unprotected_afn) + assert not self.scope.cancelled_caught, "any Cancelled should go to our parent" self.queue.put_nowait(result) @@ -403,13 +410,14 @@ def _send_message_to_host_task( message: Run[RetT] | RunSync[RetT], trio_token: TrioToken ) -> None: task_register = PARENT_TASK_DATA.task_register - cancel_register = PARENT_TASK_DATA.cancel_register def in_trio_thread() -> None: task = task_register[0] if task is None: - raise_cancel = cancel_register[0] - message.queue.put_nowait(outcome.capture(raise_cancel)) + # Our parent task is gone! Punt to a system task. + if isinstance(message, Run): + message.scope.cancel() + _send_message_to_system_task(message, trio_token) else: trio.lowlevel.reschedule(task, outcome.Value(message)) @@ -509,8 +517,6 @@ def from_thread_run_sync( Raises: RunFinishedError: if the corresponding call to `trio.run` has already completed. - Cancelled: if the corresponding `trio.to_thread.run_sync` task is - cancellable and exits before this function is called. RuntimeError: if you try calling this from inside the Trio thread, which would otherwise cause a deadlock or if no ``trio_token`` was provided, and we can't infer one from context.