diff --git a/trio/_threads.py b/trio/_threads.py index 5556e50538..cd76f1907a 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -367,12 +367,12 @@ def _send_message_to_host_task(message, trio_token): cancel_register = THREAD_LOCAL.cancel_register def in_trio_thread(): - raise_cancel = cancel_register[0] - if raise_cancel is None: - task = task_register[0] - trio.lowlevel.reschedule(task, outcome.Value(message)) - else: + task = task_register[0] + if task is None: + raise_cancel = cancel_register[0] message.queue.put_nowait(outcome.capture(raise_cancel)) + else: + trio.lowlevel.reschedule(task, outcome.Value(message)) trio_token.run_sync_soon(in_trio_thread) return message.queue.get().unwrap() @@ -417,7 +417,9 @@ def from_thread_run(afn, *args, trio_token=None): RunFinishedError: if the corresponding call to :func:`trio.run` has already completed, or if the run has started its final cleanup phase and can no longer spawn new system tasks. - Cancelled: if the corresponding task or call to :func:`trio.run` completes + Cancelled: if the corresponding `trio.to_thread.run_sync` task is + cancellable and exits before this function is called, or + if the task enters cancelled status or call to :func:`trio.run` completes while ``afn(*args)`` is running, then ``afn`` is likely to raise :exc:`trio.Cancelled`. RuntimeError: if you try calling this from inside the Trio thread, @@ -460,7 +462,7 @@ def from_thread_run_sync(fn, *args, trio_token=None): RunFinishedError: if the corresponding call to `trio.run` has already completed. Cancelled: if the corresponding `trio.to_thread.run_sync` task is - cancellable and exits before this function is called + cancellable and exits before this function is called. RuntimeError: if you try calling this from inside the Trio thread, which would otherwise cause a deadlock or if no ``trio_token`` was provided, and we can't infer one from context. diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index ba4f3d6f39..920b3d95f0 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -890,16 +890,52 @@ def get_tid_then_reenter(): async def test_from_thread_host_cancelled(): - def sync_time_bomb(): - deadline = time.perf_counter() + 10 - while time.perf_counter() < deadline: - from_thread_run_sync(cancel_scope.cancel) - assert False # pragma: no cover + queue = stdlib_queue.Queue() + + def sync_check(): + from_thread_run_sync(cancel_scope.cancel) + try: + from_thread_run_sync(bool) + except _core.Cancelled: + queue.put(True) + else: + queue.put(False) + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(sync_check) + + assert not cancel_scope.cancelled_caught + assert not queue.get_nowait() + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(sync_check, cancellable=True) + + assert cancel_scope.cancelled_caught + assert await to_thread_run_sync(partial(queue.get, timeout=1)) + + async def no_checkpoint(): + return True + + def async_check(): + from_thread_run_sync(cancel_scope.cancel) + try: + assert from_thread_run(no_checkpoint) + except _core.Cancelled: + queue.put(True) + else: + queue.put(False) + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(async_check) + + assert not cancel_scope.cancelled_caught + assert not queue.get_nowait() with _core.CancelScope() as cancel_scope: - await to_thread_run_sync(sync_time_bomb) + await to_thread_run_sync(async_check, cancellable=True) assert cancel_scope.cancelled_caught + assert await to_thread_run_sync(partial(queue.get, timeout=1)) async def async_time_bomb(): cancel_scope.cancel()