From ca9b856e497e73207b48141bfb7351d15b5b281d Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 15 Oct 2023 12:23:08 -0400 Subject: [PATCH] revise and document cancellation semantics in short, cancellable threads always use system tasks. normal threads use the host task, unless passed a token --- docs/source/reference-core.rst | 13 +- trio/_tests/test_threads.py | 20 ++-- trio/_tests/verify_types_darwin.json | 2 +- trio/_tests/verify_types_linux.json | 2 +- trio/_tests/verify_types_windows.json | 2 +- trio/_threads.py | 163 ++++++++++++-------------- 6 files changed, 105 insertions(+), 97 deletions(-) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index db1a93f121..160fa0fe97 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1823,9 +1823,20 @@ to spawn a child thread, and then use a :ref:`memory channel .. literalinclude:: reference-core/from-thread-example.py +.. note:: + + The ``from_thread.run*`` functions reuse the host task that called + :func:`trio.to_thread.run_sync` to run your provided function in the typical case, + namely when ``cancellable=False`` so Trio can be sure that the task will always be + around to perform the work. If you pass ``cancellable=True`` at the outset, or if + you provide a :class:`~trio.lowlevel.TrioToken` when calling back in to Trio, your + functions will be executed in a new system task. Therefore, the + :func:`~trio.lowlevel.current_task`, :func:`current_effective_deadline`, or other + task-tree specific values may differ depending on keyword argument values. + You can also use :func:`trio.from_thread.check_cancelled` to check for cancellation from a thread that was spawned by :func:`trio.to_thread.run_sync`. If the call to -:func:`~trio.to_thread.run_sync` was cancelled, then +:func:`~trio.to_thread.run_sync` was cancelled (even if ``cancellable=False``!), then :func:`~trio.from_thread.check_cancelled` will raise :func:`trio.Cancelled`. It's like ``trio.from_thread.run(trio.sleep, 0)``, but much faster. diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 24b450cc59..bfe5204f4a 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -933,12 +933,14 @@ async def async_time_bomb(): async def test_from_thread_check_cancelled(): q = stdlib_queue.Queue() - async def child(cancellable): + async def child(cancellable, scope): + with scope: record.append("start") try: return await to_thread_run_sync(f, cancellable=cancellable) except _core.Cancelled: record.append("cancel") + raise finally: record.append("exit") @@ -956,7 +958,7 @@ def f(): record = [] ev = threading.Event() async with _core.open_nursery() as nursery: - nursery.start_soon(child, False) + nursery.start_soon(child, False, _core.CancelScope()) await wait_all_tasks_blocked() assert record[0] == "start" assert q.get(timeout=1) == "Not Cancelled" @@ -968,14 +970,15 @@ def f(): # the appropriate cancel scope record = [] ev = threading.Event() + scope = _core.CancelScope() # Nursery cancel scope gives false positives async with _core.open_nursery() as nursery: - nursery.start_soon(child, False) + nursery.start_soon(child, False, scope) await wait_all_tasks_blocked() assert record[0] == "start" assert q.get(timeout=1) == "Not Cancelled" - nursery.cancel_scope.cancel() + scope.cancel() ev.set() - assert nursery.cancel_scope.cancelled_caught + assert scope.cancelled_caught assert "cancel" in record assert record[-1] == "exit" @@ -992,13 +995,14 @@ def f(): # noqa: F811 record = [] ev = threading.Event() + scope = _core.CancelScope() async with _core.open_nursery() as nursery: - nursery.start_soon(child, True) + nursery.start_soon(child, True, scope) await wait_all_tasks_blocked() assert record[0] == "start" - nursery.cancel_scope.cancel() + scope.cancel() ev.set() - assert nursery.cancel_scope.cancelled_caught + assert scope.cancelled_caught assert "cancel" in record assert record[-1] == "exit" assert q.get(timeout=1) == "Cancelled" diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index 6625368f20..e83a324714 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -40,7 +40,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 631, + "withKnownType": 630, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index 73ee6f7855..7c9d745dba 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -28,7 +28,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 628, + "withKnownType": 627, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index f2c1f0dd6c..a58416fe76 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -64,7 +64,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 631, + "withKnownType": 630, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_threads.py b/trio/_threads.py index 7620bef55c..e5636b1861 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -17,7 +17,6 @@ from trio._core._traps import RaiseCancelT from ._core import ( - CancelScope, RunVar, TrioToken, disable_ki_protection, @@ -35,6 +34,7 @@ class _ParentTaskData(threading.local): parent task of native Trio threads.""" token: TrioToken + abandon_on_cancel: bool cancel_register: list[RaiseCancelT | None] task_register: list[trio.lowlevel.Task | None] @@ -74,11 +74,6 @@ class ThreadPlaceholder: # Types for the to_thread_run_sync message loop -@attr.s(frozen=True, eq=False) -class ThreadDone(Generic[RetT]): - result: outcome.Outcome[RetT] = attr.ib() - - @attr.s(frozen=True, eq=False) class Run(Generic[RetT]): afn: Callable[..., Awaitable[RetT]] = attr.ib() @@ -87,7 +82,6 @@ class Run(Generic[RetT]): queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( init=False, factory=stdlib_queue.SimpleQueue ) - scope: CancelScope = attr.ib(init=False, factory=CancelScope) @disable_ki_protection async def unprotected_afn(self) -> RetT: @@ -108,14 +102,32 @@ async def run(self) -> None: await trio.lowlevel.cancel_shielded_checkpoint() async def run_system(self) -> None: - # NOTE: There is potential here to only conditionally enter a CancelScope - # when we need it, sparing some computation. But doing so adds substantial - # complexity, so we'll leave it until real need is demonstrated. - with self.scope: - result = await outcome.acapture(self.unprotected_afn) - assert not self.scope.cancelled_caught, "any Cancelled should go to our parent" + result = await outcome.acapture(self.unprotected_afn) self.queue.put_nowait(result) + def run_in_host_task(self, token: TrioToken) -> None: + task_register = PARENT_TASK_DATA.task_register + + def in_trio_thread() -> None: + task = task_register[0] + assert task is not None, "guaranteed by abandon_on_cancel semantics" + trio.lowlevel.reschedule(task, outcome.Value(self)) + + token.run_sync_soon(in_trio_thread) + + def run_in_system_nursery(self, token: TrioToken) -> None: + def in_trio_thread() -> None: + try: + trio.lowlevel.spawn_system_task( + self.run, name=self.afn, context=self.context + ) + except RuntimeError: # system nursery is closed + self.queue.put_nowait( + outcome.Error(trio.RunFinishedError("system nursery is closed")) + ) + + token.run_sync_soon(in_trio_thread) + @attr.s(frozen=True, eq=False) class RunSync(Generic[RetT]): @@ -144,6 +156,19 @@ def run_sync(self) -> None: result = outcome.capture(self.context.run, self.unprotected_fn) self.queue.put_nowait(result) + def run_in_host_task(self, token: TrioToken) -> None: + task_register = PARENT_TASK_DATA.task_register + + def in_trio_thread() -> None: + task = task_register[0] + assert task is not None, "guaranteed by abandon_on_cancel semantics" + trio.lowlevel.reschedule(task, outcome.Value(self)) + + token.run_sync_soon(in_trio_thread) + + def run_in_system_nursery(self, token: TrioToken) -> None: + token.run_sync_soon(self.run_sync) + @enable_ki_protection # Decorator used on function with Coroutine[Any, Any, RetT] async def to_thread_run_sync( # type: ignore[misc] @@ -237,7 +262,7 @@ async def to_thread_run_sync( # type: ignore[misc] """ await trio.lowlevel.checkpoint_if_cancelled() - cancellable = bool(cancellable) # raise early if cancellable.__bool__ raises + abandon_on_cancel = bool(cancellable) # raise early if cancellable.__bool__ raises if limiter is None: limiter = current_default_thread_limiter() @@ -266,9 +291,7 @@ def do_release_then_return_result() -> RetT: result = outcome.capture(do_release_then_return_result) if task_register[0] is not None: - trio.lowlevel.reschedule( - task_register[0], outcome.Value(ThreadDone(result)) - ) + trio.lowlevel.reschedule(task_register[0], outcome.Value(result)) current_trio_token = trio.lowlevel.current_trio_token() @@ -283,6 +306,7 @@ def worker_fn() -> RetT: current_async_library_cvar.set(None) PARENT_TASK_DATA.token = current_trio_token + PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel PARENT_TASK_DATA.cancel_register = cancel_register PARENT_TASK_DATA.task_register = task_register try: @@ -299,6 +323,7 @@ def worker_fn() -> RetT: return ret finally: del PARENT_TASK_DATA.token + del PARENT_TASK_DATA.abandon_on_cancel del PARENT_TASK_DATA.cancel_register del PARENT_TASK_DATA.task_register @@ -336,11 +361,11 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: while True: # wait_task_rescheduled return value cannot be typed - msg_from_thread: ThreadDone[RetT] | Run[object] | RunSync[ + msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[ object ] = await trio.lowlevel.wait_task_rescheduled(abort) - if isinstance(msg_from_thread, ThreadDone): - return msg_from_thread.result.unwrap() # type: ignore[no-any-return] + if isinstance(msg_from_thread, outcome.Outcome): + return msg_from_thread.unwrap() # type: ignore[no-any-return] elif isinstance(msg_from_thread, Run): await msg_from_thread.run() elif isinstance(msg_from_thread, RunSync): @@ -354,10 +379,10 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: def from_thread_check_cancelled() -> None: - """Raise trio.Cancelled if the associated Trio task entered a cancelled status. + """Raise `trio.Cancelled` if the associated Trio task entered a cancelled status. Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow - ``cancellable=False`` threads to raise :exc:`trio.Cancelled` at a suitable + ``cancellable=False`` threads to raise :exc:`~trio.Cancelled` at a suitable place, or to end abandoned ``cancellable=True`` threads sooner than they may otherwise. @@ -366,6 +391,13 @@ def from_thread_check_cancelled() -> None: delivery of cancellation attempted against it, regardless of the value of ``cancellable`` supplied as an argument to it. RuntimeError: If this thread is not spawned from `trio.to_thread.run_sync`. + + .. note:: + + The check for cancellation attempts of ``cancellable=False`` threads is + interrupted while executing ``from_thread.run*`` functions, which can lead to + edge cases where this function may raise or not depending on the timing of + :class:`~trio.CancelScope` shields being raised or lowered in the Trio threads. """ try: raise_cancel = PARENT_TASK_DATA.cancel_register[0] @@ -406,49 +438,6 @@ def _check_token(trio_token: TrioToken | None) -> TrioToken: return trio_token -def _send_message_to_host_task( - message: Run[RetT] | RunSync[RetT], trio_token: TrioToken -) -> None: - task_register = PARENT_TASK_DATA.task_register - - def in_trio_thread() -> None: - task = task_register[0] - if task is None: - # Our parent task is gone! Punt to a system task. - if isinstance(message, Run): - message.scope.cancel() - _send_message_to_system_task(message, trio_token) - else: - trio.lowlevel.reschedule(task, outcome.Value(message)) - - trio_token.run_sync_soon(in_trio_thread) - - -def _send_message_to_system_task( - message: Run[RetT] | RunSync[RetT], trio_token: TrioToken -) -> None: - if type(message) is RunSync: - run_sync = message.run_sync - elif type(message) is Run: - - def run_sync() -> None: - try: - trio.lowlevel.spawn_system_task( - message.run_system, name=message.afn, context=message.context - ) - except RuntimeError: # system nursery is closed - message.queue.put_nowait( - outcome.Error(trio.RunFinishedError("system nursery is closed")) - ) - - else: # pragma: no cover, internal debugging guard TODO: use assert_never - raise TypeError( - "trio.to_thread.run_sync received unrecognized thread message {!r}." - "".format(message) - ) - trio_token.run_sync_soon(run_sync) - - def from_thread_run( afn: Callable[..., Awaitable[RetT]], *args: object, @@ -467,17 +456,15 @@ def from_thread_run( RunFinishedError: if the corresponding call to :func:`trio.run` has already completed, or if the run has started its final cleanup phase and can no longer spawn new system tasks. - Cancelled: if the corresponding `trio.to_thread.run_sync` task is - cancellable and exits before this function is called, or - if the task enters cancelled status or call to :func:`trio.run` completes - while ``afn(*args)`` is running, then ``afn`` is likely to raise + Cancelled: if the task enters cancelled status or call to :func:`trio.run` + completes while ``afn(*args)`` is running, then ``afn`` is likely to raise :exc:`trio.Cancelled`. RuntimeError: if you try calling this from inside the Trio thread, which would otherwise cause a deadlock, or if no ``trio_token`` was provided, and we can't infer one from context. TypeError: if ``afn`` is not an asynchronous function. - **Locating a Trio Token**: There are two ways to specify which + **Locating a TrioToken**: There are two ways to specify which `trio.run` loop to reenter: - Spawn this thread from `trio.to_thread.run_sync`. Trio will @@ -486,17 +473,20 @@ def from_thread_run( - Pass a keyword argument, ``trio_token`` specifying a specific `trio.run` loop to re-enter. This is useful in case you have a "foreign" thread, spawned using some other framework, and still want - to enter Trio, or if you want to avoid the cancellation context of - `trio.to_thread.run_sync`. + to enter Trio, or if you want to use a new system task to call ``afn``, + maybe to avoid the cancellation context of a corresponding + `trio.to_thread.run_sync` task. """ - if trio_token is None: - send_message = _send_message_to_host_task - else: - send_message = _send_message_to_system_task + token_provided = trio_token is not None + trio_token = _check_token(trio_token) message_to_trio = Run(afn, args, contextvars.copy_context()) - send_message(message_to_trio, _check_token(trio_token)) + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) + return message_to_trio.queue.get().unwrap() # type: ignore[no-any-return] @@ -522,7 +512,7 @@ def from_thread_run_sync( provided, and we can't infer one from context. TypeError: if ``fn`` is an async function. - **Locating a Trio Token**: There are two ways to specify which + **Locating a TrioToken**: There are two ways to specify which `trio.run` loop to reenter: - Spawn this thread from `trio.to_thread.run_sync`. Trio will @@ -531,15 +521,18 @@ def from_thread_run_sync( - Pass a keyword argument, ``trio_token`` specifying a specific `trio.run` loop to re-enter. This is useful in case you have a "foreign" thread, spawned using some other framework, and still want - to enter Trio, or if you want to avoid the cancellation context of - `trio.to_thread.run_sync`. + to enter Trio, or if you want to use a new system task to call ``fn``, + maybe to avoid the cancellation context of a corresponding + `trio.to_thread.run_sync` task. """ - if trio_token is None: - send_message = _send_message_to_host_task - else: - send_message = _send_message_to_system_task + token_provided = trio_token is not None + trio_token = _check_token(trio_token) message_to_trio = RunSync(fn, args, contextvars.copy_context()) - send_message(message_to_trio, _check_token(trio_token)) + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) + return message_to_trio.queue.get().unwrap() # type: ignore[no-any-return]