Skip to content

Commit

Permalink
Add trio.from_thread.check_cancelled api to allow threads to efficien…
Browse files Browse the repository at this point in the history
…tly poll for cancellation
  • Loading branch information
richardsheridan committed Aug 8, 2022
1 parent 2d62ff0 commit 357e00e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 6 deletions.
38 changes: 32 additions & 6 deletions trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from ._util import coroutine_or_error

# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
# Global due to Threading API, thread local storage for trio token and raise_cancel
THREAD_LOCAL = threading.local()

_limiter_local = RunVar("limiter")
# I pulled this number out of the air; it isn't based on anything. Probably we
Expand Down Expand Up @@ -146,6 +146,9 @@ async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None):
# for the result – or None if this function was cancelled and we should
# discard the result.
task_register = [trio.lowlevel.current_task()]
# Holds a reference to the raise_cancel function provided if a cancellation
# is attempted against this task - or None if no such delivery has happened.
cancel_register = [None]
name = f"trio.to_thread.run_sync-{next(_thread_counter)}"
placeholder = ThreadPlaceholder(name)

Expand All @@ -170,7 +173,8 @@ def do_release_then_return_result():

def worker_fn():
current_async_library_cvar.set(None)
TOKEN_LOCAL.token = current_trio_token
THREAD_LOCAL.token = current_trio_token
THREAD_LOCAL.cancel_register = cancel_register
try:
ret = sync_fn(*args)

Expand All @@ -184,7 +188,8 @@ def worker_fn():

return ret
finally:
del TOKEN_LOCAL.token
del THREAD_LOCAL.token
del THREAD_LOCAL.cancel_register

context = contextvars.copy_context()
contextvars_aware_worker_fn = functools.partial(context.run, worker_fn)
Expand All @@ -205,8 +210,11 @@ def deliver_worker_fn_result(result):
limiter.release_on_behalf_of(placeholder)
raise

def abort(_):
def abort(raise_cancel):
# fill so from_thread_check_cancelled can raise
cancel_register[0] = raise_cancel
if cancellable:
# empty so report_back_in_trio_thread_fn cannot reschedule
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
else:
Expand All @@ -215,6 +223,24 @@ def abort(_):
return await trio.lowlevel.wait_task_rescheduled(abort)


def from_thread_check_cancelled():
"""Raise trio.Cancelled if the associated Trio task entered a cancelled status.
Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow
``cancellable=False`` threads to raise :exc:`trio.Cancelled` at a suitable
place, or to end abandoned ``cancellable=True`` sooner than they may otherwise.
Raises:
Cancelled: If the corresponding call to `trio.to_thread.run_sync` has had a
delivery of cancellation attempted against it, regardless of the value of
``cancellable`` supplied as an argument to it.
AttributeError: If this thread is not spawned from `trio.to_thread.run_sync`.
"""
raise_cancel = THREAD_LOCAL.cancel_register[0]
if raise_cancel is not None:
raise_cancel()


def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None):
"""Helper function for from_thread.run and from_thread.run_sync.
Expand All @@ -227,7 +253,7 @@ def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None):

if not trio_token:
try:
trio_token = TOKEN_LOCAL.token
trio_token = THREAD_LOCAL.token
except AttributeError:
raise RuntimeError(
"this thread wasn't created by Trio, pass kwarg trio_token=..."
Expand Down
1 change: 1 addition & 0 deletions trio/from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from ._threads import from_thread_run as run
from ._threads import from_thread_run_sync as run_sync
from ._threads import from_thread_check_cancelled as check_cancelled
84 changes: 84 additions & 0 deletions trio/tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
current_default_thread_limiter,
from_thread_run,
from_thread_run_sync,
from_thread_check_cancelled,
)

from .._core.tests.test_ki import ki_self
Expand Down Expand Up @@ -750,3 +751,86 @@ def __bool__(self):

with pytest.raises(NotImplementedError):
await to_thread_run_sync(int, cancellable=BadBool())


async def test_from_thread_check_cancelled():
q = stdlib_queue.Queue()

async def child(cancellable):
record.append("start")
try:
return await to_thread_run_sync(f, cancellable=cancellable)
except _core.Cancelled:
record.append("cancel")
finally:
record.append("exit")

def f():
try:
from_thread_check_cancelled()
except _core.Cancelled: # pragma: no cover, test failure path
q.put("Cancelled")
else:
q.put("Not Cancelled")
ev.wait()
return from_thread_check_cancelled()

# Base case: nothing cancelled so we shouldn't see cancels anywhere
record = []
ev = threading.Event()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, False)
await wait_all_tasks_blocked()
assert record[0] == "start"
assert q.get(timeout=1) == "Not Cancelled"
ev.set()
# implicit assertion, Cancelled not raised via nursery
assert record[1] == "exit"

# cancellable=False case: a cancel will pop out but be handled by
# the appropriate cancel scope
record = []
ev = threading.Event()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, False)
await wait_all_tasks_blocked()
assert record[0] == "start"
assert q.get(timeout=1) == "Not Cancelled"
nursery.cancel_scope.cancel()
ev.set()
assert nursery.cancel_scope.cancelled_caught
assert "cancel" in record
assert record[-1] == "exit"

# cancellable=True case: slightly different thread behavior needed
# check thread is cancelled "soon" after abandonment
def f():
ev.wait()
try:
from_thread_check_cancelled()
except _core.Cancelled:
q.put("Cancelled")
else: # pragma: no cover, test failure path
q.put("Not Cancelled")

record = []
ev = threading.Event()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, True)
await wait_all_tasks_blocked()
assert record[0] == "start"
nursery.cancel_scope.cancel()
ev.set()
assert nursery.cancel_scope.cancelled_caught
assert "cancel" in record
assert record[-1] == "exit"
assert q.get(timeout=1) == "Cancelled"


async def test_from_thread_check_cancelled_raises_in_foreign_threads():
with pytest.raises(AttributeError):
from_thread_check_cancelled()
q = stdlib_queue.Queue()
_core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
with pytest.raises(AttributeError):
q.get(timeout=1).unwrap()

0 comments on commit 357e00e

Please sign in to comment.